Skip to content

Commit

Permalink
Merge pull request #723 from egraphs-good/oflatt-terms-perf
Browse files Browse the repository at this point in the history
Improve performance of terms ruleset greatly
  • Loading branch information
oflatt authored Feb 11, 2025
2 parents f0da7d6 + dd17c46 commit ea983e4
Show file tree
Hide file tree
Showing 12 changed files with 256 additions and 99 deletions.
19 changes: 10 additions & 9 deletions dag_in_context/src/optimizations/select.egg
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
(ruleset select_opt)


;; inlined (Get thn i) makes the query faster ):
(rule
(
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)

(= thn_out (Get thn i))
(= els_out (Get els i))
(ExprIsPure thn_out)
(ExprIsPure els_out)
(ExprIsPure (Get thn i))
(ExprIsPure (Get els i))

(> 10 (Expr-size thn_out)) ; TODO: Tune these size limits
(> 10 (Expr-size els_out))
(= (TCPair t1 c1) (ExtractedExpr thn_out))
(= (TCPair t2 c2) (ExtractedExpr els_out))
(> 10 (Expr-size (Get thn i))) ; TODO: Tune these size limits
(> 10 (Expr-size (Get els i)))
(= (TCPair t1 c1) (ExtractedExpr (Get thn i)))
(= (TCPair t2 c2) (ExtractedExpr (Get els i)))

(ContextOf if_e ctx)
)
(
(union (Get if_e i)
Expand Down
6 changes: 5 additions & 1 deletion dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ pub(crate) fn helpers() -> String {
(saturate canon)
(saturate interval-analysis)
(saturate terms)
(saturate
terms
(saturate
terms-helpers
(saturate terms-helpers-helpers)))
;; memory-helpers TODO run memory helpers for memory optimizations
;; finally, subsume now that helpers are done
Expand Down
24 changes: 24 additions & 0 deletions dag_in_context/src/type_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,30 @@
((panic "(Load) expected pointer, received tuple"))
:ruleset error-checking)

(rule (
(= lhs (Top (Select) pred v1 v2))
)
((ExpectType pred (Base (BoolT)) "(Select)"))
:ruleset type-analysis)

(rule (
(= lhs (Top (Select) pred v1 v2))
(HasType v1 ty)
(HasType v2 ty)
)
((HasType lhs ty))
:ruleset type-analysis)

(rule (
(= lhs (Top (Select) pred v1 v2))
(HasType v1 ty1)
(HasType v2 ty2)
(!= ty1 ty2)
)
((panic "(Select) branches had different types"))
:ruleset error-checking)


; Binary ops

;; Operators that have type Type -> Type -> Type
Expand Down
61 changes: 49 additions & 12 deletions dag_in_context/src/utility/terms.egg
Original file line number Diff line number Diff line change
@@ -1,43 +1,80 @@
(ruleset terms)
;; helpers keeps track of the new best extracted terms
(ruleset terms-helpers)
;; helpers-helpers runs `Smaller` rules, resolving the merge function for helpers
(ruleset terms-helpers-helpers)

(sort TermAndCost)
(function Smaller (TermAndCost TermAndCost) TermAndCost)

(function ExtractedExpr (Expr) TermAndCost
:merge (Smaller old new))
;; potential extractions- use so that when the costs are equal, we don't change the term
;; this preserves egglog's timestamp of when the last time ExtractedExpr was changed, fixing a big performance problem
(relation PotentialExtractedExpr (Expr TermAndCost))

(function TCPair (Term i64) TermAndCost)

(function NoTerm () Term)

;; set extracted expr to default value
(rule ((PotentialExtractedExpr expr termandcost))
((set (ExtractedExpr expr) (TCPair (NoTerm) 10000000000000000)))
:ruleset terms-helpers)

;; set extracted expr to new value as long as not equal
(rule ((PotentialExtractedExpr expr (TCPair term cost))
(= (ExtractedExpr expr) (TCPair oldterm oldcost))
(< cost oldcost))
((set (ExtractedExpr expr) (TCPair term cost)))
:ruleset terms-helpers)

;; if the cost is negative panic, terms got too big
(rule ((PotentialExtractedExpr expr (TCPair term cost))
(< cost 0))
((panic "Negative cost"))
:ruleset terms-helpers)

;; Resolve Smaller
(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(<= cost1 cost2)
(< cost1 cost2)
)
((union lhs (TCPair t1 cost1)))
:ruleset terms)
:ruleset terms-helpers-helpers)

