File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change 10
10
"black-forest-labs/FLUX.1-dev" ,
11
11
torch_dtype = torch .float16 ,
12
12
)
13
- pipe .to (DEVICE ). to ( torch .float16 )
13
+ pipe .to (torch .float16 )
14
14
backbone = pipe .transformer
15
15
16
16
44
44
"immutable_weights" : False ,
45
45
"enable_cuda_graph" : True ,
46
46
}
47
-
47
+ backbone . to ( DEVICE )
48
48
trt_gm = torch_tensorrt .MutableTorchTensorRTModule (backbone , ** settings )
49
49
trt_gm .set_expected_dynamic_shape_range ((), dynamic_shapes )
50
50
pipe .transformer = trt_gm
51
+ pipe .to (DEVICE )
51
52
52
53
53
54
def generate_image (prompt , inference_step , batch_size = 2 ):
Original file line number Diff line number Diff line change @@ -892,7 +892,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
892
892
parse_graph_io (submodule , subgraph_data )
893
893
dryrun_tracker .tensorrt_graph_count += 1
894
894
dryrun_tracker .per_subgraph_data .append (subgraph_data )
895
-
895
+ torch . cuda . empty_cache ()
896
896
# Create TRT engines from submodule
897
897
if not settings .dryrun :
898
898
trt_module = convert_module (
You can’t perform that action at this time.
0 commit comments