Skip to content

Commit 8ed2ae7

Browse files
authored
fix: fix failed test cases caused by partition API changes (#1460)
Signed-off-by: Bo Wang <[email protected]> Signed-off-by: Bo Wang <[email protected]>
1 parent 975f638 commit 8ed2ae7

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tests/core/partitioning/test_shape_analysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ bool checkSegmentedBlockInputShape(
1111
if (segmented_blocks.size() != in_shape.size())
1212
return false;
1313
for (size_t i = 0; i < segmented_blocks.size(); ++i) {
14-
auto cur_block_in_shapes = segmented_blocks[i].in_shapes();
14+
auto cur_block_in_shapes = segmented_blocks[i].in_opt_shapes();
1515
if (cur_block_in_shapes.size() != in_shape[i].size())
1616
return false;
1717
for (size_t j = 0; j < cur_block_in_shapes.size(); ++j) {

tests/core/partitioning/test_type_auto_conversion.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) {
5151
inputs_map.insert({g->inputs()[1], {inputs[1]}});
5252
input_types.insert({g->inputs()[1], {{at::kInt}}});
5353

54-
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
55-
54+
partitioning_info.collection_input_spec_map = inputs_map;
5655
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
57-
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
56+
ctx.input_types_map = input_types;
57+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
58+
torch_tensorrt::core::partitioning::partition(&ctx);
5859
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
5960

6061
for (auto& seg_block : segmented_blocks) {
@@ -93,10 +94,12 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
9394
inputs_map.insert({g->inputs()[0], {inputs[0]}});
9495
input_types.insert({g->inputs()[0], {{at::kFloat}}});
9596

96-
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
97-
97+
partitioning_info.collection_input_spec_map = inputs_map;
9898
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
99-
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
99+
ctx.input_types_map = input_types;
100+
101+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
102+
torch_tensorrt::core::partitioning::partition(&ctx);
100103
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
101104

102105
for (auto& seg_block : segmented_blocks) {

0 commit comments

Comments
 (0)