Skip to content

Commit 5042d6f

Browse files
committed
address reviewer comments
1 parent 8615d8a commit 5042d6f

File tree

2 files changed

+55
-48
lines changed

2 files changed

+55
-48
lines changed

mlir/include/mlir/Dialect/Ptr/IR/MemorySpace.h

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
namespace mlir {
2222
class Operation;
2323
namespace ptr {
24-
/// This method checks if it's valid to perform an `addrspacecast` op in the
24+
/// Checks if it's valid to perform an `addrspacecast` op in the
2525
/// memory space.
2626
/// Compatible types are:
2727
/// Vectors of rank 1, or scalars of `ptr` type.
2828
LogicalResult isValidAddrSpaceCastImpl(Type tgt, Type src,
2929
Operation *diagnosticOp);
3030

31-
/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
31+
/// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in
3232
/// the memory space.
3333
/// Compatible types are:
3434
/// IntLikeTy: Vectors of rank 1, or scalars of integer types or `index` type.
@@ -52,28 +52,29 @@ class MemorySpace {
5252
MemorySpace() = default;
5353
MemorySpace(std::nullptr_t) {}
5454
MemorySpace(MemorySpaceAttrInterface memorySpace)
55-
: memorySpaceAttr(memorySpace), memorySpace(memorySpace) {}
56-
MemorySpace(Attribute memorySpace)
57-
: memorySpaceAttr(memorySpace),
55+
: underlyingMemorySpace(memorySpace), memorySpace(memorySpace) {}
56+
explicit MemorySpace(Attribute memorySpace)
57+
: underlyingMemorySpace(memorySpace),
5858
memorySpace(dyn_cast_or_null<MemorySpaceAttrInterface>(memorySpace)) {}
5959

60-
operator Attribute() const { return memorySpaceAttr; }
60+
operator Attribute() const { return underlyingMemorySpace; }
6161
operator MemorySpaceAttrInterface() const { return memorySpace; }
6262
bool operator==(const MemorySpace &memSpace) const {
63-
return memSpace.memorySpaceAttr == memorySpaceAttr;
63+
return memSpace.underlyingMemorySpace == underlyingMemorySpace;
6464
}
6565

6666
/// Returns the underlying memory space.
67-
Attribute getUnderlyingSpace() const { return memorySpaceAttr; }
67+
Attribute getUnderlyingSpace() const { return underlyingMemorySpace; }
6868

69-
/// Returns true if the underlying memory space is null.
69+
/// Returns true if the memory space is null.
7070
bool isDefaultModel() const { return memorySpace == nullptr; }
7171

7272
/// Returns the memory space as an integer, or 0 if using the default space.
7373
unsigned getAddressSpace() const {
7474
if (memorySpace)
7575
return memorySpace.getAddressSpace();
76-
if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(memorySpaceAttr))
76+
if (auto intAttr =
77+
llvm::dyn_cast_or_null<IntegerAttr>(underlyingMemorySpace))
7778
return intAttr.getInt();
7879
return 0;
7980
}
@@ -84,9 +85,9 @@ class MemorySpace {
8485
return memorySpace ? memorySpace.getDefaultMemorySpace() : nullptr;
8586
}
8687

87-
/// This method checks if it's valid to load a value from the memory space
88-
/// with a specific type, alignment, and atomic ordering. The default model
89-
/// assumes all values are loadable.
88+
/// Checks if it's valid to load a value from the memory space with a specific
89+
/// type, alignment, and atomic ordering. The default model assumes all values
90+
/// can be loaded.
9091
LogicalResult isValidLoad(Type type, AtomicOrdering ordering,
9192
IntegerAttr alignment,
9293
Operation *diagnosticOp = nullptr) const {
@@ -95,9 +96,9 @@ class MemorySpace {
9596
: success();
9697
}
9798

98-
/// This method checks if it's valid to store a value in the memory space with
99-
/// a specific type, alignment, and atomic ordering. The default model assumes
100-
/// all values are loadable.
99+
/// Checks if it's valid to store a value in the memory space with a specific
100+
/// type, alignment, and atomic ordering. The default model assumes all values
101+
/// can be stored.
101102
LogicalResult isValidStore(Type type, AtomicOrdering ordering,
102103
IntegerAttr alignment,
103104
Operation *diagnosticOp = nullptr) const {
@@ -106,8 +107,8 @@ class MemorySpace {
106107
: success();
107108
}
108109

109-
/// This method checks if it's valid to perform an atomic operation in the
110-
/// memory space with a specific type, alignment, and atomic ordering.
110+
/// Checks if it's valid to perform an atomic operation in the memory space
111+
/// with a specific type, alignment, and atomic ordering.
111112
LogicalResult isValidAtomicOp(AtomicBinOp op, Type type,
112113
AtomicOrdering ordering, IntegerAttr alignment,
113114
Operation *diagnosticOp = nullptr) const {
@@ -116,8 +117,8 @@ class MemorySpace {
116117
: success();
117118
}
118119

119-
/// This method checks if it's valid to perform an atomic operation in the
120-
/// memory space with a specific type, alignment, and atomic ordering.
120+
/// Checks if it's valid to perform an atomic exchange operation in the memory
121+
/// space with a specific type, alignment, and atomic ordering.
121122
LogicalResult isValidAtomicXchg(Type type, AtomicOrdering successOrdering,
122123
AtomicOrdering failureOrdering,
123124
IntegerAttr alignment,
@@ -128,17 +129,16 @@ class MemorySpace {
128129
: success();
129130
}
130131

131-
/// This method checks if it's valid to perform an `addrspacecast` op in the
132-
/// memory space.
132+
/// Checks if it's valid to perform an `addrspacecast` op in the memory space.
133133
LogicalResult isValidAddrSpaceCast(Type tgt, Type src,
134134
Operation *diagnosticOp = nullptr) const {
135135
return memorySpace
136136
? memorySpace.isValidAddrSpaceCast(tgt, src, diagnosticOp)
137137
: isValidAddrSpaceCastImpl(tgt, src, diagnosticOp);
138138
}
139139

140-
/// This method checks if it's valid to perform a `ptrtoint` or `inttoptr` op
141-
/// in the memory space.
140+
/// Checks if it's valid to perform a `ptrtoint` or `inttoptr` op in the
141+
/// memory space.
142142
LogicalResult isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
143143
Operation *diagnosticOp = nullptr) const {
144144
return memorySpace
@@ -149,7 +149,7 @@ class MemorySpace {
149149

150150
protected:
151151
/// Underlying memory space.
152-
Attribute memorySpaceAttr{};
152+
Attribute underlyingMemorySpace{};
153153
/// Memory space.
154154
MemorySpaceAttrInterface memorySpace{};
155155
};

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ void PtrDialect::initialize() {
3939
// Pointer API.
4040
//===----------------------------------------------------------------------===//
4141

42+
// Error constants for vector data types.
43+
constexpr const static unsigned kInvalidRankError = -1;
44+
constexpr const static unsigned kScalableDimsError = -2;
45+
4246
// Returns a pair containing:
4347
// The underlying type of a vector or the type itself if it's not a vector.
4448
// The number of elements in the vector or an error code if the type is not
@@ -49,26 +53,28 @@ static std::pair<Type, int64_t> getVecOrScalarInfo(Type ty) {
4953
// Vectors of rank greater than one or with scalable dimensions are not
5054
// supported.
5155
if (vecTy.getRank() != 1)
52-
return {elemTy, -1};
56+
return {elemTy, kInvalidRankError};
5357
else if (vecTy.getScalableDims()[0])
54-
return {elemTy, -2};
58+
return {elemTy, kScalableDimsError};
5559
return {elemTy, vecTy.getShape()[0]};
5660
}
5761
// `ty` is a scalar type.
5862
return {ty, 0};
5963
}
6064

61-
LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
62-
Operation *op) {
63-
std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
64-
std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
65-
if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
66-
return op ? op->emitError("invalid ptr-like operand") : failure();
65+
/// Checks whether the shape of the operands is compatible with the operation.
66+
/// Operands must be scalars or have the same vector shape, additionally only
67+
/// vectors of rank 1 are supported.
68+
static LogicalResult verifyShapeInfo(mlir::Operation *op,
69+
const std::pair<Type, int64_t> &tgtInfo,
70+
const std::pair<Type, int64_t> &srcInfo) {
6771
// Check shape validity.
68-
if (tgtInfo.second == -1 || srcInfo.second == -1)
72+
if (tgtInfo.second == kInvalidRankError ||
73+
srcInfo.second == kInvalidRankError)
6974
return op ? op->emitError("vectors of rank != 1 are not supported")
7075
: failure();
71-
if (tgtInfo.second == -2 || srcInfo.second == -2)
76+
if (tgtInfo.second == kScalableDimsError ||
77+
srcInfo.second == kScalableDimsError)
7278
return op ? op->emitError(
7379
"vectors with scalable dimensions are not supported")
7480
: failure();
@@ -77,29 +83,30 @@ LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
7783
return success();
7884
}
7985

86+
LogicalResult mlir::ptr::isValidAddrSpaceCastImpl(Type tgt, Type src,
87+
Operation *op) {
88+
std::pair<Type, int64_t> tgtInfo = getVecOrScalarInfo(tgt);
89+
std::pair<Type, int64_t> srcInfo = getVecOrScalarInfo(src);
90+
if (!isa<PtrType>(tgtInfo.first) || !isa<PtrType>(srcInfo.first))
91+
return op ? op->emitError("invalid ptr-like operand") : failure();
92+
// Verify shape validity.
93+
return verifyShapeInfo(op, tgtInfo, srcInfo);
94+
}
95+
8096
LogicalResult mlir::ptr::isValidPtrIntCastImpl(Type intLikeTy, Type ptrLikeTy,
8197
Operation *op) {
8298
// Check int-like type.
8399
std::pair<Type, int64_t> intInfo = getVecOrScalarInfo(intLikeTy);
100+
// The int-like operand is invalid.
84101
if (!intInfo.first.isSignlessIntOrIndex())
85-
/// The int-like operand is invalid.
86102
return op ? op->emitError("invalid int-like type") : failure();
87103
// Check ptr-like type.
88104
std::pair<Type, int64_t> ptrInfo = getVecOrScalarInfo(ptrLikeTy);
105+
// The pointer-like operand is invalid.
89106
if (!isa<PtrType>(ptrInfo.first))
90-
/// The pointer-like operand is invalid.
91107
return op ? op->emitError("invalid ptr-like type") : failure();
92-
// Check shape validity.
93-
if (intInfo.second == -1 || ptrInfo.second == -1)
94-
return op ? op->emitError("vectors of rank != 1 are not supported")
95-
: failure();
96-
if (intInfo.second == -2 || ptrInfo.second == -2)
97-
return op ? op->emitError(
98-
"vectors with scalable dimensions are not supported")
99-
: failure();
100-
if (intInfo.second != ptrInfo.second)
101-
return op ? op->emitError("incompatible operand shapes") : failure();
102-
return success();
108+
// Verify shape validity.
109+
return verifyShapeInfo(op, intInfo, ptrInfo);
103110
}
104111

105112
#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"

0 commit comments

Comments
 (0)