Skip to content

Commit af39c65

Browse files
authored
fix: Automatically send truncated long ints to cuda at shape analysis time (#1541)
1 parent 236b30e commit af39c65

File tree

4 files changed

+21
-5
lines changed

4 files changed

+21
-5
lines changed

core/lowering/lowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct LowerInfo {
2020
std::vector<std::string> forced_fallback_modules;
2121
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
2222

23-
std::string getGPUDeviceString() {
23+
std::string getGPUDeviceString() const {
2424
return "cuda:" + std::to_string(target_device.gpu_id);
2525
};
2626
};

core/partitioning/partitioninginfo/PartitioningInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ struct PartitioningInfo {
1616
uint64_t min_block_size = 1;
1717
std::vector<std::string> forced_fallback_operators;
1818
bool truncate_long_and_double;
19+
ir::Device target_device;
20+
21+
std::string getGPUDeviceString() const {
22+
return "cuda:" + std::to_string(target_device.gpu_id);
23+
};
1924
};
2025

2126
std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s);

core/partitioning/shape_analysis.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999
return nullptr;
100100
}
101101

102-
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) {
102+
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
103103
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
104104
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
105105
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
@@ -125,8 +125,11 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
125125
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
126126
auto const_zero = g->insertConstant(0);
127127
const_zero->setType(torch::jit::BoolType::get());
128+
auto cuda = g->insertConstant(device);
129+
cuda->setType(torch::jit::DeviceObjType::get());
128130
auto none_val = g->insertNode(g->createNone())->output();
129-
cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
131+
cast_node =
132+
g->create(torch::jit::aten::to, {cast_subgraph_value, cuda, const_type, const_zero, const_zero, none_val});
130133
}
131134
return cast_node;
132135
}
@@ -217,6 +220,8 @@ void getSegmentsOutputByRunning(
217220
ivalues_maps[output] = jit_results[idx++];
218221
}
219222

223+
auto target_device = partitioning_info.getGPUDeviceString();
224+
220225
// auto int64 <=> int32 conversion
221226
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
222227
// First, check if there is Int64 input
@@ -226,7 +231,7 @@ void getSegmentsOutputByRunning(
226231
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
227232
if (t == at::kLong) {
228233
// we add a cast operation to cast the type to Int64
229-
auto cast_node = createCastNode(seg_block, i, true);
234+
auto cast_node = createCastNode(seg_block, i, true, target_device);
230235
seg_block.g()->prependNode(cast_node);
231236
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
232237
}
@@ -237,7 +242,7 @@ void getSegmentsOutputByRunning(
237242
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
238243
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
239244
if (t == at::kLong) {
240-
auto cast_node = createCastNode(seg_block, i, false);
245+
auto cast_node = createCastNode(seg_block, i, false, target_device);
241246
seg_block.g()->appendNode(cast_node);
242247
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
243248
}

cpp/src/compile_spec.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
111111
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
112112
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113113
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
114+
internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
114115

115116
TORCHTRT_CHECK(
116117
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
@@ -132,11 +133,13 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
132133
case Device::DeviceType::kDLA:
133134
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
134135
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
136+
internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
135137
break;
136138
case Device::DeviceType::kGPU:
137139
default:
138140
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
139141
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
142+
internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
140143
}
141144

142145
switch (external.capability) {
@@ -155,6 +158,9 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
155158
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
156159
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
157160
internal.lower_info.target_device.dla_core = external.device.dla_core;
161+
internal.partitioning_info.target_device.gpu_id = external.device.gpu_id;
162+
internal.partitioning_info.target_device.dla_core = external.device.dla_core;
163+
158164
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
159165
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
160166
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;

0 commit comments

Comments
 (0)