-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathPtrAnalysis.h
275 lines (221 loc) · 11.2 KB
/
PtrAnalysis.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H
#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include <cstddef>
#include <set>
namespace mlir {
class OpBuilder;
namespace tts {
const extern std::string ptrAnalysisAttr;
// Data structure used to decode pointer arithmetics. offsets, sizes, and
// strides are in unit of elements in a linearly laid-out memory, which is the
// same as pointer arithmetic operations in Triton language. scalar is a
// shortcut used when the entire state describes a single scalar value. source
// is the base pointer. If order is present, PtrState describes block pointer;
// otherwise it describes non-block pointers. When it describes block pointer,
// shape field means the same field as tt.make_tensor_ptr; when it describes a
// non-block pointer, shape field indicates how address wraps around (i.e.,
// modulo); a constant 0 indicates no modulo for the dimension.
struct PtrState {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
SmallVector<OpFoldResult> shape;
SmallVector<int32_t> order;
Value source;
Value scalar;
int32_t getRank() const;
bool isEmpty() const;
bool hasModulo() const;
bool dimHasModulo(uint32_t dim) const;
bool isBlockPtr() const;
// Process addition of two PtrStates.
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState,
Operation *op, OpBuilder &builder);
// Process multiplication of two PtrStates
LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState,
Operation *op, OpBuilder &builder);
tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder,
Location loc);
};
class PtrAnalysis {
// This function is internally used by getLoopIterArgPtrState and
// getLoopResultPtrState to get the correct PtrState for either an iter-arg or
// a loop's result.
//
// A PtrState of an scf.for's iter-arg is the same as its corresponding
// init-arg, except that the strides and offsets have to point to the loop's
// iter-args that were created to carry the offsets and strides.
//
// For instance, for a pointer with index i and rank 2, 4 additional args
// starting at index i + 1 are created. The PtrState's strides and offsets
// value of the pointer's iter-arg must point to these 4 additionally created
// iter-args.
//
// A similar process is used for getting the PtrState of the loop's i'th
// result: its strides and offsets have to point to the corresponding stride
// and offset values returned by the loop.
PtrState reconcileLoopPtrState(
scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state,
llvm::function_ref<Value(scf::ForOp op, size_t)> getReplacementVal);
DenseSet<Value> maybeStructuredArgs;
public:
void initializeMaybeStructuredArgs(Operation *op);
llvm::SmallDenseMap<Value, PtrState> knownPtrs;
IRMapping ptrMap;
// Recursively parse a Value; call the corresponding
// function based on the defining operation and argument type.
LogicalResult visitOperand(Value operand, PtrState &state, const Location loc,
OpBuilder &builder);
// Operand is a result of an scf.for. Such cases occur when there are multiple
// levels of nested loops where the results of the inner scf.for (pointer) are
// yielded by the outer loop.
LogicalResult visitOperandForOp(scf::ForOp forOp, Value operand,
PtrState &state, const Location loc,
OpBuilder &builder);
// Operand is the result of arith.addi. Process both arguments and insert any
// arith.addi instruction as needed.
// Main assumptions:
// Only one of lhsState and rhsState has source field set
// Current PtrState should be empty
// Expected result:
// source = lhsState.source ? lhsState.source : rhsState.source
// sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i])
// offsets[i] = lhsState.offsets[i] + rhsState.offsets[i]
// strides[i] = lhsState.strides[i] + rhsState.strides[i]
LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state,
const Location loc, OpBuilder &builder);
// Operand is the result of arith.muli. Process both arguments and insert any
// arith.muli instruction as needed.
// Main assumptions:
// Neither lhsState nor rhsState has source field set
// Current PtrState should be empty
// Currently only support one of the operand is a scalar index
// Expected result (scalar and tensorState represent the two operands):
// source = null
// sizes[i] = tensorState.sizes[i]
// offsets[i] = tensorState.offsets[i] * scalar
// strides[i] = tensorState.strides[i] * scalar
LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state,
const Location loc, OpBuilder &builder);
LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state,
const Location loc, OpBuilder &builder);
// Operand is the result of make_range.
// Main assumptions:
// start, end, and shape are all statically known
// The output of make_range is 1-dimensional
// Does not check validity of inputs (e.g., stride > 0)
// Expected result:
// source = null
// sizes[0] = shape[0]
// offset[0] = start
// strides[0] = ceiling( (end - start) / shape[0] )
LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp,
PtrState &state, Location loc,
OpBuilder &builder);
// Operand is the result of expand_dims
// Main assumptions:
// Only 1 dimension changes for each invocation of reshape
// The changed dimension must have size of 1
// Expected result:
// Insert a dimension of size 1, stride 0, and offset 0
LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp,
PtrState &state, const Location loc,
OpBuilder &builder);
// Operand is the result of broadcast
// Main assumptions:
// Rank of soure and result is the same
// Expected result:
// Update sizes[i] only, no changes to other fields
LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp,
PtrState &state, const Location loc,
OpBuilder &builder);
// Operand is the result of splat
// Main assumptions:
// Source is a scalar value (i.e., an integer or a pointer, not a tensor)
// Expected result:
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0
// if source is an integer, offset[0] = scalar = source
LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state,
const Location loc, OpBuilder &builder);
// Operand is the result of arith.constant that is a splat
// Main assumptions:
// Source is a constant op that produces a constant dense tensor where all
// elements are the same (i.e.: a constant that is splatted)
// Expected result:
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] =
// splat value if i == 0, otherwise 0
LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state,
const Location loc, OpBuilder &builder);
LogicalResult visitOperandExtSI(arith::ExtSIOp, PtrState &state,
const Location loc, OpBuilder &builder);
// Operand is the result of addptr.
// Main assumptions:
// The ptr field should populate the source field
// ptr and offset fields should result in same rank
// Expected result:
// The resulting state for ptr and offset wil be added
LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state,
const Location loc, OpBuilder &builder);
// Operand is the result of tts.make_tptr.
// Main assumptions:
// This function is only called when rewriting a loop
// Expected result:
// Directly grab all corresponding fields from tts.make_tptr.
LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp,
PtrState &state, const Location loc,
OpBuilder &builder);
// Operand is the result of tt.make_tensor_ptr.
// Expected result:
// Parse source pointer and grab results
LogicalResult visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp,
PtrState &state, const Location loc,
OpBuilder &builder);
// Get the computed PtrState for the forOp's init-arg at the provided index.
FailureOr<PtrState> getLoopInitArgPtrState(scf::ForOp forOp, size_t index);
// Get the computed PtrState for the forOp's iter-arg at the provided index.
FailureOr<PtrState> getLoopIterArgPtrState(scf::ForOp forOp, size_t index);
// Get the computed PtrState for the forOp's result at the provided index.
FailureOr<PtrState> getLoopResultPtrState(scf::ForOp forOp, size_t index);
// After PtrAnalysis finishes, rewrite the GetStructuredStateOp by creating
// the correct initialization ops for offsets and strides and passing them to
// any loop's init-args.
LogicalResult rewriteGetStructuredStateOp(tts::GetStructuredStateOp op);
// Parse the state of AddPtrOp, insert any instruction needed to
// calculate strides and offsets, build PtrState for this operand, and record
// PtrState for knownPtrs.
LogicalResult rewriteAddptrOp(triton::AddPtrOp op);
LogicalResult rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op);
LogicalResult rewriteAdvanceOp(triton::AdvanceOp op);
// Parse the state of YieldOp, insert any instruction needed to calculate
// strides and offsets, build PtrState for this operand, and record PtrState
// in knownPtrs.
LogicalResult
rewriteYieldOp(scf::YieldOp op,
llvm::SmallDenseMap<int, PtrState> &knownPtrsFor);
// Rewrite eligible tt.addptr in loop init args so loop can update the such
// pointers over iterations. Insert any instruction needed to calculate
// strides, offsets, and modulos.
LogicalResult rewriteForOp(scf::ForOp op);
LogicalResult rewriteLoadOp(triton::LoadOp op, bool useUnsafeMask = false);
LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false);
// Only rewrite if a scalar ptr is splated into a tensor of ptr
LogicalResult rewriteSplatOp(triton::SplatOp op);
LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false);
};
} // namespace tts
} // namespace mlir
#endif