Skip to content

Commit 92ae47d

Browse files
committed
Finalize the refit revision
1 parent 3e8323f commit 92ae47d

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

examples/apps/flux-demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"use_fp32_acc": True,
4242
"use_explicit_typing": True,
4343
"debug": False,
44-
"use_python_runtime": False,
44+
"use_python_runtime": True,
4545
"immutable_weights": False,
4646
# "cache_built_engines": True,
4747
# "reuse_cached_engines": True,

examples/dynamo/torch_export_flux_dev.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
min_block_size=1,
113113
use_fp32_acc=True,
114114
use_explicit_typing=True,
115+
use_python_runtime=True,
115116
)
116117

117118
# %%
@@ -126,7 +127,7 @@
126127
torch.cuda.empty_cache()
127128
pipe.transformer = trt_gm
128129
pipe.transformer.config = config
129-
130+
trt_gm.device = torch.device("cuda")
130131
# %%
131132
# Image generation using prompt
132133
# ---------------------------

py/torch_tensorrt/dynamo/_refit.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -509,23 +509,22 @@ def refit_module_weights(
509509
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
510510
serialized_engine = engine.serialize_with_config(serialization_config)
511511

512-
del engine
513-
gc.collect()
514-
torch.cuda.empty_cache()
515-
516-
if isinstance(
517-
compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
518-
):
512+
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
513+
compiled_submodule.serialized_engine = bytes(serialized_engine)
514+
elif isinstance(compiled_submodule, TorchTensorRTModule):
519515
compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated
520516
compiled_submodule.serialized_engine = bytes(serialized_engine)
521517
compiled_submodule.setup_engine()
522-
523518
elif inline_module:
524519
new_engine_info = list(engine_info)
525520
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
526521
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
527522
setattr(compiled_module, f"{name}_engine", refitted_engine)
528523

524+
del engine
525+
gc.collect()
526+
torch.cuda.empty_cache()
527+
529528
# TODO: Memory control prototyping. Under discussion
530529
if settings.offload_module_to_cpu:
531530
del new_partitioned_module

tests/py/dynamo/models/test_model_refit.py

+1
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ def forward(self, x):
763763
debug=True,
764764
min_block_size=1,
765765
immutable_weights=False,
766+
offload_module_to_cpu=False,
766767
)
767768

768769
num_pyt_segments = len(

0 commit comments

Comments
 (0)