Skip to content

Commit e851bbf

Browse files
In this change we
1. use ShapeUtil::ForEachSubshape to recursively compute shape byte sizes, rather than explicit recursion. 2. and defer to ByteSizeOfShapeWithSharding when GetShardedInstructionSize is invoked with a sharding object PiperOrigin-RevId: 647831296
1 parent 63acc78 commit e851bbf

File tree

2 files changed

+64
-57
lines changed

2 files changed

+64
-57
lines changed

third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ limitations under the License.
6060
#include "xla/shape_util.h"
6161
#include "xla/xla_data.pb.h"
6262
#include "tsl/platform/errors.h"
63-
#include "tsl/platform/status.h"
6463

6564
namespace xla {
6665
namespace spmd {
@@ -1852,68 +1851,69 @@ std::vector<int64_t> VectorGreaterThanOneElementIndices(
18521851
return result;
18531852
}
18541853

1855-
int64_t ByteSizeOfShapeWithSharding(const Shape& shape,
1854+
// Given a sharding, and a shape index, obtains the subsharding corresponding to
1855+
// that shape index. This function works whether or not the provided sharding is
1856+
// a tuple, unlike HloSharding::GetSubSharding.
1857+
HloSharding GetSubSharding(const HloSharding& sharding,
1858+
const Shape& original_tuple_shape,
1859+
const ShapeIndex& index) {
1860+
return sharding.IsTuple()
1861+
? sharding.GetSubSharding(original_tuple_shape, index)
1862+
: sharding;
1863+
}
1864+
1865+
int64_t ByteSizeOfShapeWithSharding(const Shape& original_shape,
18561866
std::optional<HloSharding> sharding) {
1857-
if (shape.IsTuple()) {
1858-
int64_t size =
1867+
int64_t total_size = 0;
1868+
auto add_to_total_size = [&total_size](const Shape& shape) {
1869+
total_size +=
18591870
ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/kAutoShardingPointerSize);
1860-
for (size_t i = 0; i < shape.tuple_shapes_size(); i++) {
1861-
const Shape& subshape = shape.tuple_shapes().at(i);
1862-
if (sharding) {
1863-
const HloSharding& sub_sharding =
1864-
sharding->IsTuple()
1865-
? sharding->GetSubSharding(shape,
1866-
ShapeIndex{static_cast<int64_t>(i)})
1867-
: *sharding;
1868-
size += ByteSizeOfShapeWithSharding(subshape, sub_sharding);
1869-
} else {
1870-
size += ByteSizeOfShapeWithSharding(subshape, std::nullopt);
1871-
}
1871+
};
1872+
ShapeUtil::ForEachSubshape(original_shape, [&total_size, &add_to_total_size,
1873+
&sharding, &original_shape](
1874+
const Shape& subshape,
1875+
const ShapeIndex& index) {
1876+
if (subshape.IsTuple()) {
1877+
add_to_total_size(subshape);
1878+
} else if (subshape.IsArray() && sharding.has_value()) {
1879+
add_to_total_size(
1880+
GetSubSharding(*sharding, original_shape, index).TileShape(subshape));
1881+
} else if (subshape.IsArray()) {
1882+
add_to_total_size(subshape);
1883+
} else if (subshape.IsToken()) {
1884+
// Tokens are considered to have a size of 0
1885+
} else {
1886+
total_size += kAutoShardingPointerSize;
18721887
}
1873-
return size;
1874-
} else if (shape.IsArray()) {
1875-
return ShapeUtil::ByteSizeOf(sharding ? sharding->TileShape(shape) : shape);
1876-
} else if (shape.IsToken()) {
1877-
return 0;
1878-
} else {
1879-
return kAutoShardingPointerSize;
1880-
}
1888+
});
1889+
return total_size;
18811890
}
18821891

18831892
int64_t GetShardedInstructionSize(const Shape& shape, int64_t num_devices,
18841893
std::optional<HloSharding> sharding) {
1885-
if (sharding && sharding->IsUnknown()) {
1886-
sharding = HloSharding::Replicate();
1887-
}
1888-
if (shape.IsTuple()) {
1889-
int64_t size =
1890-
ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/kAutoShardingPointerSize);
1891-
for (size_t i = 0; i < shape.tuple_shapes_size(); i++) {
1892-
const Shape& subshape = shape.tuple_shapes().at(i);
1893-
size += GetShardedInstructionSize(
1894-
subshape,
1895-
sharding.has_value()
1896-
? sharding
1897-
->GetSubSharding(shape, ShapeIndex{static_cast<int64_t>(i)})
1898-
.NumTiles()
1899-
: num_devices);
1900-
}
1901-
return size;
1902-
}
1903-
if (sharding && sharding->NumTiles() > 0) {
1904-
return GetBytes(shape) / sharding->NumTiles();
1905-
}
1906-
bool shardable = false;
1907-
for (const auto dim : shape.dimensions()) {
1908-
if (dim >= num_devices) {
1909-
shardable = true;
1910-
break;
1911-
}
1912-
}
1913-
if (shardable) {
1914-
return GetBytes(shape) / num_devices;
1915-
}
1916-
return GetBytes(shape);
1894+
if (sharding.has_value()) {
1895+
return ByteSizeOfShapeWithSharding(shape, sharding);
1896+
}
1897+
1898+
int64_t total_size = 0;
1899+
ShapeUtil::ForEachSubshape(
1900+
shape, [&total_size, &num_devices](const Shape& subshape,
1901+
const ShapeIndex& index) {
1902+
if (subshape.IsTuple()) {
1903+
total_size += ShapeUtil::ByteSizeOf(
1904+
subshape, /*pointer_size=*/kAutoShardingPointerSize);
1905+
return;
1906+
}
1907+
int64_t byte_size = ByteSizeOfShape(subshape);
1908+
absl::Span<const int64_t> subshape_dims = subshape.dimensions();
1909+
auto max_dim_it = absl::c_max_element(subshape_dims);
1910+
if (max_dim_it != subshape_dims.end() && *max_dim_it >= num_devices) {
1911+
byte_size /= num_devices;
1912+
}
1913+
total_size += byte_size;
1914+
});
1915+
1916+
return total_size;
19171917
}
19181918

19191919
HloInstruction* FindInstruction(

third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ std::string ToString(absl::Span<T> span) {
159159

160160
// Get the number of bytes of a shape.
161161
inline double GetBytes(const Shape& shape) {
162-
return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/8);
162+
return static_cast<double>(
163+
ShapeUtil::ByteSizeOf(shape,
164+
/*pointer_size=*/kAutoShardingPointerSize));
163165
}
164166

165167
// Return whether two shapes are equal in dimension.
@@ -593,6 +595,11 @@ inline int64_t ByteSizeOfShape(const Shape& shape) {
593595
return ByteSizeOfShapeWithSharding(shape, /*sharding=*/std::nullopt);
594596
}
595597

598+
// Compute the byte size of a shape recursively if it is sharded across a given
599+
// number of devices per an optionally provided sharding. If the sharding is
600+
// provided, this function behaves the same as ByteSizeOfShapeWithSharding
601+
// above. If not, it will give a lower bound on the bytes size of the shape if
602+
// sharded across `num_devices` devices.
596603
int64_t GetShardedInstructionSize(
597604
const Shape& shape, int64_t num_devices,
598605
std::optional<HloSharding> sharding = std::nullopt);

0 commit comments

Comments
 (0)