Skip to content

Commit 0778d24

Browse files
Michael LeeMichael Lee
authored andcommitted
generalize annotation logic for parameterized numeric types
1 parent e2f9829 commit 0778d24

File tree

2 files changed

+264
-41
lines changed

2 files changed

+264
-41
lines changed

sqlglot/typing/snowflake.py

Lines changed: 172 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing as t
44

5+
from decimal import Decimal
56
from sqlglot import exp
67
from sqlglot.typing import EXPRESSION_METADATA
78

@@ -81,7 +82,7 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
8182
return expression
8283

8384

84-
def _extract_precision_scale(data_type: exp.DataType) -> t.Tuple[t.Optional[int], t.Optional[int]]:
85+
def _extract_type_precision_scale(data_type: exp.DataType) -> t.Tuple[t.Optional[int], t.Optional[int]]:
8586
"""Extract precision and scale from a parameterized numeric type."""
8687
expressions = data_type.expressions
8788
if not expressions:
@@ -103,78 +104,204 @@ def _extract_precision_scale(data_type: exp.DataType) -> t.Tuple[t.Optional[int]
103104
return precision, scale
104105

105106

106-
def _compute_nullif_result_type(
107-
base_type: exp.DataType, p1: int, s1: int, p2: int, s2: int
107+
def _extract_literal_precision_scale(num_str: str) -> t.Tuple[int, int]:
108+
d = Decimal(num_str).normalize()
109+
s = format(d, "f").lstrip("-")
110+
111+
if "." in s:
112+
int_part, frac_part = s.split(".", 1)
113+
precision = len(int_part + frac_part)
114+
scale = len(frac_part.rstrip("0"))
115+
else:
116+
precision = len(s)
117+
scale = 0
118+
return precision, scale
119+
120+
121+
def _is_float(type_: t.Union[exp.DataType, exp.DataType.Type, None]) -> bool:
122+
return isinstance(type_, exp.DataType) and type_.is_type(exp.DataType.Type.FLOAT)
123+
124+
125+
def _is_parameterized_numeric(type_: t.Union[exp.DataType, exp.DataType.Type, None]) -> bool:
126+
return (
127+
isinstance(type_, exp.DataType)
128+
and type_.is_type(*exp.DataType.NUMERIC_TYPES)
129+
and bool(type_.expressions)
130+
)
131+
132+
133+
def _get_normalized_type(
134+
expression: exp.Expression,
135+
) -> t.Union[exp.DataType, exp.DataType.Type, None]:
136+
"""
137+
Normalizes numeric expressions to their parameterized representation.
138+
For literal numbers, return the parameterized representation.
139+
For integer types, return NUMBER(38, 0).
140+
"""
141+
if isinstance(expression, exp.Literal) and expression.is_number:
142+
precision, scale = _extract_literal_precision_scale(expression.this)
143+
return exp.DataType(
144+
this=exp.DataType.Type.DECIMAL,
145+
expressions=[
146+
exp.DataTypeParam(this=exp.Literal.number(precision)),
147+
exp.DataTypeParam(this=exp.Literal.number(scale)),
148+
],
149+
)
150+
151+
if expression.type.is_type(*exp.DataType.INTEGER_TYPES) and not expression.type.expressions:
152+
return exp.DataType(
153+
this=exp.DataType.Type.DECIMAL,
154+
expressions=[
155+
exp.DataTypeParam(this=exp.Literal.number(38)),
156+
exp.DataTypeParam(this=exp.Literal.number(0)),
157+
],
158+
)
159+
160+
return expression.type
161+
162+
163+
def _coerce_two_parameterized_types(
164+
type1: exp.DataType, p1: int, s1: int, type2: exp.DataType, p2: int, s2: int
108165
) -> t.Optional[exp.DataType]:
109166
"""
110-
Compute result type for NULLIF with two parameterized numeric types.
167+
Coerce two parameterized numeric types using Snowflake's type coercion rules.
111168
112169
Rules:
113170
- If p1 >= p2 AND s1 >= s2: return type1
114171
- If p2 >= p1 AND s2 >= s1: return type2
115-
- Otherwise: return DECIMAL(max(p1, p2) + |s2 - s1|, max(s1, s2))
172+
- Otherwise: return NUMBER(min(38, max(p1, p2) + |s2 - s1|), max(s1, s2))
116173
"""
117174

118175
if p1 >= p2 and s1 >= s2:
119-
return base_type.copy()
176+
return type1.copy()
120177

121178
if p2 >= p1 and s2 >= s1:
122-
return exp.DataType(
123-
this=base_type.this,
124-
expressions=[
125-
exp.DataTypeParam(this=exp.Literal.number(p2)),
126-
exp.DataTypeParam(this=exp.Literal.number(s2)),
127-
],
128-
)
179+
return type2.copy()
129180

130181
result_scale = max(s1, s2)
131-
result_precision = max(p1, p2) + abs(s2 - s1)
182+
result_precision = min(38, max(p1, p2) + abs(s2 - s1))
132183

133184
return exp.DataType(
134-
this=base_type.this,
185+
this=type1.this,
135186
expressions=[
136187
exp.DataTypeParam(this=exp.Literal.number(result_precision)),
137188
exp.DataTypeParam(this=exp.Literal.number(result_scale)),
138189
],
139190
)
140191

141192

142-
def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
193+
def _coerce_parameterized_numeric_types(
194+
self: TypeAnnotator, types: t.List[t.Union[exp.DataType, exp.DataType.Type, None]]
195+
) -> t.Optional[t.Union[exp.DataType, exp.DataType.Type]]:
143196
"""
144-
Annotate NULLIF with Snowflake-specific type coercion rules for parameterized numeric types.
145-
146-
When both arguments are parameterized numeric types (e.g., DECIMAL(p, s)):
147-
- If one type dominates (p1 >= p2 AND s1 >= s2), use that type
148-
- Otherwise, compute new type with:
149-
- scale = max(s1, s2)
150-
- precision = max(p1, p2) + |s2 - s1|
197+
Generalized function to coerce multiple parameterized numeric types.
198+
Applies Snowflake's coercion logic pairwise across all types.
151199
"""
200+
if not types:
201+
return None
202+
203+
result_type = None
204+
205+
for current_type in types:
206+
if not current_type:
207+
continue
208+
209+
if result_type is None:
210+
result_type = current_type
211+
continue
212+
213+
if _is_parameterized_numeric(result_type) and _is_parameterized_numeric(current_type):
214+
p1, s1 = _extract_type_precision_scale(result_type)
215+
p2, s2 = _extract_type_precision_scale(current_type)
216+
217+
if p1 is not None and s1 is not None and p2 is not None and s2 is not None:
218+
result_type = _coerce_two_parameterized_types(
219+
result_type, p1, s1, current_type, p2, s2
220+
)
221+
else:
222+
result_type = self._maybe_coerce(result_type, current_type)
223+
else:
224+
result_type = self._maybe_coerce(result_type, current_type)
225+
226+
return result_type
227+
152228

229+
def _apply_numeric_coercion(
230+
self: TypeAnnotator,
231+
expression: exp.Expression,
232+
expressions_to_coerce: t.List[exp.Expression],
233+
) -> t.Optional[exp.Expression]:
234+
if any(_is_float(e.type) for e in expressions_to_coerce):
235+
self._set_type(expression, exp.DataType.Type.FLOAT)
236+
return expression
237+
238+
if any(_is_parameterized_numeric(e.type) for e in expressions_to_coerce):
239+
normalized_types = [_get_normalized_type(e) for e in expressions_to_coerce]
240+
result_type = _coerce_parameterized_numeric_types(self, normalized_types)
241+
if result_type:
242+
self._set_type(expression, result_type)
243+
return expression
244+
245+
return None
246+
247+
248+
def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
153249
self._annotate_args(expression)
154250

155-
this_type = expression.this.type
156-
expr_type = expression.expression.type
251+
expressions_to_coerce = []
252+
if expression.this:
253+
expressions_to_coerce.append(expression.this)
254+
if expression.expression:
255+
expressions_to_coerce.append(expression.expression)
157256

158-
if not this_type or not expr_type:
257+
coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
258+
if coerced_result is None:
159259
return self._annotate_by_args(expression, "this", "expression")
160260

161-
# Snowflake specific type coercion for NULLIF with parameterized numeric types
162-
if (
163-
this_type.is_type(*exp.DataType.NUMERIC_TYPES)
164-
and expr_type.is_type(*exp.DataType.NUMERIC_TYPES)
165-
and this_type.expressions
166-
and expr_type.expressions
167-
):
168-
p1, s1 = _extract_precision_scale(this_type)
169-
p2, s2 = _extract_precision_scale(expr_type)
261+
return coerced_result
262+
263+
264+
def _annotate_iff(self: TypeAnnotator, expression: exp.If) -> exp.If:
265+
self._annotate_args(expression)
266+
267+
expressions_to_coerce = []
268+
true_expr = expression.args.get("true")
269+
false_expr = expression.args.get("false")
270+
271+
if true_expr:
272+
expressions_to_coerce.append(true_expr)
273+
if false_expr:
274+
expressions_to_coerce.append(false_expr)
275+
276+
coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
277+
if coerced_result is None:
278+
return self._annotate_by_args(expression, "true", "false")
279+
280+
return coerced_result
281+
282+
283+
def _annotate_with_numeric_coercion(
284+
self: TypeAnnotator, expression: exp.Expression
285+
) -> exp.Expression:
286+
"""
287+
Generic annotator for functions that return one of their numeric arguments.
288+
289+
These functions all have the same structure: 'this' + 'expressions' arguments,
290+
and they all need to coerce all argument types to find a common result type.
291+
"""
292+
self._annotate_args(expression)
293+
294+
expressions_to_coerce = []
295+
if expression.this:
296+
expressions_to_coerce.append(expression.this)
297+
if expression.expressions:
298+
expressions_to_coerce.extend(expression.expressions)
170299

171-
if p1 is not None and s1 is not None and p2 is not None and s2 is not None:
172-
result_type = _compute_nullif_result_type(this_type, p1, s1, p2, s2)
173-
if result_type:
174-
self._set_type(expression, result_type)
175-
return expression
300+
coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
301+
if coerced_result is None:
302+
return self._annotate_by_args(expression, "this", "expressions")
176303

177-
return self._annotate_by_args(expression, "this", "expression")
304+
return coerced_result
178305

179306

180307
EXPRESSION_METADATA = {
@@ -344,6 +471,7 @@ def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
344471
exp.Uuid,
345472
}
346473
},
474+
exp.Coalesce: {"annotator": _annotate_with_numeric_coercion},
347475
exp.ConcatWs: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")},
348476
exp.ConvertTimezone: {
349477
"annotator": lambda self, e: self._annotate_with_type(
@@ -355,9 +483,12 @@ def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
355483
},
356484
exp.DateAdd: {"annotator": _annotate_date_or_time_add},
357485
exp.DecodeCase: {"annotator": _annotate_decode_case},
486+
exp.Greatest: {"annotator": _annotate_with_numeric_coercion},
358487
exp.GreatestIgnoreNulls: {
359488
"annotator": lambda self, e: self._annotate_by_args(e, "expressions")
360489
},
490+
exp.If: {"annotator": _annotate_iff},
491+
exp.Least: {"annotator": _annotate_with_numeric_coercion},
361492
exp.LeastIgnoreNulls: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")},
362493
exp.Nullif: {"annotator": _annotate_nullif},
363494
exp.Reverse: {"annotator": _annotate_reverse},

0 commit comments

Comments
 (0)