Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit 3d31a03

Browse files
authored
Demo app updates to use PyTorch 1.9 org.pytorch:pytorch_android_lite:1.9.0 (#151)
* initial commit * Revert "initial commit" This reverts commit 5a65775. * main readme and helloworld/demo app readme updates * build.gradle, README and code update for PT1.9 for HelloWorld and Object Detection using pytorch_android_lite:1.9.0 * build.gradle, README and code update for PT1.9 for Question Answering using pytorch_android_lite:1.9.0 * build.gradle, README and code update for PT1.9 for TorchVideo using pytorch_android_lite:1.9.0 * HelloWorld script fix * README update for TorchVideo
1 parent 367d2d9 commit 3d31a03

File tree

18 files changed

+134
-80
lines changed

18 files changed

+134
-80
lines changed

HelloWorldApp/README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This application runs TorchScript serialized TorchVision pretrained [MobileNet v
88
Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model(MobileNet v3), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html).
99
To install it, run the command below:
1010
```
11-
pip install torchvision
11+
pip install torch torchvision
1212
```
1313

1414
To serialize and optimize the model for Android, you can use the Python [script](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py) in the root folder of HelloWorld app:
@@ -22,12 +22,12 @@ model.eval()
2222
example = torch.rand(1, 3, 224, 224)
2323
traced_script_module = torch.jit.trace(model, example)
2424
optimized_traced_model = optimize_for_mobile(traced_script_module)
25-
optimized_traced_model.save("app/src/main/assets/model.pt")
25+
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
2626
```
2727
If everything works well, we should have our scripted and optimized model - `model.pt` generated in the assets folder of android application.
2828
That will be packaged inside android application as `asset` and can be used on the device.
2929

30-
By using the new MobileNet v3 model instead of the old Resnet18 model, and by calling the `optimize_for_mobile` method on the traced model, the model inference time on a Pixel 3 gets decreased from over 230ms to about 40ms.
30+
By using the new MobileNet v3 model instead of the old Resnet18 model, and by calling the `optimize_for_mobile` method on the traced model, the model inference time on a Pixel 3 gets decreased from over 230ms to about 40ms.
3131

3232
More details about TorchScript you can find in [tutorials on pytorch.org](https://pytorch.org/docs/stable/jit.html)
3333

@@ -54,8 +54,8 @@ repositories {
5454
}
5555
5656
dependencies {
57-
implementation 'org.pytorch:pytorch_android:1.4.0'
58-
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
57+
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
58+
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
5959
}
6060
```
6161
Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
@@ -73,7 +73,7 @@ Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
7373

7474
#### 5. Loading TorchScript Module
7575
```
76-
Module module = Module.load(assetFilePath(this, "model.pt"));
76+
Module module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
7777
```
7878
`org.pytorch.Module` represents `torch::jit::script::Module` that can be loaded with `load` method specifying file path to the serialized to file model.
7979

HelloWorldApp/app/build.gradle

+2-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
apply plugin: 'com.android.application'
22

3-
repositories {
4-
jcenter()
5-
maven {
6-
url "https://oss.sonatype.org/content/repositories/snapshots"
7-
}
8-
}
9-
103
android {
114
compileSdkVersion 28
125
buildToolsVersion "29.0.2"
@@ -26,6 +19,6 @@ android {
2619

2720
dependencies {
2821
implementation 'androidx.appcompat:appcompat:1.1.0'
29-
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
30-
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
22+
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
23+
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
3124
}
9.71 MB
Binary file not shown.

HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import android.widget.TextView;
1010

1111
import org.pytorch.IValue;
12+
import org.pytorch.LiteModuleLoader;
1213
import org.pytorch.Module;
1314
import org.pytorch.Tensor;
1415
import org.pytorch.torchvision.TensorImageUtils;
@@ -37,7 +38,7 @@ protected void onCreate(Bundle savedInstanceState) {
3738
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
3839
// loading serialized torchscript module from packaged into app android asset model.pt,
3940
// app/src/model/assets/model.pt
40-
module = Module.load(assetFilePath(this, "model.pt"));
41+
module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
4142
} catch (IOException e) {
4243
Log.e("PytorchHelloWorld", "Error reading assets", e);
4344
finish();

HelloWorldApp/trace_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
example = torch.rand(1, 3, 224, 224)
88
traced_script_module = torch.jit.trace(model, example)
99
optimized_traced_model = optimize_for_mobile(traced_script_module)
10-
optimized_traced_model.save("app/src/main/assets/model.pt")
10+
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.pt")

ObjectDetection/README.md

+12-8
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
## Prerequisites
88

9-
* PyTorch 1.7 or later (Optional)
9+
* PyTorch 1.9.0 or later (Optional)
1010
* Python 3.8 (Optional)
11-
* Android Pytorch library 1.7.0
11+
* Android Pytorch library pytorch_android_lite:1.9.0 and pytorch_android_torchvision:1.9.0
1212
* Android Studio 4.0.1 or later
1313

1414
## Quick Start
@@ -17,9 +17,9 @@ To Test Run the Object Detection Android App, follow the steps below:
1717

1818
### 1. Prepare the Model
1919

20-
If you don't have the PyTorch environment set up to run the script, you can download the model file [here](https://drive.google.com/file/d/1QOxNfpy_j_1KbuhN8INw2AgAC82nEw0F/view?usp=sharing) to the `android-demo-app/ObjectDetection/app/src/main/assets` folder, then skip the rest of this step and go to step 2 directly.
20+
If you don't have the PyTorch environment set up to run the script, you can download the model file `yolov5s.torchscript.ptl` [here](https://drive.google.com/u/1/uc?id=1_MF7NVi9Csm1lizoSCp1wCtUUUpuhwet&export=download) to the `android-demo-app/ObjectDetection/app/src/main/assets` folder, then skip the rest of this step and go to step 2 directly.
2121

22-
Be aware that the downloadable model file was created with PyTorch 1.7.0, matching the PyTorch Android library 1.7.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.7.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
22+
Be aware that the downloadable model file was created with PyTorch 1.9.0, matching the PyTorch Android library 1.9.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android_lite:1.9.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
2323

2424
The Python script `export.py` in the `models` folder of the [YOLOv5 repo](https://github.com/ultralytics/yolov5) is used to generate a TorchScript-formatted YOLOv5 model named `yolov5s.torchscript.pt` for mobile apps.
2525

@@ -31,20 +31,24 @@ cd yolov5
3131
pip install -r requirements.txt
3232
```
3333

34-
Then edit `models/export.py` to make two changes:
34+
Then edit `models/export.py` to make the following four changes:
3535

36-
* Change the line 50 from `model.model[-1].export = True` to `model.model[-1].export = False`
36+
* Change line 50 from `model.model[-1].export = True` to `model.model[-1].export = False`
37+
38+
* Change line 56 from `f = opt.weights.replace('.pt', '.torchscript.pt')` to `f = opt.weights.replace('.pt', '.torchscript.ptl')`
3739

3840
* Add the following two lines of model optimization code after line 57, between `ts = torch.jit.trace(model, img)` and `ts.save(f)`:
3941

4042
```
4143
from torch.utils.mobile_optimizer import optimize_for_mobile
42-
ts = optimize_for_mobile(ts)
44+
ts = optimize_for_mobile(ts)
4345
```
4446

47+
* Replace the line `ts.save(f)` with `ts._save_for_lite_interpreter(f)`.
48+
4549
If you ignore this step, you can still create a TorchScript model for mobile apps to use, but the inference on a non-optimized model can take twice as long as the inference on an optimized model - using the Android app test images, the average inference time on an optimized and non-optimized model is 0.6 seconds and 1.18 seconds, respectively. See [SCRIPT AND OPTIMIZE FOR MOBILE RECIPE](https://pytorch.org/tutorials/recipes/script_optimized.html) for more details.
4650

47-
Finally, run the script below to generate the optimized TorchScript model and copy the generated model file `yolov5s.torchscript.pt` to the `android-demo-app/ObjectDetection/app/src/main/assets` folder:
51+
Now run the script below to generate the optimized TorchScript model and copy the generated model file `yolov5s.torchscript.ptl` to the `android-demo-app/ObjectDetection/app/src/main/assets` folder:
4852

4953
```
5054
python models/export.py

ObjectDetection/app/build.gradle

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ dependencies {
3737
implementation "androidx.camera:camera-core:$camerax_version"
3838
implementation "androidx.camera:camera-camera2:$camerax_version"
3939

40-
implementation 'org.pytorch:pytorch_android:1.7.0'
41-
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'
40+
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
41+
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
4242
}

ObjectDetection/app/src/main/java/org/pytorch/demo/objectdetection/MainActivity.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import android.widget.ProgressBar;
3131

3232
import org.pytorch.IValue;
33+
import org.pytorch.LiteModuleLoader;
3334
import org.pytorch.Module;
34-
import org.pytorch.PyTorchAndroid;
3535
import org.pytorch.Tensor;
3636
import org.pytorch.torchvision.TensorImageUtils;
3737

@@ -57,6 +57,25 @@ public class MainActivity extends AppCompatActivity implements Runnable {
5757
private Module mModule = null;
5858
private float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;
5959

60+
public static String assetFilePath(Context context, String assetName) throws IOException {
61+
File file = new File(context.getFilesDir(), assetName);
62+
if (file.exists() && file.length() > 0) {
63+
return file.getAbsolutePath();
64+
}
65+
66+
try (InputStream is = context.getAssets().open(assetName)) {
67+
try (OutputStream os = new FileOutputStream(file)) {
68+
byte[] buffer = new byte[4 * 1024];
69+
int read;
70+
while ((read = is.read(buffer)) != -1) {
71+
os.write(buffer, 0, read);
72+
}
73+
os.flush();
74+
}
75+
return file.getAbsolutePath();
76+
}
77+
}
78+
6079
@Override
6180
protected void onCreate(Bundle savedInstanceState) {
6281
super.onCreate(savedInstanceState);
@@ -162,7 +181,7 @@ public void onClick(View v) {
162181
});
163182

164183
try {
165-
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), "yolov5s.torchscript.pt");
184+
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript.ptl"));
166185
BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open("classes.txt")));
167186
String line;
168187
List<String> classes = new ArrayList<>();

ObjectDetection/app/src/main/java/org/pytorch/demo/objectdetection/ObjectDetectionActivity.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import android.graphics.Rect;
88
import android.graphics.YuvImage;
99
import android.media.Image;
10-
import android.os.Bundle;
1110
import android.util.Log;
1211
import android.view.TextureView;
1312
import android.view.ViewStub;
@@ -17,8 +16,8 @@
1716
import androidx.camera.core.ImageProxy;
1817

1918
import org.pytorch.IValue;
19+
import org.pytorch.LiteModuleLoader;
2020
import org.pytorch.Module;
21-
import org.pytorch.PyTorchAndroid;
2221
import org.pytorch.Tensor;
2322
import org.pytorch.torchvision.TensorImageUtils;
2423

@@ -85,8 +84,13 @@ private Bitmap imgToBitmap(Image image) {
8584
@WorkerThread
8685
@Nullable
8786
protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {
88-
if (mModule == null) {
89-
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), "yolov5s.torchscript.pt");
87+
try {
88+
if (mModule == null) {
89+
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript.ptl"));
90+
}
91+
} catch (IOException e) {
92+
Log.e("Object Detection", "Error reading assets", e);
93+
return null;
9094
}
9195
Bitmap bitmap = imgToBitmap(image.getImage());
9296
Matrix matrix = new Matrix();

QuestionAnswering/README.md

+7-24
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
## Introduction
44

5-
Question Answering (QA) is one of the common and challenging Natural Language Processing tasks. With the revolutionary transformed-based [Bert](https://arxiv.org/abs/1810.04805) model coming out in October 2018, question answering models have reached their state of art accuracy by fine-tuning Bert-like models on QA datasets such as [Squad](https://rajpurkar.github.io/SQuAD-explorer). [Huggingface](https://huggingface.co)'s [DistilBert](https://huggingface.co/transformers/model_doc/distilbert.html) is a smaller and faster version of BERT - DistilBert "has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances as measured on the GLUE language understanding benchmark."
5+
Question Answering (QA) is one of the common and challenging Natural Language Processing tasks. With the revolutionary transformed-based [BERT](https://arxiv.org/abs/1810.04805) model coming out in October 2018, question answering models have reached their state of art accuracy by fine-tuning BERT-like models on QA datasets such as [Squad](https://rajpurkar.github.io/SQuAD-explorer). [Huggingface](https://huggingface.co)'s [DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html) is a smaller and faster version of BERT - DistilBERT "has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances as measured on the GLUE language understanding benchmark."
66

77
In this demo app, written in Kotlin, we'll show how to quantize and convert the Huggingface's DistilBert QA model to TorchScript and how to use the scripted model on an Android demo app to perform question answering.
88

99
## Prerequisites
1010

11-
* PyTorch 1.7 or later (Optional)
11+
* PyTorch 1.9.0 or later (Optional)
1212
* Python 3.8 (Optional)
13-
* Android Pytorch library 1.7 or later
13+
* Android Pytorch library org.pytorch:pytorch_android_lite:1.9.0
1414
* Android Studio 4.0.1 or later
1515

1616
## Quick Start
@@ -19,32 +19,15 @@ To Test Run the Android QA App, run the following commands on a Terminal:
1919

2020
### 1. Prepare the Model
2121

22-
If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the scripted QA model compressed in a zip file [here](https://drive.google.com/file/d/1RWZa_5oSQg5AfInkn344DN3FJ5WbbZbq/view?usp=sharing), then unzip it to the assets folder, and continue to Step 2.
22+
If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the scripted QA model `qa360_quantized.ptl` [here](https://drive.google.com/file/d/1PgD3pAEf0riUiT3BfwHOm6UEGk8FfJzI/view?usp=sharing) and save it to the `QuestionAnswering/app/src/main/assets` folder, then continue to Step 2.
2323

24-
Be aware that the downloadable model file was created with PyTorch 1.7.0, matching the PyTorch Android library 1.7.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.7.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
24+
Be aware that the downloadable model file was created with PyTorch 1.9.0, matching the PyTorch Android library 1.9.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.9.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
2525

26-
With PyTorch 1.7 installed, first install the Huggingface `transformers` by running `pip install transformers` (the versions that have been tested are 4.0.0 and 4.1.1), then run `python convert_distilbert_qa.py`.
26+
With PyTorch 1.9.0 installed, first install the Huggingface `transformers` by running `pip install transformers`, then run `python convert_distilbert_qa.py`.
2727

2828
Note that a pre-defined question and text, resulting in the size of the input tokens (of question and text) being 360, is used in the `convert_distilbert_qa.py`, and 360 is the maximum token size for the user text and question in the app. If the token size of the inputs of the text and question is less than 360, padding will be needed to make the model work correctly.
2929

30-
After the script completes, copy the model file qa360_quantized.pt to the Android app's assets folder. [Dynamic quantization](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html) is used to quantize the model to reduce its size to half, without causing inference difference in question answering - you can verify this by changing the last 4 lines of code in `convert_distilbert_qa.py` from:
31-
32-
```
33-
model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
34-
traced_model = torch.jit.trace(model_dynamic_quantized, inputs['input_ids'], strict=False)
35-
optimized_traced_model = optimize_for_mobile(traced_model)
36-
torch.jit.save(optimized_traced_model, "qa360_quantized.pt")
37-
```
38-
39-
to
40-
41-
```
42-
traced_model = torch.jit.trace(model, inputs['input_ids'], strict=False)
43-
optimized_traced_model = optimize_for_mobile(traced_model)
44-
torch.jit.save(optimized_traced_model, "qa360.pt")
45-
```
46-
47-
and rerun `python convert_distilbert_qa.py` to generate a non-quantized model `qa360.pt` and use it in the app to compare with the quantized version `qa360_quantized.pt`.
30+
After the script completes, copy the model file `qa360_quantized.ptl` to the Android app's assets folder.
4831

4932

5033
### 2. Build and run with Android Studio

QuestionAnswering/app/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies {
3232
implementation 'androidx.appcompat:appcompat:1.2.0'
3333
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
3434

35-
implementation 'org.pytorch:pytorch_android:1.7.0'
35+
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
3636
implementation "androidx.core:core-ktx:+"
3737
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
3838

0 commit comments

Comments
 (0)