Skip to content

Commit 187101a

Browse files
gchananlara-hdrneginraoof
authored
[v1.4.0] Minimal changes in interpolate to support Keypointrcnn (pytorch#32010)
* Fix interpolate * add keypointrcnn test * update ort versio for test * pin tv version * Update test.sh * Get rid of onnxruntime test changes. * [v1.4.0] Added torchvision tests as part of ORT tests (pytorch#31835) Summary: Added torchvision tests as part of ORT tests Pull Request resolved: pytorch#31835 Reviewed By: hl475 Differential Revision: D19278607 Pulled By: houseroad fbshipit-source-id: 18a6a85ce3019bcc9aee9517af1378964b585afd * Remove faster_rcnn and mask_rcnn tests. Co-authored-by: Lara Haidar <[email protected]> Co-authored-by: Negin Raoof <[email protected]>
1 parent e011d4a commit 187101a

File tree

3 files changed

+55
-30
lines changed

3 files changed

+55
-30
lines changed

.jenkins/caffe2/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,15 @@ pip install --user pytest-sugar
133133
# torchvision tests #
134134
#####################
135135
if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
136-
pip install -q --user git+https://github.com/pytorch/vision.git
136+
pip install -q --user git+https://github.com/pytorch/vision.git@v0.5.0
137137
pip install -q --user ninja
138138
# JIT C++ extensions require ninja, so put it into PATH.
139139
export PATH="/var/lib/jenkins/.local/bin:$PATH"
140140
if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
141141
# default pip version is too old(9.0.2), unable to support tag `manylinux2010`.
142142
# Fix the pip error: Couldn't find a version that satisfies the requirement
143143
sudo pip install --upgrade pip
144-
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1104
144+
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.1.0.dev1228
145145
fi
146146
"$ROOT_DIR/scripts/onnx/test.sh"
147147
fi

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -183,31 +183,6 @@ def test_deeplab(self):
183183
x = torch.randn(2, 3, 224, 224, requires_grad=True)
184184
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
185185

186-
def test_googlenet_quantization(self):
187-
model = torchvision.models.quantization.googlenet(pretrained=True)
188-
x = torch.randn(2, 3, 224, 224, requires_grad=True)
189-
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
190-
191-
def test_inception_quantization(self):
192-
model = torchvision.models.quantization.inception_v3(pretrained=True)
193-
x = torch.randn(2, 3, 224, 224, requires_grad=True)
194-
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
195-
196-
def test_mobilenet_quantization(self):
197-
model = torchvision.models.quantization.mobilenet_v2(pretrained=True)
198-
x = torch.randn(2, 3, 224, 224, requires_grad=True)
199-
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
200-
201-
def test_resnet_quantization(self):
202-
model = torchvision.models.quantization.resnet50(pretrained=True)
203-
x = torch.randn(2, 3, 224, 224, requires_grad=True)
204-
self.run_test(model, (x,))
205-
206-
def test_shufflenet_quantization(self):
207-
model = torchvision.models.quantization.shufflenet_v2_x1_0(pretrained=True)
208-
x = torch.randn(2, 3, 224, 224, requires_grad=True)
209-
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
210-
211186
def test_r3d_18_video(self):
212187
model = torchvision.models.video.r3d_18(pretrained=True)
213188
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
@@ -238,6 +213,55 @@ def run_word_language_model(self, model_name):
238213
# Only support CPU version, since tracer is not working in GPU RNN.
239214
self.run_test(model, (x, model.hidden))
240215

216+
def get_image_from_url(self, url):
217+
import sys
218+
import os
219+
if sys.version_info < (3,):
220+
from urlparse import urlsplit
221+
import urllib2
222+
request = urllib2
223+
else:
224+
from urllib.parse import urlsplit
225+
from urllib import request
226+
from PIL import Image
227+
from torchvision import transforms
228+
from torch._utils_internal import get_writable_path
229+
230+
filename = os.path.basename(urlsplit(url)[2])
231+
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
232+
path = os.path.join(data_dir, filename)
233+
data = request.urlopen(url, timeout=15).read()
234+
with open(path, 'wb') as f:
235+
f.write(data)
236+
image = Image.open(path).convert("RGB")
237+
image = image.resize((300, 200), Image.BILINEAR)
238+
to_tensor = transforms.ToTensor()
239+
return to_tensor(image)
240+
241+
def get_test_images(self):
242+
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
243+
image = self.get_image_from_url(url=image_url)
244+
images = [image]
245+
return images
246+
247+
@skipIfUnsupportedMinOpsetVersion(11)
248+
def test_keypoint_rcnn(self):
249+
class KeyPointRCNN(torch.nn.Module):
250+
def __init__(self):
251+
super(KeyPointRCNN, self).__init__()
252+
self.model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True,
253+
min_size=200,
254+
max_size=300)
255+
256+
def forward(self, images):
257+
output = self.model(images)
258+
# TODO: The keypoints_scores require the use of Argmax that is updated in ONNX.
259+
# For now we are testing all the output of KeypointRCNN except keypoints_scores.
260+
# Enable When Argmax is updated in ONNX Runtime.
261+
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']
262+
images = self.get_test_images()
263+
self.run_test(KeyPointRCNN(), (images,), rtol=1e-3, atol=1e-5)
264+
241265
def test_word_language_model_RNN_TANH(self):
242266
self.run_word_language_model("RNN_TANH")
243267

torch/onnx/symbolic_opset11.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ def symbolic_fn(g, input, output_size, align_corners=None):
5656
coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
5757
else "align_corners" if align_corners else "pytorch_half_pixel"
5858
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
59-
input_size = input.type().sizes()
60-
input_size = g.op("Constant", value_t=torch.tensor(input_size[0:2], dtype=torch.int64))
59+
input_size = g.op("Shape", input)
60+
input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
6161
output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"])
62-
output_size = g.op("Concat", input_size, output_size, axis_i=0)
62+
output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
6363

6464
return g.op("Resize",
6565
input,
@@ -115,6 +115,7 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners):
115115
size = unsqueeze(g, size, 0)
116116
size = [size for i in range(input.type().dim() - 2)]
117117
size = g.op("Concat", *size, axis_i=0)
118+
size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
118119
size = g.op("Concat", input_size, size, axis_i=0)
119120
scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
120121
return g.op("Resize",

0 commit comments

Comments
 (0)