@@ -32,6 +32,7 @@ limitations under the License.
32
32
#include " tensorflow/lite/core/interpreter.h"
33
33
#include " tensorflow/lite/core/kernels/register.h"
34
34
#include " tensorflow/lite/core/model.h"
35
+ #include " tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
35
36
#include " tensorflow/lite/kernels/internal/compatibility.h"
36
37
#include " tensorflow/lite/kernels/register_ref.h"
37
38
#include " tensorflow/lite/mutable_op_resolver.h"
@@ -85,18 +86,31 @@ using python_utils::PyDecrefDeleter;
85
86
std::unique_ptr<Interpreter> CreateInterpreter (
86
87
const InterpreterWrapper::Model* model,
87
88
const tflite::MutableOpResolver& resolver, bool preserve_all_tensors,
88
- bool disable_delegate_clustering) {
89
+ bool disable_delegate_clustering, int num_threads,
90
+ bool default_delegate_latest_features) {
89
91
if (!model) {
90
92
return nullptr ;
91
93
}
92
94
93
95
::tflite::python::ImportNumpy ();
94
96
97
+ TfLiteDelegate* xnnpack_delegate = nullptr ;
98
+ if (default_delegate_latest_features) {
99
+ auto opts = TfLiteXNNPackDelegateOptionsDefault ();
100
+ opts.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;
101
+ opts.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING;
102
+ opts.num_threads = num_threads;
103
+ xnnpack_delegate = TfLiteXNNPackDelegateCreate (&opts);
104
+ }
95
105
std::unique_ptr<Interpreter> interpreter;
96
106
InterpreterOptions options;
97
107
options.SetPreserveAllTensors (preserve_all_tensors);
98
108
options.SetDisableDelegateClustering (disable_delegate_clustering);
99
109
InterpreterBuilder builder (*model, resolver, &options);
110
+ if (default_delegate_latest_features) {
111
+ builder.AddDelegate (xnnpack_delegate);
112
+ }
113
+ builder.SetNumThreads (num_threads);
100
114
if (builder (&interpreter) != kTfLiteOk ) {
101
115
return nullptr ;
102
116
}
@@ -200,29 +214,36 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
200
214
const std::vector<std::string>& registerers_by_name,
201
215
const std::vector<std::function<void (uintptr_t )>>& registerers_by_func,
202
216
std::string* error_msg, bool preserve_all_tensors,
203
- bool disable_delegate_clustering) {
217
+ bool disable_delegate_clustering, int num_threads,
218
+ bool default_delegate_latest_features) {
204
219
if (!model) {
205
220
*error_msg = error_reporter->message ();
206
221
return nullptr ;
207
222
}
208
223
209
224
std::unique_ptr<tflite::MutableOpResolver> resolver;
210
- switch (op_resolver_id) {
211
- case kBuiltinOpResolver :
212
- resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
213
- break ;
214
- case kBuiltinRefOpResolver :
215
- resolver = std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
216
- break ;
217
- case kBuiltinOpResolverWithoutDefaultDelegates :
218
- resolver = std::make_unique<
219
- tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
220
- break ;
221
- default :
222
- // This should not never happen because the eventual caller in
223
- // interpreter.py should have passed a valid id here.
224
- TFLITE_DCHECK (false );
225
- return nullptr ;
225
+ if (default_delegate_latest_features) {
226
+ resolver = std::make_unique<
227
+ tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
228
+ } else {
229
+ switch (op_resolver_id) {
230
+ case kBuiltinOpResolver :
231
+ resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
232
+ break ;
233
+ case kBuiltinRefOpResolver :
234
+ resolver =
235
+ std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
236
+ break ;
237
+ case kBuiltinOpResolverWithoutDefaultDelegates :
238
+ resolver = std::make_unique<
239
+ tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
240
+ break ;
241
+ default :
242
+ // This should not never happen because the eventual caller in
243
+ // interpreter.py should have passed a valid id here.
244
+ TFLITE_DCHECK (false );
245
+ return nullptr ;
246
+ }
226
247
}
227
248
228
249
for (const auto & registerer : registerers_by_name) {
@@ -232,9 +253,9 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
232
253
for (const auto & registerer : registerers_by_func) {
233
254
registerer (reinterpret_cast <uintptr_t >(resolver.get ()));
234
255
}
235
- auto interpreter =
236
- CreateInterpreter ( model.get (), *resolver, preserve_all_tensors,
237
- disable_delegate_clustering );
256
+ auto interpreter = CreateInterpreter (
257
+ model.get (), *resolver, preserve_all_tensors, disable_delegate_clustering ,
258
+ num_threads, default_delegate_latest_features );
238
259
if (!interpreter) {
239
260
*error_msg = error_reporter->message ();
240
261
return nullptr ;
@@ -806,14 +827,16 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
806
827
const std::vector<std::string>& registerers_by_name,
807
828
const std::vector<std::function<void (uintptr_t )>>& registerers_by_func,
808
829
std::string* error_msg, bool preserve_all_tensors,
809
- bool disable_delegate_clustering) {
830
+ bool disable_delegate_clustering, int num_threads,
831
+ bool default_delegate_latest_features) {
810
832
std::unique_ptr<PythonErrorReporter> error_reporter (new PythonErrorReporter);
811
833
std::unique_ptr<InterpreterWrapper::Model> model =
812
834
Model::BuildFromFile (model_path, error_reporter.get ());
813
835
return CreateInterpreterWrapper (
814
836
std::move (model), op_resolver_id, std::move (error_reporter),
815
837
registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors,
816
- disable_delegate_clustering);
838
+ disable_delegate_clustering, num_threads,
839
+ default_delegate_latest_features);
817
840
}
818
841
819
842
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile (
@@ -822,15 +845,17 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
822
845
bool preserve_all_tensors, bool disable_delegate_clustering) {
823
846
return CreateWrapperCPPFromFile (
824
847
model_path, op_resolver_id, registerers, {} /* registerers_by_func*/ ,
825
- error_msg, preserve_all_tensors, disable_delegate_clustering);
848
+ error_msg, preserve_all_tensors, disable_delegate_clustering,
849
+ /* num_threads=*/ 1 , /* default_delegate_latest_features=*/ false );
826
850
}
827
851
828
852
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer (
829
853
PyObject* data, int op_resolver_id,
830
854
const std::vector<std::string>& registerers_by_name,
831
855
const std::vector<std::function<void (uintptr_t )>>& registerers_by_func,
832
856
std::string* error_msg, bool preserve_all_tensors,
833
- bool disable_delegate_clustering) {
857
+ bool disable_delegate_clustering, int num_threads,
858
+ bool default_delegate_latest_features) {
834
859
char * buf = nullptr ;
835
860
Py_ssize_t length;
836
861
std::unique_ptr<PythonErrorReporter> error_reporter (new PythonErrorReporter);
@@ -843,16 +868,18 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
843
868
return CreateInterpreterWrapper (
844
869
std::move (model), op_resolver_id, std::move (error_reporter),
845
870
registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors,
846
- disable_delegate_clustering);
871
+ disable_delegate_clustering, num_threads,
872
+ default_delegate_latest_features);
847
873
}
848
874
849
875
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer (
850
876
PyObject* data, int op_resolver_id,
851
877
const std::vector<std::string>& registerers, std::string* error_msg,
852
878
bool preserve_all_tensors, bool disable_delegate_clustering) {
853
- return CreateWrapperCPPFromBuffer (data, op_resolver_id, registerers, {},
854
- error_msg, preserve_all_tensors,
855
- disable_delegate_clustering);
879
+ return CreateWrapperCPPFromBuffer (
880
+ data, op_resolver_id, registerers, {}, error_msg, preserve_all_tensors,
881
+ disable_delegate_clustering, /* num_threads=*/ 1 ,
882
+ /* default_delegate_latest_features=*/ false );
856
883
}
857
884
858
885
PyObject* InterpreterWrapper::ResetVariableTensors () {
0 commit comments