Skip to content

Commit 0feda1c

Browse files
committed
Fixed comments
1 parent ef8288a commit 0feda1c

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def compile(
500500
enable_weight_streaming (bool): Enable weight streaming.
501501
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
502502
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
503+
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
503504
**kwargs: Any,
504505
Returns:
505506
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -678,17 +679,24 @@ def compile(
678679
)
679680

680681
gm = exported_program.module()
681-
# Move the weights in the state_dict to CPU
682682
logger.debug("Input graph: " + str(gm.graph))
683683

684684
# Apply lowering on the graph module
685685
gm = post_lowering(gm, settings)
686686
logger.debug("Lowered Input graph: " + str(gm.graph))
687+
688+
# Move the weights in the state_dict to CPU
687689
if offload_module_to_cpu:
688690
exported_program.module().to(CPU_DEVICE)
689691
logger.info(
690-
"The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
692+
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
691693
)
694+
else:
695+
remaining_memory, total_memory = torch.cuda.mem_get_info()
696+
if remaining_memory < total_memory // 2:
697+
logger.warning(
698+
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
699+
)
692700
trt_gm = compile_module(
693701
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
694702
)
@@ -833,7 +841,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
833841
str(name),
834842
str(submodule.graph),
835843
)
836-
submodule.to(torch.cuda.current_device())
844+
submodule.to(to_torch_device(settings.device))
837845
continue
838846

839847
if name not in submodule_node_dict:

tests/py/dynamo/models/test_export_serde.py

+101
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,46 @@ def test_resnet18_dynamic(ir):
321321
)
322322

323323

324+
@pytest.mark.unit
325+
def test_resnet18_dynamic_cpu_offload(ir):
326+
"""
327+
This tests export save and load functionality on Resnet18 model
328+
"""
329+
model = models.resnet18().eval().cuda()
330+
input = torch.randn((1, 3, 224, 224)).to("cuda")
331+
332+
compile_spec = {
333+
"inputs": [
334+
torchtrt.Input(
335+
min_shape=(1, 3, 224, 224),
336+
opt_shape=(4, 3, 224, 224),
337+
max_shape=(8, 3, 224, 224),
338+
dtype=torch.float32,
339+
name="x",
340+
)
341+
],
342+
"ir": ir,
343+
"min_block_size": 1,
344+
"cache_built_engines": False,
345+
"reuse_cached_engines": False,
346+
"offload_module_to_cpu": True,
347+
}
348+
349+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
350+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
351+
model.cuda()
352+
torchtrt.save(trt_module, trt_ep_path)
353+
# TODO: Enable this serialization issues are fixed
354+
# deser_trt_module = torchtrt.load(trt_ep_path).module()
355+
outputs_pyt = model(input)
356+
outputs_trt = trt_module(input)
357+
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
358+
assertions.assertTrue(
359+
cos_sim > COSINE_THRESHOLD,
360+
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
361+
)
362+
363+
324364
@pytest.mark.unit
325365
def test_hybrid_conv_fallback(ir):
326366
"""
@@ -381,6 +421,67 @@ def forward(self, x):
381421
)
382422

383423

424+
@pytest.mark.unit
425+
def test_hybrid_conv_fallback_cpu_offload(ir):
426+
"""
427+
This tests export save and load functionality on a hybrid
428+
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
429+
"""
430+
431+
class MyModule(torch.nn.Module):
432+
def __init__(self):
433+
super().__init__()
434+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
435+
self.relu = torch.nn.ReLU()
436+
437+
def forward(self, x):
438+
conv = self.conv(x)
439+
relu = self.relu(conv)
440+
mul = relu * 0.5
441+
return mul
442+
443+
model = MyModule().eval().cuda()
444+
input = torch.randn((1, 3, 224, 224)).to("cuda")
445+
446+
compile_spec = {
447+
"inputs": [
448+
torchtrt.Input(
449+
input.shape, dtype=torch.float, format=torch.contiguous_format
450+
)
451+
],
452+
"ir": ir,
453+
"min_block_size": 1,
454+
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
455+
"cache_built_engines": False,
456+
"reuse_cached_engines": False,
457+
"offload_module_to_cpu": True,
458+
}
459+
460+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
461+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
462+
model.cuda()
463+
torchtrt.save(trt_module, trt_ep_path)
464+
465+
deser_trt_module = torchtrt.load(trt_ep_path).module()
466+
outputs_pyt = model(input)
467+
outputs_trt = trt_module(input)
468+
469+
for idx in range(len(outputs_pyt)):
470+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
471+
assertions.assertTrue(
472+
cos_sim > COSINE_THRESHOLD,
473+
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
474+
)
475+
476+
outputs_trt_deser = deser_trt_module(input)
477+
for idx in range(len(outputs_pyt)):
478+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
479+
assertions.assertTrue(
480+
cos_sim > COSINE_THRESHOLD,
481+
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
482+
)
483+
484+
384485
@pytest.mark.unit
385486
def test_arange_export(ir):
386487
"""

tests/py/dynamo/models/test_model_refit.py

+68
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,74 @@ def forward(self, x):
492492
torch._dynamo.reset()
493493

494494

495+
@pytest.mark.unit
496+
def test_refit_multiple_engine_with_weightmap_cpu_offload():
497+
class net(nn.Module):
498+
def __init__(self):
499+
super().__init__()
500+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
501+
self.bn = nn.BatchNorm2d(12)
502+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
503+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
504+
505+
def forward(self, x):
506+
x = self.conv1(x)
507+
x = F.relu(x)
508+
x = self.bn(x)
509+
x = F.max_pool2d(x, (2, 2))
510+
x = self.conv2(x)
511+
x = F.relu(x)
512+
x = F.max_pool2d(x, (2, 2))
513+
x = torch.flatten(x, 1)
514+
return self.fc1(x)
515+
516+
model = net().eval().to("cuda")
517+
model2 = net().eval().to("cuda")
518+
519+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
520+
enabled_precisions = {torch.float}
521+
debug = False
522+
min_block_size = 1
523+
use_python_runtime = False
524+
525+
exp_program = torch.export.export(model, tuple(inputs))
526+
exp_program2 = torch.export.export(model2, tuple(inputs))
527+
528+
torch_executed_ops = {"torch.ops.aten.convolution.default"}
529+
trt_gm = torchtrt.dynamo.compile(
530+
exp_program,
531+
tuple(inputs),
532+
use_python_runtime=use_python_runtime,
533+
enabled_precisions=enabled_precisions,
534+
debug=debug,
535+
min_block_size=min_block_size,
536+
immutable_weights=False,
537+
torch_executed_ops=torch_executed_ops,
538+
reuse_cached_engines=False,
539+
offload_module_to_cpu=True,
540+
)
541+
542+
new_trt_gm = refit_module_weights(
543+
compiled_module=trt_gm,
544+
new_weight_module=exp_program2,
545+
arg_inputs=inputs,
546+
use_weight_map_cache=True,
547+
)
548+
model2.cuda()
549+
# Check the output
550+
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
551+
*inputs
552+
)
553+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
554+
assertions.assertTrue(
555+
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
556+
"Refit Result is not correct. Refit failed",
557+
)
558+
# Clean up model env
559+
560+
torch._dynamo.reset()
561+
562+
495563
@unittest.skipIf(
496564
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
497565
"TorchScript Frontend is not available",

0 commit comments

Comments
 (0)