Skip to content

Commit c6a892e

Browse files
authored
[mlir][SMT] restore custom builder for forall/exists (#135470)
This reverts commit 54e70ac which itself fixed an [asan leak](https://lab.llvm.org/buildbot/#/builders/55/builds/9761) from the original upstreaming commit. The leak was due to op allocations not being `free`ed. ~~The necessary change was to explicitly `->destroy()` the ops at the end of the tests. I believe this is because the rewriter used in the tests doesn't actually insert them into a module and so without an explicit `->destroy()` no bookkeeping process is able to take care of them.~~ The necessary change was to use `OwningOpRef` which calls `op->erase()` in its [own destructor](https://github.com/makslevental/llvm-project/blob/89cfae41ecc043f8c47be4dea4b7c740d4f950b3/mlir/include/mlir/IR/OwningOpRef.h#L39).
1 parent 33e5305 commit c6a892e

File tree

4 files changed

+220
-0
lines changed

4 files changed

+220
-0
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,18 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
448448
VariadicRegion<SizedRegion<1>>:$patterns);
449449
let results = (outs BoolType:$result);
450450

451+
let builders = [
452+
OpBuilder<(ins
453+
"TypeRange":$boundVarTypes,
454+
"function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
455+
CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
456+
CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
457+
"{}">:$patternBuilder,
458+
CArg<"uint32_t", "0">:$weight,
459+
CArg<"bool", "false">:$noPattern)>
460+
];
461+
let skipDefaultBuilders = true;
462+
451463
let assemblyFormat = [{
452464
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
453465
attr-dict-with-keyword $body (`patterns` $patterns^)?

mlir/lib/Dialect/SMT/IR/SMTOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() {
432432
return verifyQuantifierRegions(*this);
433433
}
434434

435+
void ForallOp::build(
436+
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
437+
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
438+
std::optional<ArrayRef<StringRef>> boundVarNames,
439+
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
440+
uint32_t weight, bool noPattern) {
441+
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442+
boundVarNames, patternBuilder, weight, noPattern);
443+
}
444+
435445
//===----------------------------------------------------------------------===//
436446
// ExistsOp
437447
//===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() {
448458
return verifyQuantifierRegions(*this);
449459
}
450460

461+
void ExistsOp::build(
462+
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
463+
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
464+
std::optional<ArrayRef<StringRef>> boundVarNames,
465+
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
466+
uint32_t weight, bool noPattern) {
467+
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468+
boundVarNames, patternBuilder, weight, noPattern);
469+
}
470+
451471
#define GET_OP_CLASSES
452472
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"

