Skip to content

Commit 0570874

Browse files
committed
[android][native_app] App example of linking to gradle deps native libs and torchscript CustomOp
1 parent 1e36f9e commit 0570874

12 files changed

+385
-0
lines changed

NativeApp/app/CMakeLists.txt

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
cmake_minimum_required(VERSION 3.4.1)
2+
set(TARGET pytorch_nativeapp)
3+
project(${TARGET} CXX)
4+
set(CMAKE_CXX_STANDARD 14)
5+
6+
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
7+
8+
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
9+
file(GLOB pytorch_testapp_SOURCES
10+
${pytorch_testapp_cpp_DIR}/pytorch_nativeapp.cpp
11+
)
12+
13+
add_library(${TARGET} SHARED
14+
${pytorch_testapp_SOURCES}
15+
)
16+
17+
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
18+
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
19+
20+
target_compile_options(${TARGET} PRIVATE
21+
-fexceptions
22+
)
23+
24+
set(BUILD_SUBDIR ${ANDROID_ABI})
25+
26+
find_library(PYTORCH_LIBRARY pytorch_jni
27+
PATHS ${PYTORCH_LINK_DIRS}
28+
NO_CMAKE_FIND_ROOT_PATH)
29+
find_library(FBJNI_LIBRARY fbjni
30+
PATHS ${PYTORCH_LINK_DIRS}
31+
NO_CMAKE_FIND_ROOT_PATH)
32+
33+
# OpenCV
34+
if(NOT DEFINED ENV{OPENCV_ANDROID_SDK})
35+
message(FATAL_ERROR "Environment var OPENCV_ANDROID_SDK set")
36+
endif()
37+
38+
set(OPENCV_INCLUDE_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/jni/include")
39+
40+
target_include_directories(${TARGET} PRIVATE
41+
"${OPENCV_INCLUDE_DIR}"
42+
${PYTORCH_INCLUDE_DIRS})
43+
44+
set(OPENCV_LIB_DIR "$ENV{OPENCV_ANDROID_SDK}/sdk/native/libs/${ANDROID_ABI}")
45+
46+
find_library(OPENCV_LIBRARY opencv_java4
47+
PATHS ${OPENCV_LIB_DIR}
48+
NO_CMAKE_FIND_ROOT_PATH)
49+
50+
target_link_libraries(${TARGET}
51+
${PYTORCH_LIBRARY}
52+
${FBJNI_LIBRARY}
53+
${OPENCV_LIBRARY}
54+
log)

NativeApp/app/build.gradle

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
apply plugin: 'com.android.application'
2+
3+
repositories {
4+
jcenter()
5+
maven {
6+
url "https://oss.sonatype.org/content/repositories/snapshots"
7+
}
8+
}
9+
10+
android {
11+
configurations {
12+
extractForNativeBuild
13+
}
14+
compileSdkVersion 28
15+
buildToolsVersion "29.0.2"
16+
defaultConfig {
17+
applicationId "org.pytorch.nativeapp"
18+
minSdkVersion 21
19+
targetSdkVersion 28
20+
versionCode 1
21+
versionName "1.0"
22+
externalNativeBuild {
23+
cmake {
24+
arguments "-DANDROID_STL=c++_shared"
25+
}
26+
}
27+
}
28+
buildTypes {
29+
release {
30+
minifyEnabled false
31+
}
32+
}
33+
externalNativeBuild {
34+
cmake {
35+
path "CMakeLists.txt"
36+
}
37+
}
38+
sourceSets {
39+
main {
40+
jniLibs.srcDirs = ['src/main/jniLibs']
41+
}
42+
}
43+
}
44+
45+
dependencies {
46+
implementation 'com.android.support:appcompat-v7:28.0.0'
47+
48+
implementation 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
49+
extractForNativeBuild 'org.pytorch:pytorch_android:1.6.0-SNAPSHOT'
50+
}
51+
52+
task extractAARForNativeBuild {
53+
doLast {
54+
configurations.extractForNativeBuild.files.each {
55+
def file = it.absoluteFile
56+
copy {
57+
from zipTree(file)
58+
into "$buildDir/$file.name"
59+
include "headers/**"
60+
include "jni/**"
61+
}
62+
}
63+
}
64+
}
65+
66+
tasks.whenTaskAdded { task ->
67+
if (task.name.contains('externalNativeBuild')) {
68+
task.dependsOn(extractAARForNativeBuild)
69+
}
70+
}
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
3+
package="org.pytorch.nativeapp">
4+
5+
<application
6+
android:allowBackup="true"
7+
android:label="PyTorchNativeApp"
8+
android:supportsRtl="true"
9+
android:theme="@style/Theme.AppCompat.Light.DarkActionBar">
10+
11+
<activity android:name=".MainActivity">
12+
<intent-filter>
13+
<action android:name="android.intent.action.MAIN" />
14+
15+
<category android:name="android.intent.category.LAUNCHER" />
16+
</intent-filter>
17+
</activity>
18+
</application>
19+
</manifest>
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*
2+
*/
3+
!.gitignore
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#include <android/log.h>
2+
#include <cassert>
3+
#include <cmath>
4+
#include <pthread.h>
5+
#include <unistd.h>
6+
#include <vector>
7+
#define ALOGI(...) \
8+
__android_log_print(ANDROID_LOG_INFO, "PyTorchNativeApp", __VA_ARGS__)
9+
#define ALOGE(...) \
10+
__android_log_print(ANDROID_LOG_ERROR, "PyTorchNativeApp", __VA_ARGS__)
11+
12+
#include "jni.h"
13+
14+
#include <opencv2/opencv.hpp>
15+
#include <torch/script.h>
16+
17+
namespace pytorch_nativeapp {
18+
namespace {
19+
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
20+
cv::Mat image_mat(/*rows=*/image.size(0),
21+
/*cols=*/image.size(1),
22+
/*type=*/CV_32FC1,
23+
/*data=*/image.data_ptr<float>());
24+
cv::Mat warp_mat(/*rows=*/warp.size(0),
25+
/*cols=*/warp.size(1),
26+
/*type=*/CV_32FC1,
27+
/*data=*/warp.data_ptr<float>());
28+
29+
cv::Mat output_mat;
30+
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
31+
32+
torch::Tensor output =
33+
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
34+
return output.clone();
35+
}
36+
37+
static auto registry =
38+
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
39+
40+
template <typename T> void log(const char *m, T t) {
41+
std::ostringstream os;
42+
os << t << std::endl;
43+
ALOGI("%s %s", m, os.str().c_str());
44+
}
45+
46+
struct JITCallGuard {
47+
torch::autograd::AutoGradMode no_autograd_guard{false};
48+
torch::AutoNonVariableTypeMode non_var_guard{true};
49+
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
50+
};
51+
} // namespace
52+
53+
static void loadAndForwardModel(JNIEnv *env, jclass, jstring jModelPath) {
54+
const char *modelPath = env->GetStringUTFChars(jModelPath, 0);
55+
assert(modelPath);
56+
57+
// To load torchscript model for mobile we need set these guards,
58+
// because mobile build doesn't support features like autograd for smaller
59+
// build size which is placed in `struct JITCallGuard` in this example. It may
60+
// change in future, you can track the latest changes keeping an eye in
61+
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
62+
JITCallGuard guard;
63+
torch::jit::Module module = torch::jit::load(modelPath);
64+
module.eval();
65+
torch::Tensor x = torch::randn({4, 8});
66+
torch::Tensor y = torch::randn({8, 5});
67+
log("x:", x);
68+
log("y:", y);
69+
c10::IValue t_out = module.forward({x, y});
70+
log("result:", t_out);
71+
env->ReleaseStringUTFChars(jModelPath, modelPath);
72+
}
73+
} // namespace pytorch_nativeapp
74+
75+
JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) {
76+
JNIEnv *env;
77+
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
78+
return JNI_ERR;
79+
}
80+
81+
jclass c = env->FindClass("org/pytorch/nativeapp/NativeClient$NativePeer");
82+
if (c == nullptr) {
83+
return JNI_ERR;
84+
}
85+
86+
static const JNINativeMethod methods[] = {
87+
{"loadAndForwardModel", "(Ljava/lang/String;)V",
88+
(void *)pytorch_nativeapp::loadAndForwardModel},
89+
};
90+
int rc = env->RegisterNatives(c, methods,
91+
sizeof(methods) / sizeof(JNINativeMethod));
92+
93+
if (rc != JNI_OK) {
94+
return rc;
95+
}
96+
97+
return JNI_VERSION_1_6;
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.pytorch.nativeapp;
2+
3+
import android.content.Context;
4+
import android.os.Bundle;
5+
import android.util.Log;
6+
import androidx.appcompat.app.AppCompatActivity;
7+
import java.io.File;
8+
import java.io.FileOutputStream;
9+
import java.io.IOException;
10+
import java.io.InputStream;
11+
import java.io.OutputStream;
12+
13+
public class MainActivity extends AppCompatActivity {
14+
15+
private static final String TAG = "PyTorchNativeApp";
16+
17+
public static String assetFilePath(Context context, String assetName) {
18+
File file = new File(context.getFilesDir(), assetName);
19+
if (file.exists() && file.length() > 0) {
20+
return file.getAbsolutePath();
21+
}
22+
23+
try (InputStream is = context.getAssets().open(assetName)) {
24+
try (OutputStream os = new FileOutputStream(file)) {
25+
byte[] buffer = new byte[4 * 1024];
26+
int read;
27+
while ((read = is.read(buffer)) != -1) {
28+
os.write(buffer, 0, read);
29+
}
30+
os.flush();
31+
}
32+
return file.getAbsolutePath();
33+
} catch (IOException e) {
34+
Log.e(TAG, "Error process asset " + assetName + " to file path");
35+
}
36+
return null;
37+
}
38+
39+
@Override
40+
protected void onCreate(Bundle savedInstanceState) {
41+
super.onCreate(savedInstanceState);
42+
final String modelFileAbsoluteFilePath =
43+
new File(assetFilePath(this, "compute.pt")).getAbsolutePath();
44+
NativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
45+
}
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.pytorch.nativeapp;
2+
3+
public final class NativeClient {
4+
5+
public static void loadAndForwardModel(final String modelPath) {
6+
NativePeer.loadAndForwardModel(modelPath);
7+
}
8+
9+
private static class NativePeer {
10+
static {
11+
System.loadLibrary("pytorch_nativeapp");
12+
}
13+
14+
private static native void loadAndForwardModel(final String modelPath);
15+
}
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*
2+
*/
3+
!.gitignore

NativeApp/build.gradle

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
buildscript {
2+
repositories {
3+
google()
4+
jcenter()
5+
}
6+
dependencies {
7+
classpath 'com.android.tools.build:gradle:3.5.0'
8+
}
9+
}
10+
11+
allprojects {
12+
repositories {
13+
google()
14+
jcenter()
15+
}
16+
}
17+
18+
task clean(type: Delete) {
19+
delete rootProject.buildDir
20+
}

NativeApp/gradle.properties

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
android.useAndroidX=true
2+
android.enableJetifier=true
3+

NativeApp/make_warp_perspective_pt.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.utils.cpp_extension
3+
4+
print(torch.version.__version__)
5+
op_source = """
6+
#include <opencv2/opencv.hpp>
7+
#include <torch/script.h>
8+
9+
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
10+
cv::Mat image_mat(/*rows=*/image.size(0),
11+
/*cols=*/image.size(1),
12+
/*type=*/CV_32FC1,
13+
/*data=*/image.data_ptr<float>());
14+
cv::Mat warp_mat(/*rows=*/warp.size(0),
15+
/*cols=*/warp.size(1),
16+
/*type=*/CV_32FC1,
17+
/*data=*/warp.data_ptr<float>());
18+
19+
cv::Mat output_mat;
20+
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{64, 64});
21+
22+
torch::Tensor output =
23+
torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{64, 64});
24+
return output.clone();
25+
}
26+
27+
static auto registry =
28+
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
29+
"""
30+
31+
torch.utils.cpp_extension.load_inline(
32+
name="warp_perspective",
33+
cpp_sources=op_source,
34+
extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
35+
is_python_module=False,
36+
verbose=True,
37+
)
38+
39+
print(torch.ops.my_ops.warp_perspective)
40+
41+
42+
@torch.jit.script
43+
def compute(x, y):
44+
if bool(x[0][0] == 42):
45+
z = 5
46+
else:
47+
z = 10
48+
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
49+
return x.matmul(y) + z
50+
51+
52+
compute.save("app/src/main/assets/compute.pt")

NativeApp/settings.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include ':app'

0 commit comments

Comments
 (0)