@@ -16,7 +16,6 @@ limitations under the License.
16
16
17
17
#include " xla/service/hlo_computation_deduplicator.h"
18
18
19
- #include < algorithm>
20
19
#include < cstdint>
21
20
#include < memory>
22
21
#include < string>
@@ -27,16 +26,12 @@ limitations under the License.
27
26
#include " xla/hlo/ir/hlo_computation.h"
28
27
#include " xla/hlo/ir/hlo_instruction.h"
29
28
#include " xla/hlo/ir/hlo_opcode.h"
30
- #include " xla/hlo/utils/hlo_matchers.h"
31
- #include " xla/layout_util.h"
32
- #include " xla/literal.h"
33
- #include " xla/service/hlo_pass_fix.h"
29
+ #include " xla/literal_util.h"
30
+ #include " xla/shape.h"
34
31
#include " xla/shape_util.h"
35
32
#include " xla/test.h"
36
33
#include " xla/tests/hlo_test_base.h"
37
- #include " xla/types.h"
38
34
#include " xla/xla_data.pb.h"
39
- #include " tsl/lib/core/status_test_util.h"
40
35
41
36
namespace xla {
42
37
namespace {
@@ -100,6 +95,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBandC) {
100
95
}
101
96
EXPECT_EQ (computation_names.size (), 2 );
102
97
}
98
+
103
99
TEST_F (HloComputationDeduplicatorTest, RemoveRegionBExactCopy) {
104
100
const std::string_view text = R"(
105
101
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -181,7 +177,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
181
177
rd1 = s32[] call(Arg_0, Arg_1), to_apply=main.15
182
178
rd2 = s32[] call(Arg_0, Arg_1), to_apply=main.16
183
179
ROOT ret = add(rd1, rd2)
184
- }
180
+ }
185
181
)" ;
186
182
187
183
auto computation_names = RunDeduplicatePass (text, /* expect_true=*/ true );
@@ -195,6 +191,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
195
191
}
196
192
EXPECT_EQ (computation_names.size (), 3 );
197
193
}
194
+
198
195
TEST_F (HloComputationDeduplicatorTest, DontRemoveRegionsWithDifferentSubcomp) {
199
196
const std::string_view text = R"(
200
197
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -334,7 +331,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
334
331
)" ;
335
332
336
333
auto computation_names = RunDeduplicatePass (text, /* expect_true=*/ false );
337
- // Will also take into account commutativety .
334
+ // Will also take into account commutativity .
338
335
int region_b_count = 0 ;
339
336
for (auto name : computation_names) {
340
337
region_b_count += (name == " region_B" );
@@ -343,6 +340,54 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
343
340
EXPECT_EQ (computation_names.size (), 3 );
344
341
}
345
342
343
+ TEST_F (HloComputationDeduplicatorTest,
344
+ DontRemoveRegionBDifferentExecutionThread) {
345
+ const std::string_view text = R"(
346
+ HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
347
+
348
+ region_A {
349
+ Arg_0 = s32[] parameter(0)
350
+ Arg_1 = s32[] parameter(1)
351
+ ROOT add = s32[] add(Arg_0, Arg_1)
352
+ }
353
+
354
+ region_B {
355
+ Arg_0 = s32[] parameter(0)
356
+ Arg_1 = s32[] parameter(1)
357
+ ROOT add = s32[] add(Arg_0, Arg_1)
358
+ }
359
+
360
+ called_computation {
361
+ Arg_0 = s32[15]{0} parameter(0)
362
+ Cst = s32[] constant(0)
363
+ ROOT rd2 = s32[] reduce(Arg_0, Cst), dimensions={0}, to_apply=region_B
364
+ }, execution_thread="parallel_thread"
365
+
366
+ ENTRY main.15 {
367
+ Arg_0 = s32[10]{0} parameter(0)
368
+ constant.3 = s32[] constant(0)
369
+ rd1 = s32[] reduce(Arg_0, constant.3), dimensions={0}, to_apply=region_A
370
+
371
+ Arg_1 = s32[15]{0} parameter(1)
372
+ call-start = ((s32[15]{0}), s32[], s32[]) call-start(Arg_1),
373
+ async_execution_thread="parallel_thread",
374
+ to_apply=%called_computation
375
+ call-done = s32[] call-done(call-start)
376
+
377
+ ROOT multiply.14 = s32[] multiply(rd1, call-done)
378
+ }
379
+ )" ;
380
+
381
+ auto computation_names = RunDeduplicatePass (text, /* expect_true=*/ false );
382
+ // Will also take into account commutativity.
383
+ int region_b_count = 0 ;
384
+ for (auto name : computation_names) {
385
+ region_b_count += (name == " region_B" );
386
+ }
387
+ EXPECT_EQ (region_b_count, 1 );
388
+ EXPECT_EQ (computation_names.size (), 5 );
389
+ }
390
+
346
391
TEST_F (HloComputationDeduplicatorTest, DontRemoveRegionLargeConstant) {
347
392
const std::string_view text = R"(
348
393
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -618,5 +663,6 @@ TEST_F(HloComputationDeduplicatorTest, DontDeduplicateReduceAllReduce) {
618
663
auto computation_names = RunDeduplicatePass (text, /* expect_true=*/ false );
619
664
EXPECT_EQ (computation_names.size (), 3 );
620
665
}
666
+
621
667
} // namespace
622
668
} // namespace xla
0 commit comments