Skip to content

Commit 5268b24

Browse files
authored
Android method metadata
Differential Revision: D75156980 Pull Request resolved: #11023
1 parent 6daeb64 commit 5268b24

File tree

6 files changed

+113
-4
lines changed

6 files changed

+113
-4
lines changed

extension/android/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ non_fbcode_target(_kind = fb_android_library,
88
srcs = [
99
"executorch_android/src/main/java/org/pytorch/executorch/DType.java",
1010
"executorch_android/src/main/java/org/pytorch/executorch/EValue.java",
11+
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1112
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1213
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
1314
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ModuleE2ETest {
6969

7070
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
7171
val expectedBackends = arrayOf("XnnpackBackend")
72-
Assert.assertArrayEquals(expectedBackends, module.getUsedBackends("forward"))
72+
Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends())
7373
}
7474

7575
@Test

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ class ModuleInstrumentationTest {
5555
Assert.assertTrue(results[0].isTensor)
5656
}
5757

58+
@Test
59+
@Throws(IOException::class, URISyntaxException::class)
60+
fun testMethodMetadata() {
61+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
62+
63+
Assert.assertArrayEquals(arrayOf("forward"), module.getMethods())
64+
Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty())
65+
}
66+
5867
@Test
5968
@Throws(IOException::class)
6069
fun testModuleLoadMethodAndForward() {
@@ -91,7 +100,7 @@ class ModuleInstrumentationTest {
91100
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
92101
}
93102

94-
@Test
103+
@Test(expected = RuntimeException::class)
95104
@Throws(IOException::class)
96105
fun testNonPteFile() {
97106
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
/** Helper class to access the metadata for a method from a Module */
12+
public class MethodMetadata {
13+
private String mName;
14+
15+
private String[] mBackends;
16+
17+
MethodMetadata setName(String name) {
18+
mName = name;
19+
return this;
20+
}
21+
22+
/**
23+
* @return Method name
24+
*/
25+
public String getName() {
26+
return mName;
27+
}
28+
29+
MethodMetadata setBackends(String[] backends) {
30+
mBackends = backends;
31+
return this;
32+
}
33+
34+
/**
35+
* @return Backends used for this method
36+
*/
37+
public String[] getBackends() {
38+
return mBackends;
39+
}
40+
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import com.facebook.soloader.nativeloader.NativeLoader;
1515
import com.facebook.soloader.nativeloader.SystemDelegate;
1616
import java.io.File;
17+
import java.util.HashMap;
18+
import java.util.Map;
1719
import java.util.concurrent.locks.Lock;
1820
import java.util.concurrent.locks.ReentrantLock;
1921
import org.pytorch.executorch.annotations.Experimental;
@@ -48,12 +50,27 @@ public class Module {
4850

4951
private final HybridData mHybridData;
5052

53+
private final Map<String, MethodMetadata> mMethodMetadata;
54+
5155
@DoNotStrip
5256
private static native HybridData initHybrid(
5357
String moduleAbsolutePath, int loadMode, int initHybrid);
5458

5559
private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
5660
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
61+
62+
mMethodMetadata = populateMethodMeta();
63+
}
64+
65+
Map<String, MethodMetadata> populateMethodMeta() {
66+
String[] methods = getMethods();
67+
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
68+
for (int i = 0; i < methods.length; i++) {
69+
String name = methods[i];
70+
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
71+
}
72+
73+
return metadata;
5774
}
5875

5976
/** Lock protecting the non-thread safe methods in mHybridData. */
@@ -158,13 +175,34 @@ public int loadMethod(String methodName) {
158175
private native int loadMethodNative(String methodName);
159176

160177
/**
161-
* Returns the names of the methods in a certain method.
178+
* Returns the names of the backends in a certain method.
162179
*
163180
* @param methodName method name to query
164181
* @return an array of backend name
165182
*/
166183
@DoNotStrip
167-
public native String[] getUsedBackends(String methodName);
184+
private native String[] getUsedBackends(String methodName);
185+
186+
/**
187+
* Returns the names of methods.
188+
*
189+
* @return name of methods in this Module
190+
*/
191+
@DoNotStrip
192+
public native String[] getMethods();
193+
194+
/**
195+
* Get the corresponding @MethodMetadata for a method
196+
*
197+
* @param name method name
198+
* @return @MethodMetadata for this method
199+
*/
200+
public MethodMetadata getMethodMetadata(String name) {
201+
if (!mMethodMetadata.containsKey(name)) {
202+
throw new RuntimeException("method " + name + "does not exist for this module");
203+
}
204+
return mMethodMetadata.get(name);
205+
}
168206

169207
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
170208
public String[] readLogBuffer() {

extension/android/jni/jni_layer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
431431
return false;
432432
}
433433

434+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
435+
const auto& names_result = module_->method_names();
436+
if (!names_result.ok()) {
437+
facebook::jni::throwNewJavaException(
438+
facebook::jni::gJavaLangIllegalArgumentException,
439+
"Cannot get load module");
440+
}
441+
const auto& methods = names_result.get();
442+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
443+
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
444+
int i = 0;
445+
for (auto s : methods) {
446+
facebook::jni::local_ref<facebook::jni::JString> method_name =
447+
facebook::jni::make_jstring(s.c_str());
448+
(*ret)[i] = method_name;
449+
i++;
450+
}
451+
return ret;
452+
}
453+
434454
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
435455
facebook::jni::alias_ref<jstring> methodName) {
436456
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
@@ -458,6 +478,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
458478
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
459479
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
460480
makeNativeMethod("etdump", ExecuTorchJni::etdump),
481+
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
461482
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
462483
});
463484
}

0 commit comments

Comments
 (0)