(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(> cost1 cost2)
)
((union lhs (TCPair t2 cost2)))
:ruleset terms)
:ruleset terms-helpers-helpers)


(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(= cost1 cost2)
)
;; arbitrarily pick first one
((union lhs (TCPair t1 cost1)))
:ruleset terms-helpers-helpers)


; Compute smallest Expr bottom-up
(rule ((= lhs (Const c ty ass)))
((set (ExtractedExpr lhs) (TCPair (TermConst c) 1)))
((PotentialExtractedExpr lhs (TCPair (TermConst c) 1)))
:ruleset terms)

(rule ((= lhs (Arg ty ass)))
((set (ExtractedExpr lhs) (TCPair (TermArg) 1)))
((PotentialExtractedExpr lhs (TCPair (TermArg) 1)))
:ruleset terms)

(rule (
(= lhs (Bop o e1 e2))
(= (TCPair t1 c1) (ExtractedExpr e1))
(= (TCPair t2 c2) (ExtractedExpr e2))
)
((set (ExtractedExpr lhs) (TCPair (TermBop o t1 t2) (+ 1 (+ c1 c2)))))
((PotentialExtractedExpr lhs (TCPair (TermBop o t1 t2) (+ 1 (+ c1 c2)))))
:ruleset terms)

(rule (
Expand All @@ -46,22 +83,22 @@
(= (TCPair t2 c2) (ExtractedExpr e2))
(= (TCPair t3 c3) (ExtractedExpr e3))
)
((set (ExtractedExpr lhs) (TCPair (TermTop o t1 t2 t3) (+ (+ 1 c1) (+ c2 c3)))))
((PotentialExtractedExpr lhs (TCPair (TermTop o t1 t2 t3) (+ (+ 1 c1) (+ c2 c3)))))
:ruleset terms)

(rule (
(= lhs (Uop o e1))
(= (TCPair t1 c1) (ExtractedExpr e1))
)
((set (ExtractedExpr lhs) (TCPair (TermUop o t1) (+ 1 c1))))
((PotentialExtractedExpr lhs (TCPair (TermUop o t1) (+ 1 c1))))
:ruleset terms)

(rule (
(= lhs (Get tup i))
(= (TCPair t1 c1) (ExtractedExpr tup))
)
; cost of the get is the same as the cost of the whole tuple
((set (ExtractedExpr lhs) (TCPair (TermGet t1 i) c1)))
((PotentialExtractedExpr lhs (TCPair (TermGet t1 i) c1)))
:ruleset terms)

; todo Alloc
Expand All @@ -73,7 +110,7 @@
(= (TCPair t1 c1) (ExtractedExpr e1))
)
; cost of single is same as cost of the element
((set (ExtractedExpr lhs) (TCPair (TermSingle t1) c1)))
((PotentialExtractedExpr lhs (TCPair (TermSingle t1) c1)))
:ruleset terms)

(rule (
Expand All @@ -82,7 +119,7 @@
(= (TCPair t2 c2) (ExtractedExpr e2))
)
; cost of concat is sum of the costs
((set (ExtractedExpr lhs) (TCPair (TermConcat t1 t2) (+ c1 c2))))
((PotentialExtractedExpr lhs (TCPair (TermConcat t1 t2) (+ c1 c2))))
:ruleset terms)


Expand All @@ -95,7 +132,7 @@
; (= (TCPair t4 c4) (ExtractedExpr els))
; )
; ; cost of if is 10 + cost of pred + cost of input + max of branch costs
; ((set (ExtractedExpr lhs) (TCPair (TermIf t1 t2 t3 t4) (+ 10 (+ (+ c1 c2) (max c3 c4))))))
; ((PotentialExtractedExpr lhs (TCPair (TermIf t1 t2 t3 t4) (+ 10 (+ (+ c1 c2) (max c3 c4))))))
; :ruleset terms)

(sort Node)
Expand Down
20 changes: 20 additions & 0 deletions tests/passing/small/simple_select_after_block_diamond.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ARGS: 1
@main(v0: int) {
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
c4_: int = const 4;
v5_: int = select v3_ c4_ c1_;
v6_: int = id v5_;
v7_: int = id c1_;
br v3_ .b8_ .b9_;
.b9_:
v10_: int = add c2_ v5_;
v6_: int = id v10_;
v7_: int = id c1_;
.b8_:
v11_: int = add c1_ v6_;
print v11_;
ret;
}

