Skip to content

Commit e5d9ed2

Browse files
committed
refactor impl 2
1 parent 0133ba9 commit e5d9ed2

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

sqlglot/optimizer/simplify.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ class UnsupportedUnit(Exception):
3636
pass
3737

3838

39-
def annotate_boolean(expression: exp.Expression) -> exp.Expression:
40-
if expression and not expression.type:
41-
expression.type = exp.DataType.Type.BOOLEAN
42-
return expression
43-
44-
4539
def simplify(
4640
expression: exp.Expression,
4741
constant_propagation: bool = False,
@@ -138,7 +132,6 @@ def _simplify(expression):
138132

139133
if coalesce_simplification:
140134
new_node = simplify_coalesce(new_node, dialect)
141-
142135
new_node.parent = parent
143136

144137
new_node = simplify_literals(new_node, root)
@@ -183,14 +176,18 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
183176

184177
expression = annotate_boolean(
185178
exp.and_(
186-
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
187-
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
179+
annotate_boolean(
180+
exp.GTE(this=expression.this.copy(), expression=expression.args["low"])
181+
),
182+
annotate_boolean(
183+
exp.LTE(this=expression.this.copy(), expression=expression.args["high"])
184+
),
188185
copy=False,
189186
)
190187
)
191188

192189
if negate:
193-
expression = exp.paren(expression, copy=False)
190+
expression = annotate_boolean(exp.paren(expression, copy=False))
194191

195192
return expression
196193

@@ -457,7 +454,9 @@ def uniq_sort(expression, root=True):
457454
else:
458455
# we didn't have to sort but maybe we need to dedup
459456
if deduped and len(deduped) < len(flattened):
460-
expression = annotate_boolean(result_func(*deduped.values(), copy=False))
457+
expression = annotate_boolean(
458+
result_func(*(e for e in deduped.values()), copy=False)
459+
)
461460

462461
return expression
463462

@@ -867,7 +866,8 @@ def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.E
867866
)
868867
)
869868

870-
return exp.paren(annotate_boolean(exp.or_(and_rhs, and_lhs, copy=False)))
869+
or_expr = annotate_boolean(exp.or_(and_rhs, and_lhs, copy=False))
870+
return annotate_boolean(exp.paren(or_expr))
871871

872872

873873
CONCATS = (exp.Concat, exp.DPipe)
@@ -1105,12 +1105,13 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
11051105
ranges = merge_ranges(ranges)
11061106
target_type = extract_type(*rs)
11071107

1108-
return annotate_boolean(
1108+
or_expr = annotate_boolean(
11091109
exp.or_(
11101110
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
11111111
copy=False,
11121112
)
11131113
)
1114+
return or_expr
11141115

11151116
return expression
11161117

@@ -1192,7 +1193,21 @@ def is_null(a: exp.Expression) -> bool:
11921193

11931194

11941195
def is_boolean(expression: exp.Expression) -> bool:
1195-
return expression.unnest().is_type(exp.DataType.Type.BOOLEAN)
1196+
return expression.is_type(exp.DataType.Type.BOOLEAN)
1197+
1198+
1199+
def annotate_boolean(expression: exp.Expression) -> exp.Expression:
1200+
if not expression.type:
1201+
expression.type = (
1202+
expression.this.type if isinstance(expression, exp.Paren) else exp.DataType.Type.BOOLEAN
1203+
)
1204+
1205+
if isinstance(expression, exp.Connector):
1206+
if isinstance(left := expression.left, exp.Paren):
1207+
left.type = exp.DataType.Type.BOOLEAN
1208+
if isinstance(right := expression.right, exp.Paren):
1209+
right.type = exp.DataType.Type.BOOLEAN
1210+
return expression
11961211

11971212

11981213
def eval_boolean(expression, a, b):
@@ -1359,6 +1374,9 @@ def _flat_simplify(expression, simplifier, root=True):
13591374
queue = deque(expression.flatten(unnest=False))
13601375
size = len(queue)
13611376

1377+
for operand in queue:
1378+
annotate_boolean(operand)
1379+
13621380
while queue:
13631381
a = queue.popleft()
13641382

tests/fixtures/optimizer/pushdown_predicates.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ SELECT x.a FROM (SELECT * FROM x) AS x CROSS JOIN y WHERE y.a = 1 OR (x.a = 1 AN
88
SELECT x.a FROM (SELECT * FROM x) AS x CROSS JOIN y WHERE (x.a = 1 AND x.b = 1) OR y.a = 1;
99

1010
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.a;
11-
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON FALSE OR x.a = y.a WHERE TRUE;
11+
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE;
1212

1313
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b;
1414
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON (x.a = 1 AND x.a = y.a AND x.b = 1) OR x.a = y.b WHERE (x.a = 1 AND x.a = y.a AND x.b = 1) OR x.a = y.b;

0 commit comments

Comments
 (0)