Skip to content

Commit 0d2d8a6

Browse files
alankellytensorflower-gardener
authored andcommitted
Experimental: Allow users to enable all features of default delegates in python
XNNPack is the only default delegate and this will allow users to benefit from all flag protected features. Other delegates can use this in the future. PiperOrigin-RevId: 618111218
1 parent cd0b05b commit 0d2d8a6

File tree

8 files changed

+89
-45
lines changed

8 files changed

+89
-45
lines changed

RELEASE.md

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
* C API:
7676
* The experimental `TfLiteRegistrationExternal` type has been renamed as
7777
`TfLiteOperator`, and likewise for the corresponding API functions.
78+
* The Python TF Lite Interpreter bindings now have an option
79+
`experimental_default_delegate_latest_features` to enable all default
80+
delegate features.
7881

7982
## Thanks to our Contributors
8083

tensorflow/lite/python/interpreter.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def __init__(
396396
experimental_op_resolver_type=OpResolverType.AUTO,
397397
experimental_preserve_all_tensors=False,
398398
experimental_disable_delegate_clustering=False,
399+
experimental_default_delegate_latest_features=False,
399400
):
400401
"""Constructor.
401402
@@ -437,6 +438,8 @@ def __init__(
437438
this flag is currently experimental, and it might be removed/updated if
438439
the TF Lite converter doesn't drop such control dependencies in the
439440
model. Default is False.
441+
experimental_default_delegate_latest_features: If true, default delegates
442+
may enable all flag protected features. Default is False;
440443
441444
Raises:
442445
ValueError: If the interpreter was unable to create.
@@ -454,6 +457,12 @@ def __init__(
454457
raise ValueError('Unrecognized passed in op resolver type: {}'.format(
455458
experimental_op_resolver_type))
456459

460+
if num_threads is not None:
461+
if not isinstance(num_threads, int):
462+
raise ValueError('type of num_threads should be int')
463+
if num_threads < 1:
464+
raise ValueError('num_threads should >= 1')
465+
457466
if model_path and not model_content:
458467
custom_op_registerers_by_name = [
459468
x for x in self._custom_op_registerers if isinstance(x, str)
@@ -468,6 +477,8 @@ def __init__(
468477
custom_op_registerers_by_func,
469478
experimental_preserve_all_tensors,
470479
experimental_disable_delegate_clustering,
480+
int(num_threads or 1),
481+
experimental_default_delegate_latest_features,
471482
)
472483
if not self._interpreter:
473484
raise ValueError('Failed to open {}'.format(model_path))
@@ -489,19 +500,14 @@ def __init__(
489500
custom_op_registerers_by_func,
490501
experimental_preserve_all_tensors,
491502
experimental_disable_delegate_clustering,
503+
int(num_threads or 1),
504+
experimental_default_delegate_latest_features,
492505
)
493506
elif not model_content and not model_path:
494507
raise ValueError('`model_path` or `model_content` must be specified.')
495508
else:
496509
raise ValueError('Can\'t both provide `model_path` and `model_content`')
497510

498-
if num_threads is not None:
499-
if not isinstance(num_threads, int):
500-
raise ValueError('type of num_threads should be int')
501-
if num_threads < 1:
502-
raise ValueError('num_threads should >= 1')
503-
self._interpreter.SetNumThreads(num_threads)
504-
505511
# Each delegate is a wrapper that owns the delegates that have been loaded
506512
# as plugins. The interpreter wrapper will be using them, but we need to
507513
# hold them in a list so that the lifetime is preserved at least as long as

tensorflow/lite/python/interpreter_wrapper/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ cc_library(
3939
"//tensorflow/lite/core/api",
4040
"//tensorflow/lite/core/c:common",
4141
"//tensorflow/lite/core/kernels:builtin_ops",
42+
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
4243
"//tensorflow/lite/kernels:reference_ops",
4344
"//tensorflow/lite/kernels/internal:compatibility",
4445
"//third_party/python_runtime:headers", # buildcleaner: keep

tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc

+56-29
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "tensorflow/lite/core/interpreter.h"
3333
#include "tensorflow/lite/core/kernels/register.h"
3434
#include "tensorflow/lite/core/model.h"
35+
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
3536
#include "tensorflow/lite/kernels/internal/compatibility.h"
3637
#include "tensorflow/lite/kernels/register_ref.h"
3738
#include "tensorflow/lite/mutable_op_resolver.h"
@@ -85,18 +86,31 @@ using python_utils::PyDecrefDeleter;
8586
std::unique_ptr<Interpreter> CreateInterpreter(
8687
const InterpreterWrapper::Model* model,
8788
const tflite::MutableOpResolver& resolver, bool preserve_all_tensors,
88-
bool disable_delegate_clustering) {
89+
bool disable_delegate_clustering, int num_threads,
90+
bool default_delegate_latest_features) {
8991
if (!model) {
9092
return nullptr;
9193
}
9294

9395
::tflite::python::ImportNumpy();
9496

97+
TfLiteDelegate* xnnpack_delegate = nullptr;
98+
if (default_delegate_latest_features) {
99+
auto opts = TfLiteXNNPackDelegateOptionsDefault();
100+
opts.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
101+
opts.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING;
102+
opts.num_threads = num_threads;
103+
xnnpack_delegate = TfLiteXNNPackDelegateCreate(&opts);
104+
}
95105
std::unique_ptr<Interpreter> interpreter;
96106
InterpreterOptions options;
97107
options.SetPreserveAllTensors(preserve_all_tensors);
98108
options.SetDisableDelegateClustering(disable_delegate_clustering);
99109
InterpreterBuilder builder(*model, resolver, &options);
110+
if (default_delegate_latest_features) {
111+
builder.AddDelegate(xnnpack_delegate);
112+
}
113+
builder.SetNumThreads(num_threads);
100114
if (builder(&interpreter) != kTfLiteOk) {
101115
return nullptr;
102116
}
@@ -200,29 +214,36 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
200214
const std::vector<std::string>& registerers_by_name,
201215
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
202216
std::string* error_msg, bool preserve_all_tensors,
203-
bool disable_delegate_clustering) {
217+
bool disable_delegate_clustering, int num_threads,
218+
bool default_delegate_latest_features) {
204219
if (!model) {
205220
*error_msg = error_reporter->message();
206221
return nullptr;
207222
}
208223

209224
std::unique_ptr<tflite::MutableOpResolver> resolver;
210-
switch (op_resolver_id) {
211-
case kBuiltinOpResolver:
212-
resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
213-
break;
214-
case kBuiltinRefOpResolver:
215-
resolver = std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
216-
break;
217-
case kBuiltinOpResolverWithoutDefaultDelegates:
218-
resolver = std::make_unique<
219-
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
220-
break;
221-
default:
222-
// This should not never happen because the eventual caller in
223-
// interpreter.py should have passed a valid id here.
224-
TFLITE_DCHECK(false);
225-
return nullptr;
225+
if (default_delegate_latest_features) {
226+
resolver = std::make_unique<
227+
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
228+
} else {
229+
switch (op_resolver_id) {
230+
case kBuiltinOpResolver:
231+
resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
232+
break;
233+
case kBuiltinRefOpResolver:
234+
resolver =
235+
std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
236+
break;
237+
case kBuiltinOpResolverWithoutDefaultDelegates:
238+
resolver = std::make_unique<
239+
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
240+
break;
241+
default:
242+
// This should not never happen because the eventual caller in
243+
// interpreter.py should have passed a valid id here.
244+
TFLITE_DCHECK(false);
245+
return nullptr;
246+
}
226247
}
227248

228249
for (const auto& registerer : registerers_by_name) {
@@ -232,9 +253,9 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
232253
for (const auto& registerer : registerers_by_func) {
233254
registerer(reinterpret_cast<uintptr_t>(resolver.get()));
234255
}
235-
auto interpreter =
236-
CreateInterpreter(model.get(), *resolver, preserve_all_tensors,
237-
disable_delegate_clustering);
256+
auto interpreter = CreateInterpreter(
257+
model.get(), *resolver, preserve_all_tensors, disable_delegate_clustering,
258+
num_threads, default_delegate_latest_features);
238259
if (!interpreter) {
239260
*error_msg = error_reporter->message();
240261
return nullptr;
@@ -806,14 +827,16 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
806827
const std::vector<std::string>& registerers_by_name,
807828
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
808829
std::string* error_msg, bool preserve_all_tensors,
809-
bool disable_delegate_clustering) {
830+
bool disable_delegate_clustering, int num_threads,
831+
bool default_delegate_latest_features) {
810832
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
811833
std::unique_ptr<InterpreterWrapper::Model> model =
812834
Model::BuildFromFile(model_path, error_reporter.get());
813835
return CreateInterpreterWrapper(
814836
std::move(model), op_resolver_id, std::move(error_reporter),
815837
registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors,
816-
disable_delegate_clustering);
838+
disable_delegate_clustering, num_threads,
839+
default_delegate_latest_features);
817840
}
818841

819842
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
@@ -822,15 +845,17 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
822845
bool preserve_all_tensors, bool disable_delegate_clustering) {
823846
return CreateWrapperCPPFromFile(
824847
model_path, op_resolver_id, registerers, {} /*registerers_by_func*/,
825-
error_msg, preserve_all_tensors, disable_delegate_clustering);
848+
error_msg, preserve_all_tensors, disable_delegate_clustering,
849+
/*num_threads=*/1, /*default_delegate_latest_features=*/false);
826850
}
827851

828852
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
829853
PyObject* data, int op_resolver_id,
830854
const std::vector<std::string>& registerers_by_name,
831855
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
832856
std::string* error_msg, bool preserve_all_tensors,
833-
bool disable_delegate_clustering) {
857+
bool disable_delegate_clustering, int num_threads,
858+
bool default_delegate_latest_features) {
834859
char* buf = nullptr;
835860
Py_ssize_t length;
836861
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@@ -843,16 +868,18 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
843868
return CreateInterpreterWrapper(
844869
std::move(model), op_resolver_id, std::move(error_reporter),
845870
registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors,
846-
disable_delegate_clustering);
871+
disable_delegate_clustering, num_threads,
872+
default_delegate_latest_features);
847873
}
848874

849875
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
850876
PyObject* data, int op_resolver_id,
851877
const std::vector<std::string>& registerers, std::string* error_msg,
852878
bool preserve_all_tensors, bool disable_delegate_clustering) {
853-
return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
854-
error_msg, preserve_all_tensors,
855-
disable_delegate_clustering);
879+
return CreateWrapperCPPFromBuffer(
880+
data, op_resolver_id, registerers, {}, error_msg, preserve_all_tensors,
881+
disable_delegate_clustering, /*num_threads=*/1,
882+
/*default_delegate_latest_features=*/false);
856883
}
857884

858885
PyObject* InterpreterWrapper::ResetVariableTensors() {

tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class InterpreterWrapper {
5757
const std::vector<std::string>& registerers_by_name,
5858
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
5959
std::string* error_msg, bool preserve_all_tensors,
60-
bool disable_delegate_clustering);
60+
bool disable_delegate_clustering, int num_threads,
61+
bool default_delegate_latest_features);
6162

6263
// SWIG caller takes ownership of pointer.
6364
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
@@ -69,7 +70,8 @@ class InterpreterWrapper {
6970
const std::vector<std::string>& registerers_by_name,
7071
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
7172
std::string* error_msg, bool preserve_all_tensors,
72-
bool disable_delegate_clustering);
73+
bool disable_delegate_clustering, int num_threads,
74+
bool default_delegate_latest_features);
7375

7476
~InterpreterWrapper();
7577
PyObject* AllocateTensors(int subgraph_index);
@@ -126,7 +128,8 @@ class InterpreterWrapper {
126128
const std::vector<std::string>& registerers_by_name,
127129
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
128130
std::string* error_msg, bool preserve_all_tensors,
129-
bool disable_delegate_clustering);
131+
bool disable_delegate_clustering, int num_threads,
132+
bool default_delegate_latest_features);
130133

131134
InterpreterWrapper(std::unique_ptr<Model> model,
132135
std::unique_ptr<PythonErrorReporter> error_reporter,

tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc

+8-4
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
5353
[](const std::string& model_path, int op_resolver_id,
5454
const std::vector<std::string>& registerers_by_name,
5555
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
56-
bool preserve_all_tensors, bool disable_delegate_clustering) {
56+
bool preserve_all_tensors, bool disable_delegate_clustering,
57+
int num_threads, bool default_delegate_latest_features) {
5758
std::string error;
5859
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
5960
model_path.c_str(), op_resolver_id, registerers_by_name,
6061
registerers_by_func, &error, preserve_all_tensors,
61-
disable_delegate_clustering);
62+
disable_delegate_clustering, num_threads,
63+
default_delegate_latest_features);
6264
if (!wrapper) {
6365
throw std::invalid_argument(error);
6466
}
@@ -82,12 +84,14 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
8284
[](const py::bytes& data, int op_resolver_id,
8385
const std::vector<std::string>& registerers_by_name,
8486
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
85-
bool preserve_all_tensors, bool disable_delegate_clustering) {
87+
bool preserve_all_tensors, bool disable_delegate_clustering,
88+
int num_threads, bool default_delegate_latest_features) {
8689
std::string error;
8790
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
8891
data.ptr(), op_resolver_id, registerers_by_name,
8992
registerers_by_func, &error, preserve_all_tensors,
90-
disable_delegate_clustering);
93+
disable_delegate_clustering, num_threads,
94+
default_delegate_latest_features);
9195
if (!wrapper) {
9296
throw std::invalid_argument(error);
9397
}

tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\', \'experimental_op_resolver_type\', \'experimental_preserve_all_tensors\', \'experimental_disable_delegate_clustering\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'OpResolverType.AUTO\', \'False\', \'False\'], "
7+
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\', \'experimental_op_resolver_type\', \'experimental_preserve_all_tensors\', \'experimental_disable_delegate_clustering\', \'experimental_default_delegate_latest_features\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'OpResolverType.AUTO\', \'False\', \'False\', \'False\'], "
88
}
99
member_method {
1010
name: "allocate_tensors"

tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\', \'experimental_op_resolver_type\', \'experimental_preserve_all_tensors\', \'experimental_disable_delegate_clustering\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'OpResolverType.AUTO\', \'False\', \'False\'], "
7+
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\', \'experimental_op_resolver_type\', \'experimental_preserve_all_tensors\', \'experimental_disable_delegate_clustering\', \'experimental_default_delegate_latest_features\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'OpResolverType.AUTO\', \'False\', \'False\', \'False\'], "
88
}
99
member_method {
1010
name: "allocate_tensors"

0 commit comments

Comments
 (0)