@@ -109,7 +109,7 @@ def construct_refit_mapping(
109
109
110
110
111
111
def construct_refit_mapping_from_weight_name_map (
112
- weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ]
112
+ weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ], device : torch . device
113
113
) -> dict [Any , Any ]:
114
114
engine_weight_map = {}
115
115
for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
@@ -120,7 +120,11 @@ def construct_refit_mapping_from_weight_name_map(
120
120
# If weights is not in sd, we can leave it unchanged
121
121
continue
122
122
else :
123
- engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ]
123
+ engine_weight_map [engine_weight_name ] = (
124
+ state_dict [sd_weight_name ]
125
+ if state_dict [sd_weight_name ].device == device
126
+ else state_dict [sd_weight_name ].to ("device" )
127
+ )
124
128
125
129
engine_weight_map [engine_weight_name ] = (
126
130
engine_weight_map [engine_weight_name ]
@@ -162,7 +166,7 @@ def _refit_single_trt_engine_with_gm(
162
166
"constant_mapping" , {}
163
167
) # type: ignore
164
168
mapping = construct_refit_mapping_from_weight_name_map (
165
- weight_name_map , new_gm .state_dict ()
169
+ weight_name_map , new_gm .state_dict (), torch_device
166
170
)
167
171
constant_mapping_with_type = {}
168
172
0 commit comments