Skip to content

Commit ad8d093

Browse files
Tongfei-Guocopybara-github
authored andcommitted
[XLA:SPMD] Support shard-as propagation with unspecified_dims.
PiperOrigin-RevId: 629857357
1 parent c3366f8 commit ad8d093

File tree

2 files changed

+153
-5
lines changed

2 files changed

+153
-5
lines changed

xla/service/sharding_propagation.cc

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,49 @@ bool InferUnspecifiedDimsFromUsers(HloInstruction* annotate_op,
12251225
return changed;
12261226
}
12271227

1228+
bool InferUnspecifiedDimsFromShardGroup(
1229+
HloInstruction* annotate_op, absl::Span<const int64_t> unspecified_dims,
1230+
const absl::flat_hash_set<HloInstruction*>& shard_group) {
1231+
// ProcessShardingInstruction will either keep the "Sharding" custom call as
1232+
// is or replace it with a copy.
1233+
CHECK(annotate_op->IsCustomCall("Sharding") ||
1234+
annotate_op->opcode() == HloOpcode::kCopy);
1235+
1236+
// Do not propagate sharding to ShardBarrierTo custom-call.
1237+
if (annotate_op->IsCustomCall(spmd::kShardBarrierTo)) {
1238+
return false;
1239+
}
1240+
1241+
bool changed = false;
1242+
for (const HloInstruction* member : shard_group) {
1243+
if (member == annotate_op) {
1244+
continue;
1245+
}
1246+
// Do not propagate sharding from ShardBarrierFrom custom-call.
1247+
if (member->IsCustomCall(spmd::kShardBarrierFrom)) {
1248+
continue;
1249+
}
1250+
if (!IsSpatiallyPartitioned(member)) {
1251+
continue;
1252+
}
1253+
const HloSharding& member_sharding = member->sharding();
1254+
if (!member_sharding.IsTiled()) {
1255+
continue;
1256+
}
1257+
HloSharding partial_replicated =
1258+
hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1259+
member_sharding, unspecified_dims);
1260+
HloSharding sharding = annotate_op->sharding();
1261+
if (!hlo_sharding_util::MergeShardingIfCompatible(
1262+
partial_replicated, sharding.NumTiles() + 1, &sharding)) {
1263+
continue;
1264+
}
1265+
annotate_op->set_sharding(sharding);
1266+
changed |= true;
1267+
}
1268+
return changed;
1269+
}
1270+
12281271
// Returns whether an op is a target for CSE prevention.
12291272
bool IsCSEPreventionTarget(const HloInstruction* instruction) {
12301273
// Scalar broadcasts are the most common CSE target that causes cross-layer
@@ -1582,7 +1625,7 @@ absl::StatusOr<bool> ProcessShardingInstruction(
15821625
if (instruction->IsCustomCall("Sharding") && !replaced_with_copy) {
15831626
// Pass shard group to operand sharding custom-call if it's not
15841627
// replaced with a copy, meaning that the shardings are to annotate
1585-
// shard_group or shard_barrier only.
1628+
// shard_group.
15861629
HloSharding operand_sharding = instruction->operand(0)->has_sharding()
15871630
? instruction->operand(0)->sharding()
15881631
: HloSharding::Unknown();
@@ -2238,7 +2281,8 @@ bool ShardingPropagation::InferShardingFromShardGroup(
22382281
// Propagate manual sharding.
22392282
if (!instruction->has_sharding() || instruction->sharding().IsTileMaximal()) {
22402283
for (const HloInstruction* member : shard_group) {
2241-
if (!member->has_sharding() || !member->sharding().IsManual()) {
2284+
if (!member->has_sharding() || !member->sharding().IsManual() ||
2285+
member == instruction) {
22422286
continue;
22432287
}
22442288
instruction->set_sharding(member->sharding());
@@ -2249,7 +2293,9 @@ bool ShardingPropagation::InferShardingFromShardGroup(
22492293
const bool may_combine_partial_sharding = is_spmd_ && aggressiveness > 0;
22502294
bool changed = false;
22512295
for (const HloInstruction* member : shard_group) {
2252-
if (member->IsCustomCall(spmd::kShardBarrierFrom)) {
2296+
// Do not propagate sharding from ShardBarrierFrom custom-call.
2297+
if (member == instruction ||
2298+
member->IsCustomCall(spmd::kShardBarrierFrom)) {
22532299
continue;
22542300
}
22552301
changed |= MaybeImproveInstructionSharding(member->sharding(), instruction,
@@ -3309,6 +3355,20 @@ absl::StatusOr<bool> ShardingPropagation::Run(
33093355
? shard_group_id_to_shard_as_group.at(shard_group_id)
33103356
: shard_group_id_to_shard_like_group.at(shard_group_id);
33113357
if (provided_shardings.contains(instruction)) {
3358+
if (!may_merge_partial) {
3359+
continue;
3360+
}
3361+
auto it = unspecified_dims.find(instruction);
3362+
if (it != unspecified_dims.end() &&
3363+
InferUnspecifiedDimsFromShardGroup(instruction, it->second,
3364+
shard_group)) {
3365+
++inferred_from_shard_group_counter;
3366+
VLOG(2) << "Refined partial sharding (shard group): "
3367+
<< instruction->ToString();
3368+
clear_cache(instruction);
3369+
already_inferred_from_shard_group.insert(instruction);
3370+
changed_last_iter = true;
3371+
}
33123372
continue;
33133373
}
33143374
already_inferred_from_shard_group.insert(instruction);
@@ -3469,9 +3529,23 @@ absl::StatusOr<bool> ShardingPropagation::Run(
34693529
VLOG(2) << "Aligning shard group: " << shard_as_group_id
34703530
<< " to sharding:" << common_sharding.ToString();
34713531
for (HloInstruction* member : shard_as_group) {
3472-
if (!member->IsCustomCall(spmd::kShardBarrierTo)) {
3473-
member->set_sharding(common_sharding);
3532+
if (member->IsCustomCall(spmd::kShardBarrierTo)) {
3533+
continue;
3534+
}
3535+
if (provided_shardings.contains(member)) {
3536+
auto it = unspecified_dims.find(member);
3537+
if (it != unspecified_dims.end()) {
3538+
HloSharding partial_replicated =
3539+
hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
3540+
common_sharding, it->second);
3541+
HloSharding sharding = member->sharding();
3542+
if (hlo_sharding_util::MergeShardingIfCompatible(
3543+
partial_replicated, sharding.NumTiles() + 1, &sharding)) {
3544+
member->set_sharding(sharding);
3545+
}
3546+
}
34743547
}
3548+
member->set_sharding(common_sharding);
34753549
}
34763550
}
34773551

xla/service/sharding_propagation_test.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11793,5 +11793,79 @@ ENTRY entry_computation {
1179311793
}
1179411794
}
1179511795

11796+
TEST_F(ShardingPropagationTest, ShardAsWithShardBarrier) {
11797+
const char* const hlo_string = R"(
11798+
HloModule pjit_f
11799+
11800+
ENTRY main.11 {
11801+
Arg_0.1 = bf16[384,1408]{1,0} parameter(0), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11802+
broadcast.4 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11803+
custom-call.5 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.4), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11804+
broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11805+
custom-call.3 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.2), custom_call_target="Sharding", sharding={devices=[8,1,1,1024]<=[8192] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2]"
11806+
custom-call.6 = bf16[8,384,1408]{2,1,0} custom-call(custom-call.3), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11807+
%shard-barrier-to = bf16[8,384,1408]{2,1,0} custom-call(%custom-call.6), custom_call_target="ShardBarrierTo", custom_call_has_side_effect=true
11808+
slice.7 = bf16[1,384,1408]{2,1,0} slice(shard-barrier-to), slice={[1:2], [0:384], [0:1408]}
11809+
reshape.8 = bf16[384,1408]{1,0} reshape(slice.7)
11810+
tuple.9 = (bf16[384,1408]{1,0}) tuple(reshape.8)
11811+
get-tuple-element.10 = bf16[384,1408]{1,0} get-tuple-element(tuple.9), index=0, sharding={devices=[16,1,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11812+
ROOT tuple.13 = (bf16[384,1408]{1,0}, bf16[8,384,1408]{2,1,0}) tuple(get-tuple-element.10, custom-call.5)
11813+
})";
11814+
TF_ASSERT_OK_AND_ASSIGN(auto module,
11815+
ParseAndReturnVerifiedModule(hlo_string));
11816+
TF_ASSERT_OK_AND_ASSIGN(
11817+
bool changed,
11818+
ShardingPropagation(
11819+
/*is_spmd=*/true, /*propagate_metadata=*/true,
11820+
/*allow_spmd_sharding_propagation_to_output=*/{true},
11821+
/*allow_spmd_sharding_propagation_to_parameters=*/{false, false})
11822+
.Run(module.get()));
11823+
EXPECT_TRUE(changed);
11824+
11825+
XLA_VLOG_LINES(1, module->ToString());
11826+
auto* broadcast_4 = FindInstruction(module.get(), "broadcast.4");
11827+
ASSERT_NE(broadcast_4, nullptr);
11828+
EXPECT_THAT(
11829+
broadcast_4,
11830+
op::Sharding("{devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}"));
11831+
auto* copy = FindInstruction(module.get(), "copy");
11832+
ASSERT_NE(copy, nullptr);
11833+
EXPECT_THAT(
11834+
copy,
11835+
op::Sharding("{devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}"));
11836+
}
11837+
11838+
TEST_F(ShardingPropagationTest, ShardAsWithShardBarrier2) {
11839+
const char* const hlo_string = R"(
11840+
HloModule module
11841+
ENTRY %elementwise {
11842+
%param0 = f32[5,7,11,13]{3,2,1,0} parameter(0)
11843+
%custom-call.0 = f32[5,7,11,13]{3,2,1,0} custom-call(param0), custom_call_target="Sharding", sharding={devices=[2,1,1,1,4]<=[8] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2,3]"
11844+
%shard-barrier-from = f32[5,7,11,13]{3,2,1,0} custom-call(%custom-call.0), custom_call_target="ShardBarrierFrom", custom_call_has_side_effect=true
11845+
%custom-call.2 = f32[5,7,11,13]{3,2,1,0} custom-call(shard-barrier-from), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11846+
%param1 = f32[5,7,11,13]{3,2,1,0} parameter(1)
11847+
%custom-call.1 = f32[5,7,11,13]{3,2,1,0} custom-call(param1), custom_call_target="Sharding", sharding={devices=[1,2,2,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}, backend_config="unspecified_dims=[0]"
11848+
%custom-call.3 = f32[5,7,11,13]{3,2,1,0} custom-call(custom-call.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11849+
ROOT %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple(%custom-call.0, %custom-call.3)
11850+
})";
11851+
TF_ASSERT_OK_AND_ASSIGN(auto module,
11852+
ParseAndReturnVerifiedModule(hlo_string));
11853+
TF_ASSERT_OK_AND_ASSIGN(
11854+
bool changed,
11855+
ShardingPropagation(
11856+
/*is_spmd=*/true, /*propagate_metadata=*/true,
11857+
/*allow_spmd_sharding_propagation_to_output=*/{true},
11858+
/*allow_spmd_sharding_propagation_to_parameters=*/{false, false})
11859+
.Run(module.get()));
11860+
EXPECT_TRUE(changed);
11861+
11862+
XLA_VLOG_LINES(1, module->ToString());
11863+
EXPECT_THAT(
11864+
module->entry_computation()->root_instruction(),
11865+
op::Sharding(
11866+
"{{devices=[2,2,2,1]<=[8]}, {devices=[1,2,2,1,2]<=[2,4]T(1,0) "
11867+
"last_tile_dim_replicate}}"));
11868+
}
11869+
1179611870
} // namespace
1179711871
} // namespace xla

0 commit comments

Comments
 (0)