@@ -36,12 +36,6 @@ class UnsupportedUnit(Exception):
3636 pass
3737
3838
39- def annotate_boolean (expression : exp .Expression ) -> exp .Expression :
40- if expression and not expression .type :
41- expression .type = exp .DataType .Type .BOOLEAN
42- return expression
43-
44-
4539def simplify (
4640 expression : exp .Expression ,
4741 constant_propagation : bool = False ,
@@ -138,7 +132,6 @@ def _simplify(expression):
138132
139133 if coalesce_simplification :
140134 new_node = simplify_coalesce (new_node , dialect )
141-
142135 new_node .parent = parent
143136
144137 new_node = simplify_literals (new_node , root )
@@ -183,14 +176,18 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
183176
184177 expression = annotate_boolean (
185178 exp .and_ (
186- exp .GTE (this = expression .this .copy (), expression = expression .args ["low" ]),
187- exp .LTE (this = expression .this .copy (), expression = expression .args ["high" ]),
179+ annotate_boolean (
180+ exp .GTE (this = expression .this .copy (), expression = expression .args ["low" ])
181+ ),
182+ annotate_boolean (
183+ exp .LTE (this = expression .this .copy (), expression = expression .args ["high" ])
184+ ),
188185 copy = False ,
189186 )
190187 )
191188
192189 if negate :
193- expression = exp .paren (expression , copy = False )
190+ expression = annotate_boolean ( exp .paren (expression , copy = False ) )
194191
195192 return expression
196193
@@ -457,7 +454,9 @@ def uniq_sort(expression, root=True):
457454 else :
458455 # we didn't have to sort but maybe we need to dedup
459456 if deduped and len (deduped ) < len (flattened ):
460- expression = annotate_boolean (result_func (* deduped .values (), copy = False ))
457+ expression = annotate_boolean (
458+ result_func (* (e for e in deduped .values ()), copy = False )
459+ )
461460
462461 return expression
463462
@@ -867,7 +866,8 @@ def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.E
867866 )
868867 )
869868
870- return exp .paren (annotate_boolean (exp .or_ (and_rhs , and_lhs , copy = False )))
869+ or_expr = annotate_boolean (exp .or_ (and_rhs , and_lhs , copy = False ))
870+ return annotate_boolean (exp .paren (or_expr ))
871871
872872
873873CONCATS = (exp .Concat , exp .DPipe )
@@ -1105,12 +1105,13 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
11051105 ranges = merge_ranges (ranges )
11061106 target_type = extract_type (* rs )
11071107
1108- return annotate_boolean (
1108+ or_expr = annotate_boolean (
11091109 exp .or_ (
11101110 * [_datetrunc_eq_expression (l , drange , target_type ) for drange in ranges ],
11111111 copy = False ,
11121112 )
11131113 )
1114+ return or_expr
11141115
11151116 return expression
11161117
@@ -1192,7 +1193,21 @@ def is_null(a: exp.Expression) -> bool:
11921193
11931194
11941195def is_boolean (expression : exp .Expression ) -> bool :
1195- return expression .unnest ().is_type (exp .DataType .Type .BOOLEAN )
1196+ return expression .is_type (exp .DataType .Type .BOOLEAN )
1197+
1198+
1199+ def annotate_boolean (expression : exp .Expression ) -> exp .Expression :
1200+ if not expression .type :
1201+ expression .type = (
1202+ expression .this .type if isinstance (expression , exp .Paren ) else exp .DataType .Type .BOOLEAN
1203+ )
1204+
1205+ if isinstance (expression , exp .Connector ):
1206+ if isinstance (left := expression .left , exp .Paren ):
1207+ left .type = exp .DataType .Type .BOOLEAN
1208+ if isinstance (right := expression .right , exp .Paren ):
1209+ right .type = exp .DataType .Type .BOOLEAN
1210+ return expression
11961211
11971212
11981213def eval_boolean (expression , a , b ):
@@ -1359,6 +1374,9 @@ def _flat_simplify(expression, simplifier, root=True):
13591374 queue = deque (expression .flatten (unnest = False ))
13601375 size = len (queue )
13611376
1377+ for operand in queue :
1378+ annotate_boolean (operand )
1379+
13621380 while queue :
13631381 a = queue .popleft ()
13641382
0 commit comments