32 changes: 16 additions & 16 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@ expression: visualization.result
---
# ARGS: 1
@main(v0: int) {
c1_: int = const 2;
v2_: bool = lt v0 c1_;
c3_: int = const 1;
v4_: int = id c3_;
v5_: int = id c3_;
v6_: int = id c1_;
br v2_ .b7_ .b8_;
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
v4_: int = id c1_;
v5_: int = id c1_;
v6_: int = id c2_;
br v3_ .b7_ .b8_;
.b7_:
c9_: bool = const true;
c10_: int = const 4;
v11_: int = select c9_ c10_ c1_;
v11_: int = select c9_ c10_ c2_;
v4_: int = id v11_;
v5_: int = id c3_;
v6_: int = id c1_;
v12_: int = add c1_ v4_;
v13_: int = select v2_ v4_ v12_;
v14_: int = add c3_ v13_;
v5_: int = id c1_;
v6_: int = id c2_;
v12_: int = add c2_ v4_;
v13_: int = select v3_ v4_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
jmp .b15_;
.b8_:
v12_: int = add c1_ v4_;
v13_: int = select v2_ v4_ v12_;
v14_: int = add c3_ v13_;
v12_: int = add c2_ v4_;
v13_: int = select v3_ v4_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
.b15_:
Expand Down
45 changes: 25 additions & 20 deletions tests/snapshots/files__branch_hoisting-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,33 @@ expression: visualization.result
@main(v0: int) {
c1_: int = const 0;
c2_: int = const 500;
v3_: int = id c1_;
v3_: bool = eq c1_ v0;
v4_: int = id c1_;
v5_: int = id v0;
v6_: int = id c1_;
v7_: int = id c2_;
.b8_:
v9_: bool = eq v5_ v6_;
c10_: int = const 2;
v11_: int = mul c10_ v4_;
c12_: int = const 3;
v13_: int = mul c12_ v4_;
v14_: int = select v9_ v11_ v13_;
c15_: int = const 1;
v16_: int = add c15_ v4_;
v17_: bool = lt v16_ v7_;
v3_: int = id v14_;
v4_: int = id v16_;
v5_: int = id v5_;
v5_: int = id c1_;
v6_: int = id v0;
v7_: int = id c1_;
v8_: int = id c2_;
v9_: bool = id v3_;
.b10_:
c11_: int = const 1;
v12_: int = add c11_ v5_;
v13_: int = add c11_ v12_;
v14_: int = add c11_ v13_;
c15_: int = const 2;
v16_: int = mul c15_ v14_;
c17_: int = const 3;
v18_: int = mul c17_ v14_;
v19_: int = select v9_ v16_ v18_;
v20_: int = add c11_ v14_;
v21_: bool = lt v20_ v8_;
v4_: int = id v19_;
v5_: int = id v20_;
v6_: int = id v6_;
v7_: int = id v7_;
br v17_ .b8_ .b18_;
.b18_:
print v3_;
v8_: int = id v8_;
v9_: bool = id v9_;
br v21_ .b10_ .b22_;
.b22_:
print v4_;
ret;
}
16 changes: 8 additions & 8 deletions tests/snapshots/files__branch_hoisting-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ expression: visualization.result
v13_: int = id c3_;
.b14_:
v15_: bool = eq v11_ v12_;
c16_: int = const 2;
c17_: int = const 1;
v18_: int = add c17_ v10_;
v19_: int = add c17_ v18_;
v20_: int = add c17_ v19_;
v21_: int = mul c16_ v20_;
c16_: int = const 1;
v17_: int = add c16_ v10_;
v18_: int = add c16_ v17_;
v19_: int = add c16_ v18_;
c20_: int = const 2;
v21_: int = mul c20_ v19_;
c22_: int = const 3;
v23_: int = mul c22_ v20_;
v23_: int = mul c22_ v19_;
v24_: int = select v15_ v21_ v23_;
v25_: int = add c17_ v20_;
v25_: int = add c16_ v19_;
v26_: bool = lt v25_ v13_;
v9_: int = id v24_;
v10_: int = id v25_;
Expand Down
Loading

0 comments on commit ea983e4

Please sign in to comment.