Skip to content

Commit 63acc78

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla] Add a test for HLO deduplication + execution threads
PiperOrigin-RevId: 647816688
1 parent a3f86f1 commit 63acc78

File tree

4 files changed

+81
-14
lines changed

4 files changed

+81
-14
lines changed

third_party/xla/xla/service/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,16 @@ cc_library(
11591159
hdrs = ["hlo_computation_deduplicator.h"],
11601160
deps = [
11611161
":hlo_pass",
1162+
"//xla:shape_util",
1163+
"//xla:status_macros",
1164+
"//xla:util",
11621165
"//xla/hlo/ir:hlo",
1166+
"@com_google_absl//absl/container:flat_hash_map",
1167+
"@com_google_absl//absl/container:flat_hash_set",
1168+
"@com_google_absl//absl/status:statusor",
1169+
"@com_google_absl//absl/strings",
1170+
"@com_google_absl//absl/strings:string_view",
1171+
"@local_tsl//tsl/platform:logging",
11631172
],
11641173
)
11651174

@@ -1171,6 +1180,7 @@ xla_cc_test(
11711180
":hlo_computation_deduplicator",
11721181
":hlo_pass",
11731182
"//xla:literal",
1183+
"//xla:literal_util",
11741184
"//xla:shape_util",
11751185
"//xla:test",
11761186
"//xla:types",
@@ -1181,6 +1191,7 @@ xla_cc_test(
11811191
"//xla/tests:xla_internal_test_main",
11821192
"@com_google_googletest//:gtest_main",
11831193
"@local_tsl//tsl/lib/core:status_test_util",
1194+
"@local_tsl//tsl/platform:statusor",
11841195
],
11851196
)
11861197

third_party/xla/xla/service/hlo_computation_deduplicator.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@ limitations under the License.
1515

1616
#include "xla/service/hlo_computation_deduplicator.h"
1717

18-
#include <algorithm>
1918
#include <string>
2019
#include <utility>
2120

21+
#include "absl/container/flat_hash_map.h"
22+
#include "absl/container/flat_hash_set.h"
23+
#include "absl/status/statusor.h"
24+
#include "absl/strings/string_view.h"
2225
#include "xla/hlo/ir/hlo_computation.h"
2326
#include "xla/hlo/ir/hlo_instruction.h"
27+
#include "xla/shape_util.h"
28+
#include "tsl/platform/logging.h"
2429

2530
namespace xla {
2631

@@ -36,6 +41,7 @@ bool HloComputationDeduplicator::ContainsLargeConstants(HloComputation* comp) {
3641
}
3742
return false;
3843
}
44+
3945
absl::StatusOr<bool> HloComputationDeduplicator::Run(
4046
HloModule* module,
4147
const absl::flat_hash_set<absl::string_view>& execution_threads) {

third_party/xla/xla/service/hlo_computation_deduplicator.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ limitations under the License.
1616
#ifndef XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_
1717
#define XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_
1818

19+
#include "absl/container/flat_hash_set.h"
20+
#include "absl/status/statusor.h"
21+
#include "absl/strings/string_view.h"
22+
#include "xla/hlo/ir/hlo_computation.h"
1923
#include "xla/service/hlo_pass_interface.h"
2024

2125
namespace xla {
2226

2327
// Deduplicate computations inside a `HloModule`: If two computations are
2428
// identical then keep the first one (in postorder terms) and remove the rest.
2529
class HloComputationDeduplicator : public HloModulePass {
26-
private:
27-
bool ContainsLargeConstants(HloComputation* comp);
28-
bool mark_fusion_duplications_;
29-
3030
public:
3131
// Setting mark_fusion_duplications to true will only process fusions in the
3232
// HLO. The comparator in this pass will mark duplicate fusions which is
@@ -40,6 +40,10 @@ class HloComputationDeduplicator : public HloModulePass {
4040
absl::StatusOr<bool> Run(
4141
HloModule* module,
4242
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
43+
44+
private:
45+
bool ContainsLargeConstants(HloComputation* comp);
46+
bool mark_fusion_duplications_;
4347
};
4448

4549
} // namespace xla

third_party/xla/xla/service/hlo_computation_deduplicator_test.cc

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616

1717
#include "xla/service/hlo_computation_deduplicator.h"
1818

19-
#include <algorithm>
2019
#include <cstdint>
2120
#include <memory>
2221
#include <string>
@@ -27,16 +26,12 @@ limitations under the License.
2726
#include "xla/hlo/ir/hlo_computation.h"
2827
#include "xla/hlo/ir/hlo_instruction.h"
2928
#include "xla/hlo/ir/hlo_opcode.h"
30-
#include "xla/hlo/utils/hlo_matchers.h"
31-
#include "xla/layout_util.h"
32-
#include "xla/literal.h"
33-
#include "xla/service/hlo_pass_fix.h"
29+
#include "xla/literal_util.h"
30+
#include "xla/shape.h"
3431
#include "xla/shape_util.h"
3532
#include "xla/test.h"
3633
#include "xla/tests/hlo_test_base.h"
37-
#include "xla/types.h"
3834
#include "xla/xla_data.pb.h"
39-
#include "tsl/lib/core/status_test_util.h"
4035

4136
namespace xla {
4237
namespace {
@@ -100,6 +95,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBandC) {
10095
}
10196
EXPECT_EQ(computation_names.size(), 2);
10297
}
98+
10399
TEST_F(HloComputationDeduplicatorTest, RemoveRegionBExactCopy) {
104100
const std::string_view text = R"(
105101
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -181,7 +177,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
181177
rd1 = s32[] call(Arg_0, Arg_1), to_apply=main.15
182178
rd2 = s32[] call(Arg_0, Arg_1), to_apply=main.16
183179
ROOT ret = add(rd1, rd2)
184-
}
180+
}
185181
)";
186182

187183
auto computation_names = RunDeduplicatePass(text, /*expect_true=*/true);
@@ -195,6 +191,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) {
195191
}
196192
EXPECT_EQ(computation_names.size(), 3);
197193
}
194+
198195
TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionsWithDifferentSubcomp) {
199196
const std::string_view text = R"(
200197
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -334,7 +331,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
334331
)";
335332

