Skip to content

Commit ddc7bc6

Browse files
committed
Fixed Cuda Error
1 parent 3bd44ca commit ddc7bc6

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

py/torch_tensorrt/dynamo/_refit.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def construct_refit_mapping(
109109

110110

111111
def construct_refit_mapping_from_weight_name_map(
112-
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
112+
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any], device: torch.device
113113
) -> dict[Any, Any]:
114114
engine_weight_map = {}
115115
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
@@ -120,7 +120,11 @@ def construct_refit_mapping_from_weight_name_map(
120120
# If weights is not in sd, we can leave it unchanged
121121
continue
122122
else:
123-
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]
123+
engine_weight_map[engine_weight_name] = (
124+
state_dict[sd_weight_name]
125+
if state_dict[sd_weight_name].device == device
126+
else state_dict[sd_weight_name].to("device")
127+
)
124128

125129
engine_weight_map[engine_weight_name] = (
126130
engine_weight_map[engine_weight_name]
@@ -162,7 +166,7 @@ def _refit_single_trt_engine_with_gm(
162166
"constant_mapping", {}
163167
) # type: ignore
164168
mapping = construct_refit_mapping_from_weight_name_map(
165-
weight_name_map, new_gm.state_dict()
169+
weight_name_map, new_gm.state_dict(), torch_device
166170
)
167171
constant_mapping_with_type = {}
168172

0 commit comments

Comments
 (0)