Skip to content

Commit 0a8c67e

Browse files
authored
Add WriteAfterWriteElimination pass (#2572)
* Add RemoveUselessStores pass Signed-off-by: Anna Gringauze <[email protected]> * Address some CR comments Signed-off-by: Anna Gringauze <[email protected]> * Address CR comments Signed-off-by: Anna Gringauze <[email protected]> --------- Signed-off-by: Anna Gringauze <[email protected]>
1 parent 10933d5 commit 0a8c67e

File tree

4 files changed

+266
-0
lines changed

4 files changed

+266
-0
lines changed

include/cudaq/Optimizer/Transforms/Passes.td

+25
Original file line numberDiff line numberDiff line change
@@ -1038,4 +1038,29 @@ def UpdateRegisterNames : Pass<"update-register-names"> {
10381038
}];
10391039
}
10401040

1041+
def WriteAfterWriteElimination : Pass<"write-after-write-elimination"> {
1042+
let summary = "Remove stores that are overridden by subsequent store";
1043+
let description = [{
1044+
Remove stores to a location on the stack that have a subsequent store
1045+
to the same location without a use between them:
1046+
1047+
Example:
1048+
```mlir
1049+
%1 = cc.alloca !cc.array<i64 x 1>
1050+
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
1051+
cc.store %c0_i64, %2 : !cc.ptr<i64>
1052+
// nothing using %2 until the next instruction
1053+
cc.store %c1_i64, %2 : !cc.ptr<i64>
1054+
```
1055+
1056+
would be converted to
1057+
1058+
```mlir
1059+
%1 = cc.alloca !cc.array<i64 x 1>
1060+
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
1061+
cc.store %c1_i64, %2 : !cc.ptr<i64>
1062+
```
1063+
}];
1064+
}
1065+
10411066
#endif // CUDAQ_OPT_OPTIMIZER_TRANSFORMS_PASSES

