Skip to content

Commit ddf444b

Browse files
If the dynamic_dimensions parameter is empty in the Shape ctor, assume all dimensions are static.
Some callers call the `Shape(element_type, dimensions, dynamic_dimensions)` ctor with a non-empty `dimensions` and an empty `dynamic_dimensions`. This breaks the shape object's invariant that the two should have the same size. We have two options for fixing this: 1. Force the caller to always provide a `dynamic_dimensions` whose size matches that of `dimensions`. 2. Provide a sensible default behavior when `dynamic_dimensions` is empty. I chose #2 as: 1. #1 is more risky as it may cause the compiler to crash in production (e.g. if we don't have adequate test coverage). 2. It's very common for an array to have only static dimensions. Therefore it's good to optimize the user experience for this common case. PiperOrigin-RevId: 739197635
1 parent c0c6cfd commit ddf444b

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

xla/shape.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ Shape::Shape(const PrimitiveType element_type,
5959
dynamic_dimensions.end()) {
6060
CHECK(primitive_util::IsArrayType(element_type_))
6161
<< "Invalid element type for array shape: " << element_type_;
62-
if (!dynamic_dimensions.empty()) {
62+
if (dynamic_dimensions_.empty()) {
63+
// Assume all dimensions are static.
64+
dynamic_dimensions_.resize(dimensions_.size(), false);
65+
} else {
6366
CHECK_EQ(dimensions_.size(), dynamic_dimensions_.size())
6467
<< "If dynamic_dimensions is provided, it must have the same size as "
6568
"dimensions.";

xla/shape.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Shape {
6767
// Precondition:
6868
// - `element_type` must be a valid array type.
6969
// - `dynamic_dimensions` must be either empty or have the same size as
70-
// `dimensions`.
70+
// `dimensions`. If it's empty, all dimensions are static.
7171
Shape(PrimitiveType element_type, absl::Span<const int64_t> dimensions,
7272
absl::Span<const bool> dynamic_dimensions);
7373

xla/shape_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ class ShapeTest : public ::testing::Test {
4646
ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false});
4747
};
4848

49+
// Tests that if the dynamic_dimensions parameter empty in the Shape
50+
// constructor, it's treated as all dimensions are static.
51+
TEST(Shape, ArrayCtorTreatsEmptyDynamicDimensionsAsAllStatic) {
52+
const Shape shape(F32, {1, 2, 3}, {});
53+
EXPECT_TRUE(shape.is_static());
54+
EXPECT_TRUE(shape.is_static_dimension(0));
55+
EXPECT_TRUE(shape.is_static_dimension(1));
56+
EXPECT_TRUE(shape.is_static_dimension(2));
57+
}
58+
4959
TEST_F(ShapeTest, ShapeToFromProto) {
5060
for (const Shape& shape :
5161
{opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_,

0 commit comments

Comments
 (0)