336333
auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
337-
// Will also take into account commutativety.
334+
// Will also take into account commutativity.
338335
int region_b_count = 0;
339336
for (auto name : computation_names) {
340337
region_b_count += (name == "region_B");
@@ -343,6 +340,54 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) {
343340
EXPECT_EQ(computation_names.size(), 3);
344341
}
345342

343+
TEST_F(HloComputationDeduplicatorTest,
344+
DontRemoveRegionBDifferentExecutionThread) {
345+
const std::string_view text = R"(
346+
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
347+
348+
region_A {
349+
Arg_0 = s32[] parameter(0)
350+
Arg_1 = s32[] parameter(1)
351+
ROOT add = s32[] add(Arg_0, Arg_1)
352+
}
353+
354+
region_B {
355+
Arg_0 = s32[] parameter(0)
356+
Arg_1 = s32[] parameter(1)
357+
ROOT add = s32[] add(Arg_0, Arg_1)
358+
}
359+
360+
called_computation {
361+
Arg_0 = s32[15]{0} parameter(0)
362+
Cst = s32[] constant(0)
363+
ROOT rd2 = s32[] reduce(Arg_0, Cst), dimensions={0}, to_apply=region_B
364+
}, execution_thread="parallel_thread"
365+
366+
ENTRY main.15 {
367+
Arg_0 = s32[10]{0} parameter(0)
368+
constant.3 = s32[] constant(0)
369+
rd1 = s32[] reduce(Arg_0, constant.3), dimensions={0}, to_apply=region_A
370+
371+
Arg_1 = s32[15]{0} parameter(1)
372+
call-start = ((s32[15]{0}), s32[], s32[]) call-start(Arg_1),
373+
async_execution_thread="parallel_thread",
374+
to_apply=%called_computation
375+
call-done = s32[] call-done(call-start)
376+
377+
ROOT multiply.14 = s32[] multiply(rd1, call-done)
378+
}
379+
)";
380+
381+
auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
382+
// Will also take into account commutativity.
383+
int region_b_count = 0;
384+
for (auto name : computation_names) {
385+
region_b_count += (name == "region_B");
386+
}
387+
EXPECT_EQ(region_b_count, 1);
388+
EXPECT_EQ(computation_names.size(), 5);
389+
}
390+
346391
TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionLargeConstant) {
347392
const std::string_view text = R"(
348393
HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
@@ -618,5 +663,6 @@ TEST_F(HloComputationDeduplicatorTest, DontDeduplicateReduceAllReduce) {
618663
auto computation_names = RunDeduplicatePass(text, /*expect_true=*/false);
619664
EXPECT_EQ(computation_names.size(), 3);
620665
}
666+
621667
} // namespace
622668
} // namespace xla

0 commit comments

Comments
 (0)