22
33import typing as t
44
5+ from decimal import Decimal
56from sqlglot import exp
67from sqlglot .typing import EXPRESSION_METADATA
78
@@ -81,7 +82,7 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
8182 return expression
8283
8384
84- def _extract_precision_scale (data_type : exp .DataType ) -> t .Tuple [t .Optional [int ], t .Optional [int ]]:
85+ def _extract_type_precision_scale (data_type : exp .DataType ) -> t .Tuple [t .Optional [int ], t .Optional [int ]]:
8586 """Extract precision and scale from a parameterized numeric type."""
8687 expressions = data_type .expressions
8788 if not expressions :
@@ -103,78 +104,204 @@ def _extract_precision_scale(data_type: exp.DataType) -> t.Tuple[t.Optional[int]
103104 return precision , scale
104105
105106
106- def _compute_nullif_result_type (
107- base_type : exp .DataType , p1 : int , s1 : int , p2 : int , s2 : int
107+ def _extract_literal_precision_scale (num_str : str ) -> t .Tuple [int , int ]:
108+ d = Decimal (num_str ).normalize ()
109+ s = format (d , "f" ).lstrip ("-" )
110+
111+ if "." in s :
112+ int_part , frac_part = s .split ("." , 1 )
113+ precision = len (int_part + frac_part )
114+ scale = len (frac_part .rstrip ("0" ))
115+ else :
116+ precision = len (s )
117+ scale = 0
118+ return precision , scale
119+
120+
121+ def _is_float (type_ : t .Union [exp .DataType , exp .DataType .Type , None ]) -> bool :
122+ return isinstance (type_ , exp .DataType ) and type_ .is_type (exp .DataType .Type .FLOAT )
123+
124+
125+ def _is_parameterized_numeric (type_ : t .Union [exp .DataType , exp .DataType .Type , None ]) -> bool :
126+ return (
127+ isinstance (type_ , exp .DataType )
128+ and type_ .is_type (* exp .DataType .NUMERIC_TYPES )
129+ and bool (type_ .expressions )
130+ )
131+
132+
133+ def _get_normalized_type (
134+ expression : exp .Expression ,
135+ ) -> t .Union [exp .DataType , exp .DataType .Type , None ]:
136+ """
137+ Normalizes numeric expressions to their parameterized representation.
138+ For literal numbers, return the parameterized representation.
139+ For integer types, return NUMBER(38, 0).
140+ """
141+ if isinstance (expression , exp .Literal ) and expression .is_number :
142+ precision , scale = _extract_literal_precision_scale (expression .this )
143+ return exp .DataType (
144+ this = exp .DataType .Type .DECIMAL ,
145+ expressions = [
146+ exp .DataTypeParam (this = exp .Literal .number (precision )),
147+ exp .DataTypeParam (this = exp .Literal .number (scale )),
148+ ],
149+ )
150+
151+ if expression .type .is_type (* exp .DataType .INTEGER_TYPES ) and not expression .type .expressions :
152+ return exp .DataType (
153+ this = exp .DataType .Type .DECIMAL ,
154+ expressions = [
155+ exp .DataTypeParam (this = exp .Literal .number (38 )),
156+ exp .DataTypeParam (this = exp .Literal .number (0 )),
157+ ],
158+ )
159+
160+ return expression .type
161+
162+
163+ def _coerce_two_parameterized_types (
164+ type1 : exp .DataType , p1 : int , s1 : int , type2 : exp .DataType , p2 : int , s2 : int
108165) -> t .Optional [exp .DataType ]:
109166 """
110- Compute result type for NULLIF with two parameterized numeric types .
167+ Coerce two parameterized numeric types using Snowflake's type coercion rules .
111168
112169 Rules:
113170 - If p1 >= p2 AND s1 >= s2: return type1
114171 - If p2 >= p1 AND s2 >= s1: return type2
115- - Otherwise: return DECIMAL( max(p1, p2) + |s2 - s1|, max(s1, s2))
172+ - Otherwise: return NUMBER(min(38, max(p1, p2) + |s2 - s1|) , max(s1, s2))
116173 """
117174
118175 if p1 >= p2 and s1 >= s2 :
119- return base_type .copy ()
176+ return type1 .copy ()
120177
121178 if p2 >= p1 and s2 >= s1 :
122- return exp .DataType (
123- this = base_type .this ,
124- expressions = [
125- exp .DataTypeParam (this = exp .Literal .number (p2 )),
126- exp .DataTypeParam (this = exp .Literal .number (s2 )),
127- ],
128- )
179+ return type2 .copy ()
129180
130181 result_scale = max (s1 , s2 )
131- result_precision = max (p1 , p2 ) + abs (s2 - s1 )
182+ result_precision = min ( 38 , max (p1 , p2 ) + abs (s2 - s1 ) )
132183
133184 return exp .DataType (
134- this = base_type .this ,
185+ this = type1 .this ,
135186 expressions = [
136187 exp .DataTypeParam (this = exp .Literal .number (result_precision )),
137188 exp .DataTypeParam (this = exp .Literal .number (result_scale )),
138189 ],
139190 )
140191
141192
142- def _annotate_nullif (self : TypeAnnotator , expression : exp .Nullif ) -> exp .Nullif :
193+ def _coerce_parameterized_numeric_types (
194+ self : TypeAnnotator , types : t .List [t .Union [exp .DataType , exp .DataType .Type , None ]]
195+ ) -> t .Optional [t .Union [exp .DataType , exp .DataType .Type ]]:
143196 """
144- Annotate NULLIF with Snowflake-specific type coercion rules for parameterized numeric types.
145-
146- When both arguments are parameterized numeric types (e.g., DECIMAL(p, s)):
147- - If one type dominates (p1 >= p2 AND s1 >= s2), use that type
148- - Otherwise, compute new type with:
149- - scale = max(s1, s2)
150- - precision = max(p1, p2) + |s2 - s1|
197+ Generalized function to coerce multiple parameterized numeric types.
198+ Applies Snowflake's coercion logic pairwise across all types.
151199 """
200+ if not types :
201+ return None
202+
203+ result_type = None
204+
205+ for current_type in types :
206+ if not current_type :
207+ continue
208+
209+ if result_type is None :
210+ result_type = current_type
211+ continue
212+
213+ if _is_parameterized_numeric (result_type ) and _is_parameterized_numeric (current_type ):
214+ p1 , s1 = _extract_type_precision_scale (result_type )
215+ p2 , s2 = _extract_type_precision_scale (current_type )
216+
217+ if p1 is not None and s1 is not None and p2 is not None and s2 is not None :
218+ result_type = _coerce_two_parameterized_types (
219+ result_type , p1 , s1 , current_type , p2 , s2
220+ )
221+ else :
222+ result_type = self ._maybe_coerce (result_type , current_type )
223+ else :
224+ result_type = self ._maybe_coerce (result_type , current_type )
225+
226+ return result_type
227+
152228
229+ def _apply_numeric_coercion (
230+ self : TypeAnnotator ,
231+ expression : exp .Expression ,
232+ expressions_to_coerce : t .List [exp .Expression ],
233+ ) -> t .Optional [exp .Expression ]:
234+ if any (_is_float (e .type ) for e in expressions_to_coerce ):
235+ self ._set_type (expression , exp .DataType .Type .FLOAT )
236+ return expression
237+
238+ if any (_is_parameterized_numeric (e .type ) for e in expressions_to_coerce ):
239+ normalized_types = [_get_normalized_type (e ) for e in expressions_to_coerce ]
240+ result_type = _coerce_parameterized_numeric_types (self , normalized_types )
241+ if result_type :
242+ self ._set_type (expression , result_type )
243+ return expression
244+
245+ return None
246+
247+
248+ def _annotate_nullif (self : TypeAnnotator , expression : exp .Nullif ) -> exp .Nullif :
153249 self ._annotate_args (expression )
154250
155- this_type = expression .this .type
156- expr_type = expression .expression .type
251+ expressions_to_coerce = []
252+ if expression .this :
253+ expressions_to_coerce .append (expression .this )
254+ if expression .expression :
255+ expressions_to_coerce .append (expression .expression )
157256
158- if not this_type or not expr_type :
257+ coerced_result = _apply_numeric_coercion (self , expression , expressions_to_coerce )
258+ if coerced_result is None :
159259 return self ._annotate_by_args (expression , "this" , "expression" )
160260
161- # Snowflake specific type coercion for NULLIF with parameterized numeric types
162- if (
163- this_type .is_type (* exp .DataType .NUMERIC_TYPES )
164- and expr_type .is_type (* exp .DataType .NUMERIC_TYPES )
165- and this_type .expressions
166- and expr_type .expressions
167- ):
168- p1 , s1 = _extract_precision_scale (this_type )
169- p2 , s2 = _extract_precision_scale (expr_type )
261+ return coerced_result
262+
263+
264+ def _annotate_iff (self : TypeAnnotator , expression : exp .If ) -> exp .If :
265+ self ._annotate_args (expression )
266+
267+ expressions_to_coerce = []
268+ true_expr = expression .args .get ("true" )
269+ false_expr = expression .args .get ("false" )
270+
271+ if true_expr :
272+ expressions_to_coerce .append (true_expr )
273+ if false_expr :
274+ expressions_to_coerce .append (false_expr )
275+
276+ coerced_result = _apply_numeric_coercion (self , expression , expressions_to_coerce )
277+ if coerced_result is None :
278+ return self ._annotate_by_args (expression , "true" , "false" )
279+
280+ return coerced_result
281+
282+
283+ def _annotate_with_numeric_coercion (
284+ self : TypeAnnotator , expression : exp .Expression
285+ ) -> exp .Expression :
286+ """
287+ Generic annotator for functions that return one of their numeric arguments.
288+
289+ These functions all have the same structure: 'this' + 'expressions' arguments,
290+ and they all need to coerce all argument types to find a common result type.
291+ """
292+ self ._annotate_args (expression )
293+
294+ expressions_to_coerce = []
295+ if expression .this :
296+ expressions_to_coerce .append (expression .this )
297+ if expression .expressions :
298+ expressions_to_coerce .extend (expression .expressions )
170299
171- if p1 is not None and s1 is not None and p2 is not None and s2 is not None :
172- result_type = _compute_nullif_result_type (this_type , p1 , s1 , p2 , s2 )
173- if result_type :
174- self ._set_type (expression , result_type )
175- return expression
300+ coerced_result = _apply_numeric_coercion (self , expression , expressions_to_coerce )
301+ if coerced_result is None :
302+ return self ._annotate_by_args (expression , "this" , "expressions" )
176303
177- return self . _annotate_by_args ( expression , "this" , "expression" )
304+ return coerced_result
178305
179306
180307EXPRESSION_METADATA = {
@@ -344,6 +471,7 @@ def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
344471 exp .Uuid ,
345472 }
346473 },
474+ exp .Coalesce : {"annotator" : _annotate_with_numeric_coercion },
347475 exp .ConcatWs : {"annotator" : lambda self , e : self ._annotate_by_args (e , "expressions" )},
348476 exp .ConvertTimezone : {
349477 "annotator" : lambda self , e : self ._annotate_with_type (
@@ -355,9 +483,12 @@ def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
355483 },
356484 exp .DateAdd : {"annotator" : _annotate_date_or_time_add },
357485 exp .DecodeCase : {"annotator" : _annotate_decode_case },
486+ exp .Greatest : {"annotator" : _annotate_with_numeric_coercion },
358487 exp .GreatestIgnoreNulls : {
359488 "annotator" : lambda self , e : self ._annotate_by_args (e , "expressions" )
360489 },
490+ exp .If : {"annotator" : _annotate_iff },
491+ exp .Least : {"annotator" : _annotate_with_numeric_coercion },
361492 exp .LeastIgnoreNulls : {"annotator" : lambda self , e : self ._annotate_by_args (e , "expressions" )},
362493 exp .Nullif : {"annotator" : _annotate_nullif },
363494 exp .Reverse : {"annotator" : _annotate_reverse },
0 commit comments