lib/Optimizer/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ add_cudaq_library(OptTransforms
5555
StatePreparation.cpp
5656
UnitarySynthesis.cpp
5757
WiresToWiresets.cpp
58+
WriteAfterWriteElimination.cpp
5859

5960
DEPENDS
6061
OptTransformsPassIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "PassDetails.h"
10+
#include "cudaq/Optimizer/Builder/Intrinsics.h"
11+
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
12+
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
13+
#include "cudaq/Optimizer/Transforms/Passes.h"
14+
#include "mlir/Dialect/Complex/IR/Complex.h"
15+
#include "mlir/IR/BuiltinOps.h"
16+
#include "mlir/IR/Dominance.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "mlir/Transforms/Passes.h"
20+
21+
namespace cudaq::opt {
22+
#define GEN_PASS_DEF_WRITEAFTERWRITEELIMINATION
23+
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
24+
} // namespace cudaq::opt
25+
26+
#define DEBUG_TYPE "write-after-write-elimination"
27+
28+
using namespace mlir;
29+
30+
namespace {
31+
/// Remove stores followed by a store to the same pointer
32+
/// if the pointer is not used in between.
33+
/// ```
34+
/// cc.store %c0_i64, %1 : !cc.ptr<i64>
35+
/// // no use of %1 until next line
36+
/// cc.store %0, %1 : !cc.ptr<i64>
37+
/// ───────────────────────────────────────────
38+
/// cc.store %0, %1 : !cc.ptr<i64>
39+
/// ```
40+
class SimplifyWritesAnalysis {
41+
public:
42+
SimplifyWritesAnalysis(DominanceInfo &di, Operation *op) : dom(di) {
43+
for (auto &region : op->getRegions())
44+
for (auto &b : region)
45+
collectBlockInfo(&b);
46+
}
47+
48+
/// Remove stores followed by a store to the same pointer if the pointer is
49+
/// not used in between, using collected block info.
50+
void removeOverriddenStores() {
51+
SmallVector<Operation *> toErase;
52+
53+
for (const auto &[block, ptrToStores] : blockInfo) {
54+
for (const auto &[ptr, stores] : ptrToStores) {
55+
if (stores.size() > 1) {
56+
auto replacement = stores.back();
57+
for (auto it = stores.rend(); it != stores.rbegin(); it++) {
58+
auto store = *it;
59+
if (isReplacement(ptr, *store, *replacement)) {
60+
LLVM_DEBUG(llvm::dbgs() << "replacing store " << store
61+
<< " by: " << replacement << '\n');
62+
toErase.push_back(store->getOperation());
63+
}
64+
}
65+
}
66+
}
67+
}
68+
69+
for (auto *op : toErase)
70+
op->erase();
71+
}
72+
73+
private:
74+
/// Detect if value is used in the op or its nested blocks.
75+
bool isReplacement(Value ptr, cudaq::cc::StoreOp store,
76+
cudaq::cc::StoreOp replacement) const {
77+
// Check that there are no stores dominated by the store and not dominated
78+
// by the replacement (i.e. used in between the store and the replacement)
79+
for (auto *user : ptr.getUsers()) {
80+
if (user != store && user != replacement) {
81+
if (dom.dominates(store, user) && !dom.dominates(replacement, user)) {
82+
LLVM_DEBUG(llvm::dbgs() << "store " << replacement
83+
<< " is used before: " << store << '\n');
84+
return false;
85+
}
86+
}
87+
}
88+
return true;
89+
}
90+
91+
/// Collect all stores to a pointer for a block.
92+
void collectBlockInfo(Block *block) {
93+
for (auto &op : *block) {
94+
for (auto &region : op.getRegions())
95+
for (auto &b : region)
96+
collectBlockInfo(&b);
97+
98+
if (auto store = dyn_cast<cudaq::cc::StoreOp>(&op)) {
99+
auto ptr = store.getPtrvalue();
100+
if (isStoreToStack(store)) {
101+
auto ptrToStores = blockInfo.FindAndConstruct(block).second;
102+
auto stores = ptrToStores.FindAndConstruct(ptr).second;
103+
stores.push_back(&store);
104+
}
105+
}
106+
}
107+
}
108+
109+
/// Detect stores to stack locations, for example:
110+
/// ```
111+
/// %1 = cc.alloca !cc.array<i64 x 2>
112+
///
113+
/// %2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
114+
/// cc.store %c0_i64, %2 : !cc.ptr<i64>
115+
///
116+
/// %3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
117+
/// cc.store %c0_i64, %3 : !cc.ptr<i64>
118+
/// ```
119+
static bool isStoreToStack(cudaq::cc::StoreOp store) {
120+
auto ptrOp = store.getPtrvalue();
121+
if (auto cast = ptrOp.getDefiningOp<cudaq::cc::CastOp>())
122+
ptrOp = cast.getOperand();
123+
124+
if (auto computePtr = ptrOp.getDefiningOp<cudaq::cc::ComputePtrOp>())
125+
ptrOp = computePtr.getBase();
126+
127+
return isa_and_present<cudaq::cc::AllocaOp>(ptrOp.getDefiningOp());
128+
}
129+
130+
DominanceInfo &dom;
131+
DenseMap<Block *, DenseMap<Value, SmallVector<cudaq::cc::StoreOp *>>>
132+
blockInfo;
133+
};
134+
135+
class WriteAfterWriteEliminationPass
136+
: public cudaq::opt::impl::WriteAfterWriteEliminationBase<
137+
WriteAfterWriteEliminationPass> {
138+
public:
139+
using WriteAfterWriteEliminationBase::WriteAfterWriteEliminationBase;
140+
141+
void runOnOperation() override {
142+
auto op = getOperation();
143+
DominanceInfo domInfo(op);
144+
145+
LLVM_DEBUG(llvm::dbgs()
146+
<< "Before write after write elimination: " << *op << '\n');
147+
148+
auto analysis = SimplifyWritesAnalysis(domInfo, op);
149+
analysis.removeOverriddenStores();
150+
151+
LLVM_DEBUG(llvm::dbgs()
152+
<< "After write after write elimination: " << *op << '\n');
153+
}
154+
};
155+
} // namespace
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// ========================================================================== //
2+
// Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. //
3+
// All rights reserved. //
4+
// //
5+
// This source code and the accompanying materials are made available under //
6+
// the terms of the Apache License 2.0 which accompanies this distribution. //
7+
// ========================================================================== //
8+
9+
// RUN: cudaq-opt -write-after-write-elimination %s | FileCheck %s
10+
11+
func.func @test_two_stores_same_pointer() {
12+
%c0_i64 = arith.constant 0 : i64
13+
%0 = quake.alloca !quake.veq<2>
14+
%1 = cc.const_array [1] : !cc.array<i64 x 1>
15+
%2 = cc.extract_value %1[0] : (!cc.array<i64 x 1>) -> i64
16+
%3 = cc.alloca !cc.array<i64 x 1>
17+
%4 = cc.cast %3 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
18+
cc.store %c0_i64, %4 : !cc.ptr<i64>
19+
cc.store %2, %4 : !cc.ptr<i64>
20+
%5 = cc.load %4 : !cc.ptr<i64>
21+
%6 = quake.extract_ref %0[%5] : (!quake.veq<2>, i64) -> !quake.ref
22+
quake.x %6 : (!quake.ref) -> ()
23+
return
24+
}
25+
26+
// CHECK-LABEL: func.func @test_two_stores_same_pointer() {
27+
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2>
28+
// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array<i64 x 1>
29+
// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array<i64 x 1>) -> i64
30+
// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array<i64 x 1>
31+
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
32+
// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr<i64>
33+
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr<i64>
34+
// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref
35+
// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> ()
36+
// CHECK: return
37+
// CHECK: }
38+
39+
func.func @test_two_stores_different_pointers() {
40+
%c0_i64 = arith.constant 0 : i64
41+
%c1_i64 = arith.constant 1 : i64
42+
%0 = quake.alloca !quake.veq<2>
43+
%1 = cc.alloca !cc.array<i64 x 1>
44+
%2 = cc.alloca i64
45+
cc.store %c0_i64, %2 : !cc.ptr<i64>
46+
%3 = cc.alloca i64
47+
cc.store %c1_i64, %3 : !cc.ptr<i64>
48+
return
49+
}
50+
51+
// CHECK-LABEL: func.func @test_two_stores_different_pointers() {
52+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
53+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
54+
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.veq<2>
55+
// CHECK: %[[VAL_3:.*]] = cc.alloca !cc.array<i64 x 1>
56+
// CHECK: %[[VAL_4:.*]] = cc.alloca i64
57+
// CHECK: cc.store %[[VAL_0]], %[[VAL_4]] : !cc.ptr<i64>
58+
// CHECK: %[[VAL_5:.*]] = cc.alloca i64
59+
// CHECK: cc.store %[[VAL_1]], %[[VAL_5]] : !cc.ptr<i64>
60+
// CHECK: return
61+
// CHECK: }
62+
63+
func.func @test_two_stores_same_pointer_interleaving() {
64+
%c0_i64 = arith.constant 0 : i64
65+
%c1_i64 = arith.constant 1 : i64
66+
%1 = cc.alloca !cc.array<i64 x 2>
67+
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
68+
cc.store %c0_i64, %2 : !cc.ptr<i64>
69+
%3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
70+
cc.store %c0_i64, %3 : !cc.ptr<i64>
71+
cc.store %c1_i64, %2 : !cc.ptr<i64>
72+
cc.store %c1_i64, %3 : !cc.ptr<i64>
73+
return
74+
}
75+
76+
// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() {
77+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
78+
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<i64 x 2>
79+
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
80+
// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
81+
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i64>
82+
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i64>
83+
// CHECK: return
84+
// CHECK: }
85+

0 commit comments

Comments
 (0)