Skip to content

Commit 0d38a99

Browse files
committed
input and output tag API
1 parent 3083980 commit 0d38a99

File tree

5 files changed

+78
-7
lines changed

5 files changed

+78
-7
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ class ModuleE2ETest {
6868
inputStream.close()
6969

7070
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
71-
val expectedBackends = arrayOf("XnnpackBackend")
7271
Assert.assertArrayEquals(
73-
expectedBackends,
72+
arrayOf("XnnpackBackend"),
7473
module.getMethodMetadata("forward").getBackends(),
7574
)
7675
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class ModuleInstrumentationTest {
6262

6363
Assert.assertArrayEquals(arrayOf("forward"), module.getMethods())
6464
Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty())
65+
Assert.assertArrayEquals(intArrayOf(1, 1, 3), module.getMethodMetadata("forward").inputTags)
66+
Assert.assertArrayEquals(intArrayOf(1), module.getMethodMetadata("forward").outputTags)
6567
}
6668

6769
@Test

extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
/** Helper class to access the metadata for a method from a Module */
1212
public class MethodMetadata {
1313
private String mName;
14-
1514
private String[] mBackends;
15+
private int[] mInputTags;
16+
private int[] mOutputTags;
1617

1718
MethodMetadata setName(String name) {
1819
mName = name;
@@ -37,4 +38,28 @@ MethodMetadata setBackends(String[] backends) {
3738
public String[] getBackends() {
3839
return mBackends;
3940
}
41+
42+
/**
43+
* @return Output tags
44+
*/
45+
public int[] getOutputTags() {
46+
return mOutputTags;
47+
}
48+
49+
MethodMetadata setOutputTags(int[] outputTags) {
50+
mOutputTags = outputTags;
51+
return this;
52+
}
53+
54+
/**
55+
* @return Input tags
56+
*/
57+
public int[] getInputTags() {
58+
return mInputTags;
59+
}
60+
61+
MethodMetadata setInputTags(int[] inputTags) {
62+
mInputTags = inputTags;
63+
return this;
64+
}
4065
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ Map<String, MethodMetadata> populateMethodMeta() {
5959
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
6060
for (int i = 0; i < methods.length; i++) {
6161
String name = methods[i];
62-
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
62+
metadata.put(
63+
name,
64+
new MethodMetadata()
65+
.setName(name)
66+
.setBackends(getUsedBackends(name))
67+
.setInputTags(getInputTags(name))
68+
.setOutputTags(getOutputTags(name)));
6369
}
6470

6571
return metadata;
@@ -204,6 +210,12 @@ public String[] readLogBuffer() {
204210
@DoNotStrip
205211
private native String[] readLogBufferNative();
206212

213+
@DoNotStrip
214+
private native int[] getInputTags(String method);
215+
216+
@DoNotStrip
217+
private native int[] getOutputTags(String method);
218+
207219
/**
208220
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
209221
*

extension/android/jni/jni_layer.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,11 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
453453

454454
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
455455
facebook::jni::alias_ref<jstring> methodName) {
456-
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
456+
auto method_meta =
457+
module_->method_meta(methodName->toStdString()).get();
457458
std::unordered_set<std::string> backends;
458-
for (auto i = 0; i < methodMeta.num_backends(); i++) {
459-
backends.insert(methodMeta.get_backend_name(i).get());
459+
for (auto i = 0; i < method_meta.num_backends(); i++) {
460+
backends.insert(method_meta.get_backend_name(i).get());
460461
}
461462

462463
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
@@ -471,6 +472,36 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
471472
return ret;
472473
}
473474

475+
facebook::jni::local_ref<facebook::jni::JArrayInt> getInputTags(
476+
facebook::jni::alias_ref<jstring> methodName) {
477+
auto method_meta =
478+
module_->method_meta(methodName->toStdString()).get();
479+
auto num_inputs = method_meta.num_inputs();
480+
facebook::jni::local_ref<facebook::jni::JArrayInt> ret =
481+
facebook::jni::JArrayInt::newArray(num_inputs);
482+
483+
int i = 0;
484+
for (int i = 0; i < num_inputs; i++) {
485+
ret->pin()[i] = static_cast<uint32_t>(method_meta.input_tag(i).get());
486+
}
487+
return ret;
488+
}
489+
490+
facebook::jni::local_ref<facebook::jni::JArrayInt> getOutputTags(
491+
facebook::jni::alias_ref<jstring> methodName) {
492+
auto method_meta =
493+
module_->method_meta(methodName->toStdString()).get();
494+
auto num_outputs = method_meta.num_outputs();
495+
facebook::jni::local_ref<facebook::jni::JArrayInt> ret =
496+
facebook::jni::JArrayInt::newArray(num_outputs);
497+
498+
int i = 0;
499+
for (int i = 0; i < num_outputs; i++) {
500+
ret->pin()[i] = static_cast<uint32_t>(method_meta.output_tag(i).get());
501+
}
502+
return ret;
503+
}
504+
474505
static void registerNatives() {
475506
registerHybrid({
476507
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
@@ -480,6 +511,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
480511
makeNativeMethod("etdump", ExecuTorchJni::etdump),
481512
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
482513
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
514+
makeNativeMethod("getInputTags", ExecuTorchJni::getInputTags),
515+
makeNativeMethod("getOutputTags", ExecuTorchJni::getOutputTags),
483516
});
484517
}
485518
};

0 commit comments

Comments
 (0)