Skip to content

Commit 044acdf

Browse files
committed
Optimized FLUX compilation memory usage
1 parent 27dee53 commit 044acdf

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

examples/apps/flux-demo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"black-forest-labs/FLUX.1-dev",
1111
torch_dtype=torch.float16,
1212
)
13-
pipe.to(DEVICE).to(torch.float16)
13+
pipe.to(torch.float16)
1414
backbone = pipe.transformer
1515

1616

@@ -44,10 +44,11 @@
4444
"immutable_weights": False,
4545
"enable_cuda_graph": True,
4646
}
47-
47+
backbone.to(DEVICE)
4848
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
4949
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
5050
pipe.transformer = trt_gm
51+
pipe.to(DEVICE)
5152

5253

5354
def generate_image(prompt, inference_step, batch_size=2):

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
892892
parse_graph_io(submodule, subgraph_data)
893893
dryrun_tracker.tensorrt_graph_count += 1
894894
dryrun_tracker.per_subgraph_data.append(subgraph_data)
895-
895+
torch.cuda.empty_cache()
896896
# Create TRT engines from submodule
897897
if not settings.dryrun:
898898
trt_module = convert_module(

0 commit comments

Comments
 (0)