|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "[](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)" |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | + "cell_type": "code", |
| 12 | + "execution_count": null, |
| 13 | + "metadata": {}, |
| 14 | + "outputs": [], |
| 15 | + "source": [ |
| 16 | + "# to make onnx export work\n", |
| 17 | + "!pip install onnx onnxruntime" |
| 18 | + ] |
| 19 | + }, |
| 20 | + { |
| 21 | + "cell_type": "markdown", |
| 22 | + "metadata": {}, |
| 23 | + "source": [ |
| 24 | + "See complete tutorial in Pytorch docs:\n", |
| 25 | + " - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html" |
| 26 | + ] |
| 27 | + }, |
| 28 | + { |
| 29 | + "cell_type": "code", |
| 30 | + "execution_count": 1, |
| 31 | + "metadata": {}, |
| 32 | + "outputs": [], |
| 33 | + "source": [ |
| 34 | + "import onnx\n", |
| 35 | + "import onnxruntime\n", |
| 36 | + "import numpy as np\n", |
| 37 | + "\n", |
| 38 | + "import torch\n", |
| 39 | + "import segmentation_models_pytorch as smp" |
| 40 | + ] |
| 41 | + }, |
| 42 | + { |
| 43 | + "cell_type": "markdown", |
| 44 | + "metadata": {}, |
| 45 | + "source": [ |
| 46 | + "### Create random model (or load your own model)" |
| 47 | + ] |
| 48 | + }, |
| 49 | + { |
| 50 | + "cell_type": "code", |
| 51 | + "execution_count": 2, |
| 52 | + "metadata": {}, |
| 53 | + "outputs": [], |
| 54 | + "source": [ |
| 55 | + "model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n", |
| 56 | + "model = model.eval()" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "markdown", |
| 61 | + "metadata": {}, |
| 62 | + "source": [ |
| 63 | + "### Export the model to ONNX" |
| 64 | + ] |
| 65 | + }, |
| 66 | + { |
| 67 | + "cell_type": "code", |
| 68 | + "execution_count": 3, |
| 69 | + "metadata": {}, |
| 70 | + "outputs": [], |
| 71 | + "source": [ |
| 72 | + "# dynamic_axes is used to specify the variable length axes. it can be just batch size\n", |
| 73 | + "dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n", |
| 74 | + "\n", |
| 75 | + "onnx_model_name = \"unet_resnet34.onnx\"\n", |
| 76 | + "\n", |
| 77 | + "onnx_model = torch.onnx.export(\n", |
| 78 | + " model, # model being run\n", |
| 79 | + " torch.randn(1, 3, 224, 224), # model input\n", |
| 80 | + " onnx_model_name, # where to save the model (can be a file or file-like object)\n", |
| 81 | + " export_params=True, # store the trained parameter weights inside the model file\n", |
| 82 | + " opset_version=17, # the ONNX version to export\n", |
| 83 | + " do_constant_folding=True, # whether to execute constant folding for optimization\n", |
| 84 | + " input_names=[\"input\"], # the model's input names\n", |
| 85 | + " output_names=[\"output\"], # the model's output names\n", |
| 86 | + " dynamic_axes={ # variable length axes\n", |
| 87 | + " \"input\": dynamic_axes,\n", |
| 88 | + " \"output\": dynamic_axes,\n", |
| 89 | + " },\n", |
| 90 | + ")" |
| 91 | + ] |
| 92 | + }, |
| 93 | + { |
| 94 | + "cell_type": "code", |
| 95 | + "execution_count": 4, |
| 96 | + "metadata": {}, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "# check with onnx first\n", |
| 100 | + "onnx_model = onnx.load(onnx_model_name)\n", |
| 101 | + "onnx.checker.check_model(onnx_model)" |
| 102 | + ] |
| 103 | + }, |
| 104 | + { |
| 105 | + "cell_type": "markdown", |
| 106 | + "metadata": {}, |
| 107 | + "source": [ |
| 108 | + "### Run with onnxruntime" |
| 109 | + ] |
| 110 | + }, |
| 111 | + { |
| 112 | + "cell_type": "code", |
| 113 | + "execution_count": 5, |
| 114 | + "metadata": {}, |
| 115 | + "outputs": [ |
| 116 | + { |
| 117 | + "data": { |
| 118 | + "text/plain": [ |
| 119 | + "[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n", |
| 120 | + " 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n", |
| 121 | + " [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n", |
| 122 | + " 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n", |
| 123 | + " [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n", |
| 124 | + " -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n", |
| 125 | + " ...,\n", |
| 126 | + " [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n", |
| 127 | + " -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n", |
| 128 | + " [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n", |
| 129 | + " -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n", |
| 130 | + " [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n", |
| 131 | + " -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n", |
| 132 | + " \n", |
| 133 | + " \n", |
| 134 | + " [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n", |
| 135 | + " 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n", |
| 136 | + " [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n", |
| 137 | + " 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n", |
| 138 | + " [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n", |
| 139 | + " -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n", |
| 140 | + " ...,\n", |
| 141 | + " [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n", |
| 142 | + " -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n", |
| 143 | + " [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n", |
| 144 | + " 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n", |
| 145 | + " [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n", |
| 146 | + " -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n", |
| 147 | + " dtype=float32)]" |
| 148 | + ] |
| 149 | + }, |
| 150 | + "execution_count": 5, |
| 151 | + "metadata": {}, |
| 152 | + "output_type": "execute_result" |
| 153 | + } |
| 154 | + ], |
| 155 | + "source": [ |
| 156 | + "# create sample with different batch size, height and width\n", |
| 157 | + "# from what we used in export above\n", |
| 158 | + "sample = torch.randn(2, 3, 512, 512)\n", |
| 159 | + "\n", |
| 160 | + "ort_session = onnxruntime.InferenceSession(\n", |
| 161 | + " onnx_model_name, providers=[\"CPUExecutionProvider\"]\n", |
| 162 | + ")\n", |
| 163 | + "\n", |
| 164 | + "# compute ONNX Runtime output prediction\n", |
| 165 | + "ort_inputs = {\"input\": sample.numpy()}\n", |
| 166 | + "ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n", |
| 167 | + "ort_outputs" |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "markdown", |
| 172 | + "metadata": {}, |
| 173 | + "source": [ |
| 174 | + "### Verify it's the same as for pytorch model" |
| 175 | + ] |
| 176 | + }, |
| 177 | + { |
| 178 | + "cell_type": "code", |
| 179 | + "execution_count": 6, |
| 180 | + "metadata": {}, |
| 181 | + "outputs": [ |
| 182 | + { |
| 183 | + "name": "stdout", |
| 184 | + "output_type": "stream", |
| 185 | + "text": [ |
| 186 | + "Exported model has been tested with ONNXRuntime, and the result looks good!\n" |
| 187 | + ] |
| 188 | + } |
| 189 | + ], |
| 190 | + "source": [ |
| 191 | + "# compute PyTorch output prediction\n", |
| 192 | + "with torch.no_grad():\n", |
| 193 | + " torch_out = model(sample)\n", |
| 194 | + "\n", |
| 195 | + "# compare ONNX Runtime and PyTorch results\n", |
| 196 | + "np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n", |
| 197 | + "\n", |
| 198 | + "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" |
| 199 | + ] |
| 200 | + } |
| 201 | + ], |
| 202 | + "metadata": { |
| 203 | + "kernelspec": { |
| 204 | + "display_name": ".venv", |
| 205 | + "language": "python", |
| 206 | + "name": "python3" |
| 207 | + }, |
| 208 | + "language_info": { |
| 209 | + "codemirror_mode": { |
| 210 | + "name": "ipython", |
| 211 | + "version": 3 |
| 212 | + }, |
| 213 | + "file_extension": ".py", |
| 214 | + "mimetype": "text/x-python", |
| 215 | + "name": "python", |
| 216 | + "nbconvert_exporter": "python", |
| 217 | + "pygments_lexer": "ipython3", |
| 218 | + "version": "3.10.12" |
| 219 | + } |
| 220 | + }, |
| 221 | + "nbformat": 4, |
| 222 | + "nbformat_minor": 2 |
| 223 | +} |
0 commit comments