@@ -11793,5 +11793,79 @@ ENTRY entry_computation {
11793
11793
}
11794
11794
}
11795
11795
11796
+ TEST_F (ShardingPropagationTest, ShardAsWithShardBarrier) {
11797
+ const char * const hlo_string = R"(
11798
+ HloModule pjit_f
11799
+
11800
+ ENTRY main.11 {
11801
+ Arg_0.1 = bf16[384,1408]{1,0} parameter(0), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11802
+ broadcast.4 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11803
+ custom-call.5 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.4), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11804
+ broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11805
+ custom-call.3 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.2), custom_call_target="Sharding", sharding={devices=[8,1,1,1024]<=[8192] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2]"
11806
+ custom-call.6 = bf16[8,384,1408]{2,1,0} custom-call(custom-call.3), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11807
+ %shard-barrier-to = bf16[8,384,1408]{2,1,0} custom-call(%custom-call.6), custom_call_target="ShardBarrierTo", custom_call_has_side_effect=true
11808
+ slice.7 = bf16[1,384,1408]{2,1,0} slice(shard-barrier-to), slice={[1:2], [0:384], [0:1408]}
11809
+ reshape.8 = bf16[384,1408]{1,0} reshape(slice.7)
11810
+ tuple.9 = (bf16[384,1408]{1,0}) tuple(reshape.8)
11811
+ get-tuple-element.10 = bf16[384,1408]{1,0} get-tuple-element(tuple.9), index=0, sharding={devices=[16,1,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11812
+ ROOT tuple.13 = (bf16[384,1408]{1,0}, bf16[8,384,1408]{2,1,0}) tuple(get-tuple-element.10, custom-call.5)
11813
+ })" ;
11814
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
11815
+ ParseAndReturnVerifiedModule (hlo_string));
11816
+ TF_ASSERT_OK_AND_ASSIGN (
11817
+ bool changed,
11818
+ ShardingPropagation (
11819
+ /* is_spmd=*/ true , /* propagate_metadata=*/ true ,
11820
+ /* allow_spmd_sharding_propagation_to_output=*/ {true },
11821
+ /* allow_spmd_sharding_propagation_to_parameters=*/ {false , false })
11822
+ .Run (module .get ()));
11823
+ EXPECT_TRUE (changed);
11824
+
11825
+ XLA_VLOG_LINES (1 , module ->ToString ());
11826
+ auto * broadcast_4 = FindInstruction (module .get (), " broadcast.4" );
11827
+ ASSERT_NE (broadcast_4, nullptr );
11828
+ EXPECT_THAT (
11829
+ broadcast_4,
11830
+ op::Sharding (" {devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}" ));
11831
+ auto * copy = FindInstruction (module .get (), " copy" );
11832
+ ASSERT_NE (copy, nullptr );
11833
+ EXPECT_THAT (
11834
+ copy,
11835
+ op::Sharding (" {devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}" ));
11836
+ }
11837
+
11838
+ TEST_F (ShardingPropagationTest, ShardAsWithShardBarrier2) {
11839
+ const char * const hlo_string = R"(
11840
+ HloModule module
11841
+ ENTRY %elementwise {
11842
+ %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0)
11843
+ %custom-call.0 = f32[5,7,11,13]{3,2,1,0} custom-call(param0), custom_call_target="Sharding", sharding={devices=[2,1,1,1,4]<=[8] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2,3]"
11844
+ %shard-barrier-from = f32[5,7,11,13]{3,2,1,0} custom-call(%custom-call.0), custom_call_target="ShardBarrierFrom", custom_call_has_side_effect=true
11845
+ %custom-call.2 = f32[5,7,11,13]{3,2,1,0} custom-call(shard-barrier-from), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11846
+ %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1)
11847
+ %custom-call.1 = f32[5,7,11,13]{3,2,1,0} custom-call(param1), custom_call_target="Sharding", sharding={devices=[1,2,2,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}, backend_config="unspecified_dims=[0]"
11848
+ %custom-call.3 = f32[5,7,11,13]{3,2,1,0} custom-call(custom-call.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11849
+ ROOT %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple(%custom-call.0, %custom-call.3)
11850
+ })" ;
11851
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
11852
+ ParseAndReturnVerifiedModule (hlo_string));
11853
+ TF_ASSERT_OK_AND_ASSIGN (
11854
+ bool changed,
11855
+ ShardingPropagation (
11856
+ /* is_spmd=*/ true , /* propagate_metadata=*/ true ,
11857
+ /* allow_spmd_sharding_propagation_to_output=*/ {true },
11858
+ /* allow_spmd_sharding_propagation_to_parameters=*/ {false , false })
11859
+ .Run (module .get ()));
11860
+ EXPECT_TRUE (changed);
11861
+
11862
+ XLA_VLOG_LINES (1 , module ->ToString ());
11863
+ EXPECT_THAT (
11864
+ module ->entry_computation ()->root_instruction (),
11865
+ op::Sharding (
11866
+ " {{devices=[2,2,2,1]<=[8]}, {devices=[1,2,2,1,2]<=[2,4]T(1,0) "
11867
+ " last_tile_dim_replicate}}" ));
11868
+ }
11869
+
11796
11870
} // namespace
11797
11871
} // namespace xla
0 commit comments