|
6 | 6 | import torch
|
7 | 7 | import torch_tensorrt as torchtrt
|
8 | 8 | import torchvision.models as models
|
9 |
| -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity |
| 9 | +from torch_tensorrt.dynamo.utils import ( |
| 10 | + COSINE_THRESHOLD, |
| 11 | + cosine_similarity, |
| 12 | + get_model_device, |
| 13 | +) |
10 | 14 |
|
11 | 15 | assertions = unittest.TestCase()
|
12 | 16 |
|
@@ -283,6 +287,53 @@ def test_resnet18(ir):
|
283 | 287 | )
|
284 | 288 |
|
285 | 289 |
|
| 290 | +@pytest.mark.unit |
| 291 | +def test_resnet18_cpu_offload(ir): |
| 292 | + """ |
| 293 | + This tests export save and load functionality on Resnet18 model |
| 294 | + """ |
| 295 | + model = models.resnet18().eval().cuda() |
| 296 | + input = torch.randn((1, 3, 224, 224)).to("cuda") |
| 297 | + |
| 298 | + compile_spec = { |
| 299 | + "inputs": [ |
| 300 | + torchtrt.Input( |
| 301 | + input.shape, dtype=torch.float, format=torch.contiguous_format |
| 302 | + ) |
| 303 | + ], |
| 304 | + "ir": ir, |
| 305 | + "min_block_size": 1, |
| 306 | + "cache_built_engines": False, |
| 307 | + "reuse_cached_engines": False, |
| 308 | + "offload_module_to_cpu": True, |
| 309 | + } |
| 310 | + |
| 311 | + exp_program = torchtrt.dynamo.trace(model, **compile_spec) |
| 312 | + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) |
| 313 | + assertions.assertTrue( |
| 314 | + get_model_device(model).type == "cpu", |
| 315 | + msg="Model should be offloaded to CPU", |
| 316 | + ) |
| 317 | + model.cuda() |
| 318 | + torchtrt.save(trt_module, trt_ep_path) |
| 319 | + |
| 320 | + deser_trt_module = torchtrt.load(trt_ep_path).module() |
| 321 | + outputs_pyt = model(input) |
| 322 | + outputs_trt = trt_module(input) |
| 323 | + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) |
| 324 | + assertions.assertTrue( |
| 325 | + cos_sim > COSINE_THRESHOLD, |
| 326 | + msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 327 | + ) |
| 328 | + |
| 329 | + outputs_trt_deser = deser_trt_module(input) |
| 330 | + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) |
| 331 | + assertions.assertTrue( |
| 332 | + cos_sim > COSINE_THRESHOLD, |
| 333 | + msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 334 | + ) |
| 335 | + |
| 336 | + |
286 | 337 | @pytest.mark.unit
|
287 | 338 | def test_resnet18_dynamic(ir):
|
288 | 339 | """
|
@@ -381,6 +432,67 @@ def forward(self, x):
|
381 | 432 | )
|
382 | 433 |
|
383 | 434 |
|
| 435 | +@pytest.mark.unit |
| 436 | +def test_hybrid_conv_fallback_cpu_offload(ir): |
| 437 | + """ |
| 438 | + This tests export save and load functionality on a hybrid |
| 439 | + model where a conv (a weighted layer) has been forced to fallback to Pytorch. |
| 440 | + """ |
| 441 | + |
| 442 | + class MyModule(torch.nn.Module): |
| 443 | + def __init__(self): |
| 444 | + super().__init__() |
| 445 | + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) |
| 446 | + self.relu = torch.nn.ReLU() |
| 447 | + |
| 448 | + def forward(self, x): |
| 449 | + conv = self.conv(x) |
| 450 | + relu = self.relu(conv) |
| 451 | + mul = relu * 0.5 |
| 452 | + return mul |
| 453 | + |
| 454 | + model = MyModule().eval().cuda() |
| 455 | + input = torch.randn((1, 3, 224, 224)).to("cuda") |
| 456 | + |
| 457 | + compile_spec = { |
| 458 | + "inputs": [ |
| 459 | + torchtrt.Input( |
| 460 | + input.shape, dtype=torch.float, format=torch.contiguous_format |
| 461 | + ) |
| 462 | + ], |
| 463 | + "ir": ir, |
| 464 | + "min_block_size": 1, |
| 465 | + "torch_executed_ops": {"torch.ops.aten.convolution.default"}, |
| 466 | + "cache_built_engines": False, |
| 467 | + "reuse_cached_engines": False, |
| 468 | + "offload_module_to_cpu": True, |
| 469 | + } |
| 470 | + |
| 471 | + exp_program = torchtrt.dynamo.trace(model, **compile_spec) |
| 472 | + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) |
| 473 | + model.cuda() |
| 474 | + torchtrt.save(trt_module, trt_ep_path) |
| 475 | + |
| 476 | + deser_trt_module = torchtrt.load(trt_ep_path).module() |
| 477 | + outputs_pyt = model(input) |
| 478 | + outputs_trt = trt_module(input) |
| 479 | + |
| 480 | + for idx in range(len(outputs_pyt)): |
| 481 | + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) |
| 482 | + assertions.assertTrue( |
| 483 | + cos_sim > COSINE_THRESHOLD, |
| 484 | + msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 485 | + ) |
| 486 | + |
| 487 | + outputs_trt_deser = deser_trt_module(input) |
| 488 | + for idx in range(len(outputs_pyt)): |
| 489 | + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) |
| 490 | + assertions.assertTrue( |
| 491 | + cos_sim > COSINE_THRESHOLD, |
| 492 | + msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 493 | + ) |
| 494 | + |
| 495 | + |
384 | 496 | @pytest.mark.unit
|
385 | 497 | def test_arange_export(ir):
|
386 | 498 | """
|
|
0 commit comments