@@ -99,7 +99,7 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
99
99
return nullptr ;
100
100
}
101
101
102
- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input) {
102
+ torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device ) {
103
103
auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index] : seg_block.raw_outputs ()[index];
104
104
auto cast_subgraph_value = is_input ? seg_block.inputs ()[index] : seg_block.outputs ()[index];
105
105
torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
@@ -125,8 +125,11 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
125
125
auto const_type = is_input ? g->insertConstant (4 ) : g->insertConstant (3 );
126
126
auto const_zero = g->insertConstant (0 );
127
127
const_zero->setType (torch::jit::BoolType::get ());
128
+ auto cuda = g->insertConstant (device);
129
+ cuda->setType (torch::jit::DeviceObjType::get ());
128
130
auto none_val = g->insertNode (g->createNone ())->output ();
129
- cast_node = g->create (torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
131
+ cast_node =
132
+ g->create (torch::jit::aten::to, {cast_subgraph_value, cuda, const_type, const_zero, const_zero, none_val});
130
133
}
131
134
return cast_node;
132
135
}
@@ -217,6 +220,8 @@ void getSegmentsOutputByRunning(
217
220
ivalues_maps[output] = jit_results[idx++];
218
221
}
219
222
223
+ auto target_device = partitioning_info.getGPUDeviceString ();
224
+
220
225
// auto int64 <=> int32 conversion
221
226
if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double ) {
222
227
// First, check if there is Int64 input
@@ -226,7 +231,7 @@ void getSegmentsOutputByRunning(
226
231
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
227
232
if (t == at::kLong ) {
228
233
// we add a cast operation to cast the type to Int64
229
- auto cast_node = createCastNode (seg_block, i, true );
234
+ auto cast_node = createCastNode (seg_block, i, true , target_device );
230
235
seg_block.g ()->prependNode (cast_node);
231
236
seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
232
237
}
@@ -237,7 +242,7 @@ void getSegmentsOutputByRunning(
237
242
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
238
243
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
239
244
if (t == at::kLong ) {
240
- auto cast_node = createCastNode (seg_block, i, false );
245
+ auto cast_node = createCastNode (seg_block, i, false , target_device );
241
246
seg_block.g ()->appendNode (cast_node);
242
247
seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
243
248
}
0 commit comments