@@ -60,7 +60,6 @@ limitations under the License.
60
60
#include " xla/shape_util.h"
61
61
#include " xla/xla_data.pb.h"
62
62
#include " tsl/platform/errors.h"
63
- #include " tsl/platform/status.h"
64
63
65
64
namespace xla {
66
65
namespace spmd {
@@ -1852,68 +1851,69 @@ std::vector<int64_t> VectorGreaterThanOneElementIndices(
1852
1851
return result;
1853
1852
}
1854
1853
1855
- int64_t ByteSizeOfShapeWithSharding (const Shape& shape,
1854
+ // Given a sharding, and a shape index, obtains the subsharding corresponding to
1855
+ // that shape index. This function works whether or not the provided sharding is
1856
+ // a tuple, unlike HloSharding::GetSubSharding.
1857
+ HloSharding GetSubSharding (const HloSharding& sharding,
1858
+ const Shape& original_tuple_shape,
1859
+ const ShapeIndex& index) {
1860
+ return sharding.IsTuple ()
1861
+ ? sharding.GetSubSharding (original_tuple_shape, index)
1862
+ : sharding;
1863
+ }
1864
+
1865
+ int64_t ByteSizeOfShapeWithSharding (const Shape& original_shape,
1856
1866
std::optional<HloSharding> sharding) {
1857
- if (shape.IsTuple ()) {
1858
- int64_t size =
1867
+ int64_t total_size = 0 ;
1868
+ auto add_to_total_size = [&total_size](const Shape& shape) {
1869
+ total_size +=
1859
1870
ShapeUtil::ByteSizeOf (shape, /* pointer_size=*/ kAutoShardingPointerSize );
1860
- for (size_t i = 0 ; i < shape.tuple_shapes_size (); i++) {
1861
- const Shape& subshape = shape.tuple_shapes ().at (i);
1862
- if (sharding) {
1863
- const HloSharding& sub_sharding =
1864
- sharding->IsTuple ()
1865
- ? sharding->GetSubSharding (shape,
1866
- ShapeIndex{static_cast <int64_t >(i)})
1867
- : *sharding;
1868
- size += ByteSizeOfShapeWithSharding (subshape, sub_sharding);
1869
- } else {
1870
- size += ByteSizeOfShapeWithSharding (subshape, std::nullopt);
1871
- }
1871
+ };
1872
+ ShapeUtil::ForEachSubshape (original_shape, [&total_size, &add_to_total_size,
1873
+ &sharding, &original_shape](
1874
+ const Shape& subshape,
1875
+ const ShapeIndex& index) {
1876
+ if (subshape.IsTuple ()) {
1877
+ add_to_total_size (subshape);
1878
+ } else if (subshape.IsArray () && sharding.has_value ()) {
1879
+ add_to_total_size (
1880
+ GetSubSharding (*sharding, original_shape, index).TileShape (subshape));
1881
+ } else if (subshape.IsArray ()) {
1882
+ add_to_total_size (subshape);
1883
+ } else if (subshape.IsToken ()) {
1884
+ // Tokens are considered to have a size of 0
1885
+ } else {
1886
+ total_size += kAutoShardingPointerSize ;
1872
1887
}
1873
- return size;
1874
- } else if (shape.IsArray ()) {
1875
- return ShapeUtil::ByteSizeOf (sharding ? sharding->TileShape (shape) : shape);
1876
- } else if (shape.IsToken ()) {
1877
- return 0 ;
1878
- } else {
1879
- return kAutoShardingPointerSize ;
1880
- }
1888
+ });
1889
+ return total_size;
1881
1890
}
1882
1891
1883
1892
int64_t GetShardedInstructionSize (const Shape& shape, int64_t num_devices,
1884
1893
std::optional<HloSharding> sharding) {
1885
- if (sharding && sharding->IsUnknown ()) {
1886
- sharding = HloSharding::Replicate ();
1887
- }
1888
- if (shape.IsTuple ()) {
1889
- int64_t size =
1890
- ShapeUtil::ByteSizeOf (shape, /* pointer_size=*/ kAutoShardingPointerSize );
1891
- for (size_t i = 0 ; i < shape.tuple_shapes_size (); i++) {
1892
- const Shape& subshape = shape.tuple_shapes ().at (i);
1893
- size += GetShardedInstructionSize (
1894
- subshape,
1895
- sharding.has_value ()
1896
- ? sharding
1897
- ->GetSubSharding (shape, ShapeIndex{static_cast <int64_t >(i)})
1898
- .NumTiles ()
1899
- : num_devices);
1900
- }
1901
- return size;
1902
- }
1903
- if (sharding && sharding->NumTiles () > 0 ) {
1904
- return GetBytes (shape) / sharding->NumTiles ();
1905
- }
1906
- bool shardable = false ;
1907
- for (const auto dim : shape.dimensions ()) {
1908
- if (dim >= num_devices) {
1909
- shardable = true ;
1910
- break ;
1911
- }
1912
- }
1913
- if (shardable) {
1914
- return GetBytes (shape) / num_devices;
1915
- }
1916
- return GetBytes (shape);
1894
+ if (sharding.has_value ()) {
1895
+ return ByteSizeOfShapeWithSharding (shape, sharding);
1896
+ }
1897
+
1898
+ int64_t total_size = 0 ;
1899
+ ShapeUtil::ForEachSubshape (
1900
+ shape, [&total_size, &num_devices](const Shape& subshape,
1901
+ const ShapeIndex& index) {
1902
+ if (subshape.IsTuple ()) {
1903
+ total_size += ShapeUtil::ByteSizeOf (
1904
+ subshape, /* pointer_size=*/ kAutoShardingPointerSize );
1905
+ return ;
1906
+ }
1907
+ int64_t byte_size = ByteSizeOfShape (subshape);
1908
+ absl::Span<const int64_t > subshape_dims = subshape.dimensions ();
1909
+ auto max_dim_it = absl::c_max_element (subshape_dims);
1910
+ if (max_dim_it != subshape_dims.end () && *max_dim_it >= num_devices) {
1911
+ byte_size /= num_devices;
1912
+ }
1913
+ total_size += byte_size;
1914
+ });
1915
+
1916
+ return total_size;
1917
1917
}
1918
1918
1919
1919
HloInstruction* FindInstruction (
0 commit comments