mlir/unittests/Dialect/SMT/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRSMTTests
22
AttributeTest.cpp
3+
QuantifierTest.cpp
34
TypeTest.cpp
45
)
56

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
using namespace smt;
14+
15+
namespace {
16+
17+
//===----------------------------------------------------------------------===//
18+
// Test custom builders of ExistsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
TEST(QuantifierTest, ExistsBuilderWithPattern) {
22+
MLIRContext context;
23+
context.loadDialect<SMTDialect>();
24+
Location loc(UnknownLoc::get(&context));
25+
26+
OpBuilder builder(&context);
27+
auto boolTy = BoolType::get(&context);
28+
29+
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
30+
loc, TypeRange{boolTy, boolTy},
31+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
32+
return builder.create<AndOp>(loc, boundVars);
33+
},
34+
std::nullopt,
35+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
36+
return boundVars;
37+
},
38+
/*weight=*/2);
39+
40+
SmallVector<char, 1024> buffer;
41+
llvm::raw_svector_ostream stream(buffer);
42+
existsOp->print(stream);
43+
44+
ASSERT_STREQ(
45+
stream.str().str().c_str(),
46+
"%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
47+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
48+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
49+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
50+
}
51+
52+
TEST(QuantifierTest, ExistsBuilderNoPattern) {
53+
MLIRContext context;
54+
context.loadDialect<SMTDialect>();
55+
Location loc(UnknownLoc::get(&context));
56+
57+
OpBuilder builder(&context);
58+
auto boolTy = BoolType::get(&context);
59+
60+
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
61+
loc, TypeRange{boolTy, boolTy},
62+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
63+
return builder.create<AndOp>(loc, boundVars);
64+
},
65+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
66+
67+
SmallVector<char, 1024> buffer;
68+
llvm::raw_svector_ostream stream(buffer);
69+
existsOp->print(stream);
70+
71+
ASSERT_STREQ(stream.str().str().c_str(),
72+
"%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
73+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
74+
"smt.yield %0 : !smt.bool\n}\n");
75+
}
76+
77+
TEST(QuantifierTest, ExistsBuilderDefault) {
78+
MLIRContext context;
79+
context.loadDialect<SMTDialect>();
80+
Location loc(UnknownLoc::get(&context));
81+
82+
OpBuilder builder(&context);
83+
auto boolTy = BoolType::get(&context);
84+
85+
OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
86+
loc, TypeRange{boolTy, boolTy},
87+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
88+
return builder.create<AndOp>(loc, boundVars);
89+
},
90+
ArrayRef<StringRef>{"a", "b"});
91+
92+
SmallVector<char, 1024> buffer;
93+
llvm::raw_svector_ostream stream(buffer);
94+
existsOp->print(stream);
95+
96+
ASSERT_STREQ(stream.str().str().c_str(),
97+
"%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
98+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
99+
"%0 : !smt.bool\n}\n");
100+
}
101+
102+
//===----------------------------------------------------------------------===//
103+
// Test custom builders of ForallOp
104+
//===----------------------------------------------------------------------===//
105+
106+
TEST(QuantifierTest, ForallBuilderWithPattern) {
107+
MLIRContext context;
108+
context.loadDialect<SMTDialect>();
109+
Location loc(UnknownLoc::get(&context));
110+
111+
OpBuilder builder(&context);
112+
auto boolTy = BoolType::get(&context);
113+
114+
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
115+
loc, TypeRange{boolTy, boolTy},
116+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
117+
return builder.create<AndOp>(loc, boundVars);
118+
},
119+
ArrayRef<StringRef>{"a", "b"},
120+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
121+
return boundVars;
122+
},
123+
/*weight=*/2);
124+
125+
SmallVector<char, 1024> buffer;
126+
llvm::raw_svector_ostream stream(buffer);
127+
forallOp->print(stream);
128+
129+
ASSERT_STREQ(
130+
stream.str().str().c_str(),
131+
"%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
132+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
133+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
134+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
135+
}
136+
137+
TEST(QuantifierTest, ForallBuilderNoPattern) {
138+
MLIRContext context;
139+
context.loadDialect<SMTDialect>();
140+
Location loc(UnknownLoc::get(&context));
141+
142+
OpBuilder builder(&context);
143+
auto boolTy = BoolType::get(&context);
144+
145+
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
146+
loc, TypeRange{boolTy, boolTy},
147+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
148+
return builder.create<AndOp>(loc, boundVars);
149+
},
150+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
151+
152+
SmallVector<char, 1024> buffer;
153+
llvm::raw_svector_ostream stream(buffer);
154+
forallOp->print(stream);
155+
156+
ASSERT_STREQ(stream.str().str().c_str(),
157+
"%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
158+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
159+
"smt.yield %0 : !smt.bool\n}\n");
160+
}
161+
162+
TEST(QuantifierTest, ForallBuilderDefault) {
163+
MLIRContext context;
164+
context.loadDialect<SMTDialect>();
165+
Location loc(UnknownLoc::get(&context));
166+
167+
OpBuilder builder(&context);
168+
auto boolTy = BoolType::get(&context);
169+
170+
OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
171+
loc, TypeRange{boolTy, boolTy},
172+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
173+
return builder.create<AndOp>(loc, boundVars);
174+
},
175+
std::nullopt);
176+
177+
SmallVector<char, 1024> buffer;
178+
llvm::raw_svector_ostream stream(buffer);
179+
forallOp->print(stream);
180+
181+
ASSERT_STREQ(stream.str().str().c_str(),
182+
"%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
183+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
184+
"%0 : !smt.bool\n}\n");
185+
}
186+
187+
} // namespace

0 commit comments

Comments
 (0)