@@ -87,6 +87,23 @@ convertArrayAttrToGlobalConstant(MLIRContext *ctx, Location loc,
87
87
}
88
88
89
89
namespace {
90
+
91
+ // This pattern replaces a cc.const_array with a global constant. It can
92
+ // recognize a couple of usage patterns and will generate efficient IR in those
93
+ // cases.
94
+ //
95
+ // Pattern 1: The entire constant array is stored to a stack variable(s). Here
96
+ // we can eliminate the stack allocation and use the global constant.
97
+ //
98
+ // Pattern 2: Individual elements at dynamic offsets are extracted from the
99
+ // constant array and used. This can be replaced with a compute pointer
100
+ // operation using the global constant and a load of the element at the computed
101
+ // offset.
102
+ //
103
+ // Default: If the usage is not recognized, the constant array value is replaced
104
+ // with a load of the entire global variable. In this case, LLVM's optimizations
105
+ // are counted on to help demote the (large?) sequence value to primitive memory
106
+ // address arithmetic.
90
107
struct ConstantArrayPattern
91
108
: public OpRewritePattern<cudaq::cc::ConstantArrayOp> {
92
109
explicit ConstantArrayPattern (MLIRContext *ctx, ModuleOp module,
@@ -95,21 +112,30 @@ struct ConstantArrayPattern
95
112
96
113
LogicalResult matchAndRewrite (cudaq::cc::ConstantArrayOp conarr,
97
114
PatternRewriter &rewriter) const override {
115
+ auto func = conarr->getParentOfType <func::FuncOp>();
116
+ if (!func)
117
+ return failure ();
118
+
98
119
SmallVector<cudaq::cc::AllocaOp> allocas;
99
120
SmallVector<cudaq::cc::StoreOp> stores;
121
+ SmallVector<cudaq::cc::ExtractValueOp> extracts;
122
+ bool loadAsValue = false ;
100
123
for (auto *usr : conarr->getUsers ()) {
101
124
auto store = dyn_cast<cudaq::cc::StoreOp>(usr);
102
- if (!store)
103
- return failure ();
104
- auto alloca = store.getPtrvalue ().getDefiningOp <cudaq::cc::AllocaOp>();
105
- if (!alloca )
106
- return failure ();
107
- stores.push_back (store);
108
- allocas.push_back (alloca );
125
+ auto extract = dyn_cast<cudaq::cc::ExtractValueOp>(usr);
126
+ if (store) {
127
+ auto alloca = store.getPtrvalue ().getDefiningOp <cudaq::cc::AllocaOp>();
128
+ if (alloca ) {
129
+ stores.push_back (store);
130
+ allocas.push_back (alloca );
131
+ continue ;
132
+ }
133
+ } else if (extract) {
134
+ extracts.push_back (extract);
135
+ continue ;
136
+ }
137
+ loadAsValue = true ;
109
138
}
110
- auto func = conarr->getParentOfType <func::FuncOp>();
111
- if (!func)
112
- return failure ();
113
139
std::string globalName =
114
140
func.getName ().str () + " .rodata_" + std::to_string (counter++);
115
141
auto *ctx = rewriter.getContext ();
@@ -118,12 +144,39 @@ struct ConstantArrayPattern
118
144
if (failed (convertArrayAttrToGlobalConstant (ctx, conarr.getLoc (), valueAttr,
119
145
module, globalName, eleTy)))
120
146
return failure ();
121
- for (auto alloca : allocas)
122
- rewriter.replaceOpWithNewOp <cudaq::cc::AddressOfOp>(
123
- alloca , alloca .getType (), globalName);
124
- for (auto store : stores)
125
- rewriter.eraseOp (store);
126
- rewriter.eraseOp (conarr);
147
+ auto loc = conarr.getLoc ();
148
+ if (!extracts.empty ()) {
149
+ auto base = rewriter.create <cudaq::cc::AddressOfOp>(
150
+ loc, cudaq::cc::PointerType::get (conarr.getType ()), globalName);
151
+ auto elePtrTy = cudaq::cc::PointerType::get (eleTy);
152
+ for (auto extract : extracts) {
153
+ SmallVector<cudaq::cc::ComputePtrArg> args;
154
+ unsigned i = 0 ;
155
+ for (auto arg : extract.getRawConstantIndices ()) {
156
+ if (arg == cudaq::cc::ExtractValueOp::getDynamicIndexValue ())
157
+ args.push_back (extract.getDynamicIndices ()[i++]);
158
+ else
159
+ args.push_back (arg);
160
+ }
161
+ OpBuilder::InsertionGuard guard (rewriter);
162
+ rewriter.setInsertionPoint (extract);
163
+ auto addrVal =
164
+ rewriter.create <cudaq::cc::ComputePtrOp>(loc, elePtrTy, base, args);
165
+ rewriter.replaceOpWithNewOp <cudaq::cc::LoadOp>(extract, addrVal);
166
+ }
167
+ }
168
+ if (!stores.empty ()) {
169
+ for (auto alloca : allocas)
170
+ rewriter.replaceOpWithNewOp <cudaq::cc::AddressOfOp>(
171
+ alloca , alloca .getType (), globalName);
172
+ for (auto store : stores)
173
+ rewriter.eraseOp (store);
174
+ }
175
+ if (loadAsValue) {
176
+ auto base = rewriter.create <cudaq::cc::AddressOfOp>(
177
+ loc, cudaq::cc::PointerType::get (conarr.getType ()), globalName);
178
+ rewriter.replaceOpWithNewOp <cudaq::cc::LoadOp>(conarr, base);
179
+ }
127
180
return success ();
128
181
}
129
182
0 commit comments