Skip to content

Commit 737b24f

Browse files
authored
Add onnx tutorial (#990)
* fixup * Remove from tracing shape check * Add notebook with example * Update readme examples * Fixup
1 parent 076f684 commit 737b24f

File tree

3 files changed

+226
-1
lines changed

3 files changed

+226
-1
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
8181
- Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb).
8282
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb)
8383
- Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)).
84+
- Export trained model to ONNX - [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)
8485

8586
### 📦 Models <a name="models"></a>
8687

examples/convert_to_onnx.ipynb

+223
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"# to make onnx export work\n",
17+
"!pip install onnx onnxruntime"
18+
]
19+
},
20+
{
21+
"cell_type": "markdown",
22+
"metadata": {},
23+
"source": [
24+
"See complete tutorial in Pytorch docs:\n",
25+
" - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": 1,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"import onnx\n",
35+
"import onnxruntime\n",
36+
"import numpy as np\n",
37+
"\n",
38+
"import torch\n",
39+
"import segmentation_models_pytorch as smp"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"### Create random model (or load your own model)"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 2,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n",
56+
"model = model.eval()"
57+
]
58+
},
59+
{
60+
"cell_type": "markdown",
61+
"metadata": {},
62+
"source": [
63+
"### Export the model to ONNX"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 3,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n",
73+
"dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n",
74+
"\n",
75+
"onnx_model_name = \"unet_resnet34.onnx\"\n",
76+
"\n",
77+
"onnx_model = torch.onnx.export(\n",
78+
" model, # model being run\n",
79+
" torch.randn(1, 3, 224, 224), # model input\n",
80+
" onnx_model_name, # where to save the model (can be a file or file-like object)\n",
81+
" export_params=True, # store the trained parameter weights inside the model file\n",
82+
" opset_version=17, # the ONNX version to export\n",
83+
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
84+
" input_names=[\"input\"], # the model's input names\n",
85+
" output_names=[\"output\"], # the model's output names\n",
86+
" dynamic_axes={ # variable length axes\n",
87+
" \"input\": dynamic_axes,\n",
88+
" \"output\": dynamic_axes,\n",
89+
" },\n",
90+
")"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": 4,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"# check with onnx first\n",
100+
"onnx_model = onnx.load(onnx_model_name)\n",
101+
"onnx.checker.check_model(onnx_model)"
102+
]
103+
},
104+
{
105+
"cell_type": "markdown",
106+
"metadata": {},
107+
"source": [
108+
"### Run with onnxruntime"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": 5,
114+
"metadata": {},
115+
"outputs": [
116+
{
117+
"data": {
118+
"text/plain": [
119+
"[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n",
120+
" 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n",
121+
" [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n",
122+
" 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n",
123+
" [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n",
124+
" -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n",
125+
" ...,\n",
126+
" [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n",
127+
" -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n",
128+
" [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n",
129+
" -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n",
130+
" [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n",
131+
" -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n",
132+
" \n",
133+
" \n",
134+
" [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n",
135+
" 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n",
136+
" [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n",
137+
" 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n",
138+
" [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n",
139+
" -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n",
140+
" ...,\n",
141+
" [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n",
142+
" -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n",
143+
" [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n",
144+
" 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n",
145+
" [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n",
146+
" -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n",
147+
" dtype=float32)]"
148+
]
149+
},
150+
"execution_count": 5,
151+
"metadata": {},
152+
"output_type": "execute_result"
153+
}
154+
],
155+
"source": [
156+
"# create sample with different batch size, height and width\n",
157+
"# from what we used in export above\n",
158+
"sample = torch.randn(2, 3, 512, 512)\n",
159+
"\n",
160+
"ort_session = onnxruntime.InferenceSession(\n",
161+
" onnx_model_name, providers=[\"CPUExecutionProvider\"]\n",
162+
")\n",
163+
"\n",
164+
"# compute ONNX Runtime output prediction\n",
165+
"ort_inputs = {\"input\": sample.numpy()}\n",
166+
"ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n",
167+
"ort_outputs"
168+
]
169+
},
170+
{
171+
"cell_type": "markdown",
172+
"metadata": {},
173+
"source": [
174+
"### Verify it's the same as for pytorch model"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": 6,
180+
"metadata": {},
181+
"outputs": [
182+
{
183+
"name": "stdout",
184+
"output_type": "stream",
185+
"text": [
186+
"Exported model has been tested with ONNXRuntime, and the result looks good!\n"
187+
]
188+
}
189+
],
190+
"source": [
191+
"# compute PyTorch output prediction\n",
192+
"with torch.no_grad():\n",
193+
" torch_out = model(sample)\n",
194+
"\n",
195+
"# compare ONNX Runtime and PyTorch results\n",
196+
"np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n",
197+
"\n",
198+
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
199+
]
200+
}
201+
],
202+
"metadata": {
203+
"kernelspec": {
204+
"display_name": ".venv",
205+
"language": "python",
206+
"name": "python3"
207+
},
208+
"language_info": {
209+
"codemirror_mode": {
210+
"name": "ipython",
211+
"version": 3
212+
},
213+
"file_extension": ".py",
214+
"mimetype": "text/x-python",
215+
"name": "python",
216+
"nbconvert_exporter": "python",
217+
"pygments_lexer": "ipython3",
218+
"version": "3.10.12"
219+
}
220+
},
221+
"nbformat": 4,
222+
"nbformat_minor": 2
223+
}

segmentation_models_pytorch/base/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def check_input_shape(self, x):
3333
def forward(self, x):
3434
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
3535

36-
self.check_input_shape(x)
36+
if not torch.jit.is_tracing():
37+
self.check_input_shape(x)
3738

3839
features = self.encoder(x)
3940
decoder_output = self.decoder(*features)

0 commit comments

Comments
 (0)