@@ -19,10 +19,13 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
19
19
use arrow:: compute:: kernels:: cast_utils:: {
20
20
parse_interval_month_day_nano_config, IntervalParseConfig , IntervalUnit ,
21
21
} ;
22
- use arrow:: datatypes:: DECIMAL128_MAX_PRECISION ;
23
- use arrow_schema:: DataType ;
22
+ use arrow:: datatypes:: { i256, DECIMAL128_MAX_PRECISION } ;
23
+ use arrow_schema:: { DataType , DECIMAL256_MAX_PRECISION } ;
24
+ use bigdecimal:: num_bigint:: BigInt ;
25
+ use bigdecimal:: { BigDecimal , Signed , ToPrimitive } ;
24
26
use datafusion_common:: {
25
- internal_err, not_impl_err, plan_err, DFSchema , DataFusionError , Result , ScalarValue ,
27
+ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema ,
28
+ DataFusionError , Result , ScalarValue ,
26
29
} ;
27
30
use datafusion_expr:: expr:: { BinaryExpr , Placeholder } ;
28
31
use datafusion_expr:: planner:: PlannerResult ;
@@ -31,6 +34,9 @@ use log::debug;
31
34
use sqlparser:: ast:: { BinaryOperator , Expr as SQLExpr , Interval , UnaryOperator , Value } ;
32
35
use sqlparser:: parser:: ParserError :: ParserError ;
33
36
use std:: borrow:: Cow ;
37
+ use std:: cmp:: Ordering ;
38
+ use std:: ops:: Neg ;
39
+ use std:: str:: FromStr ;
34
40
35
41
impl < S : ContextProvider > SqlToRel < ' _ , S > {
36
42
pub ( crate ) fn parse_value (
@@ -84,7 +90,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
84
90
}
85
91
86
92
if self . options . parse_float_as_decimal {
87
- parse_decimal_128 ( unsigned_number, negative)
93
+ parse_decimal ( unsigned_number, negative)
88
94
} else {
89
95
signed_number. parse :: < f64 > ( ) . map ( lit) . map_err ( |_| {
90
96
DataFusionError :: from ( ParserError ( format ! (
@@ -315,45 +321,84 @@ const fn try_decode_hex_char(c: u8) -> Option<u8> {
315
321
}
316
322
}
317
323
318
- /// Parse Decimal128 from a string
319
- ///
320
- /// TODO: support parsing from scientific notation
321
- fn parse_decimal_128 ( unsigned_number : & str , negative : bool ) -> Result < Expr > {
322
- // remove leading zeroes
323
- let trimmed = unsigned_number. trim_start_matches ( '0' ) ;
324
- // Parse precision and scale, remove decimal point if exists
325
- let ( precision, scale, replaced_str) = if trimmed == "." {
326
- // Special cases for numbers such as “0.”, “000.”, and so on.
327
- ( 1 , 0 , Cow :: Borrowed ( "0" ) )
328
- } else if let Some ( i) = trimmed. find ( '.' ) {
329
- (
330
- trimmed. len ( ) - 1 ,
331
- trimmed. len ( ) - i - 1 ,
332
- Cow :: Owned ( trimmed. replace ( '.' , "" ) ) ,
333
- )
334
- } else {
335
- // No decimal point, keep as is
336
- ( trimmed. len ( ) , 0 , Cow :: Borrowed ( trimmed) )
337
- } ;
324
+ /// Returns None if the value can't be converted to i256.
325
+ /// Modified from <https://github.com/apache/arrow-rs/blob/c4dbf0d8af6ca5a19b8b2ea777da3c276807fc5e/arrow-buffer/src/bigint/mod.rs#L303>
326
+ fn bigint_to_i256 ( v : & BigInt ) -> Option < i256 > {
327
+ let v_bytes = v. to_signed_bytes_le ( ) ;
328
+ match v_bytes. len ( ) . cmp ( & 32 ) {
329
+ Ordering :: Less => {
330
+ let mut bytes = if v. is_negative ( ) {
331
+ [ 255_u8 ; 32 ]
332
+ } else {
333
+ [ 0 ; 32 ]
334
+ } ;
335
+ bytes[ 0 ..v_bytes. len ( ) ] . copy_from_slice ( & v_bytes[ ..v_bytes. len ( ) ] ) ;
336
+ Some ( i256:: from_le_bytes ( bytes) )
337
+ }
338
+ Ordering :: Equal => Some ( i256:: from_le_bytes ( v_bytes. try_into ( ) . unwrap ( ) ) ) ,
339
+ Ordering :: Greater => None ,
340
+ }
341
+ }
338
342
339
- let number = replaced_str. parse :: < i128 > ( ) . map_err ( |e| {
343
+ fn parse_decimal ( unsigned_number : & str , negative : bool ) -> Result < Expr > {
344
+ let mut dec = BigDecimal :: from_str ( unsigned_number) . map_err ( |e| {
340
345
DataFusionError :: from ( ParserError ( format ! (
341
- "Cannot parse {replaced_str } as i128 when building decimal : {e}"
346
+ "Cannot parse {unsigned_number } as BigDecimal : {e}"
342
347
) ) )
343
348
} ) ?;
344
-
345
- // Check precision overflow
346
- if precision as u8 > DECIMAL128_MAX_PRECISION {
347
- return Err ( DataFusionError :: from ( ParserError ( format ! (
348
- "Cannot parse {replaced_str} as i128 when building decimal: precision overflow"
349
- ) ) ) ) ;
349
+ if negative {
350
+ dec = dec. neg ( ) ;
350
351
}
351
352
352
- Ok ( Expr :: Literal ( ScalarValue :: Decimal128 (
353
- Some ( if negative { -number } else { number } ) ,
354
- precision as u8 ,
355
- scale as i8 ,
356
- ) ) )
353
+ let digits = dec. digits ( ) ;
354
+ let ( int_val, scale) = dec. into_bigint_and_exponent ( ) ;
355
+ if scale < i8:: MIN as i64 {
356
+ return not_impl_err ! (
357
+ "Decimal scale {} exceeds the minimum supported scale: {}" ,
358
+ scale,
359
+ i8 :: MIN
360
+ ) ;
361
+ }
362
+ let precision = if scale > 0 {
363
+ // arrow-rs requires the precision to include the positive scale.
364
+ // See <https://github.com/apache/arrow-rs/blob/123045cc766d42d1eb06ee8bb3f09e39ea995ddc/arrow-array/src/types.rs#L1230>
365
+ std:: cmp:: max ( digits, scale. unsigned_abs ( ) )
366
+ } else {
367
+ digits
368
+ } ;
369
+ if precision <= DECIMAL128_MAX_PRECISION as u64 {
370
+ let val = int_val. to_i128 ( ) . ok_or_else ( || {
371
+ // Failures are unexpected here as we have already checked the precision
372
+ internal_datafusion_err ! (
373
+ "Unexpected overflow when converting {} to i128" ,
374
+ int_val
375
+ )
376
+ } ) ?;
377
+ Ok ( Expr :: Literal ( ScalarValue :: Decimal128 (
378
+ Some ( val) ,
379
+ precision as u8 ,
380
+ scale as i8 ,
381
+ ) ) )
382
+ } else if precision <= DECIMAL256_MAX_PRECISION as u64 {
383
+ let val = bigint_to_i256 ( & int_val) . ok_or_else ( || {
384
+ // Failures are unexpected here as we have already checked the precision
385
+ internal_datafusion_err ! (
386
+ "Unexpected overflow when converting {} to i256" ,
387
+ int_val
388
+ )
389
+ } ) ?;
390
+ Ok ( Expr :: Literal ( ScalarValue :: Decimal256 (
391
+ Some ( val) ,
392
+ precision as u8 ,
393
+ scale as i8 ,
394
+ ) ) )
395
+ } else {
396
+ not_impl_err ! (
397
+ "Decimal precision {} exceeds the maximum supported precision: {}" ,
398
+ precision,
399
+ DECIMAL256_MAX_PRECISION
400
+ )
401
+ }
357
402
}
358
403
359
404
#[ cfg( test) ]
@@ -379,4 +424,79 @@ mod tests {
379
424
assert_eq ! ( output, expect) ;
380
425
}
381
426
}
427
+
428
+ #[ test]
429
+ fn test_bigint_to_i256 ( ) {
430
+ let cases = [
431
+ ( BigInt :: from ( 0 ) , Some ( i256:: from ( 0 ) ) ) ,
432
+ ( BigInt :: from ( 1 ) , Some ( i256:: from ( 1 ) ) ) ,
433
+ ( BigInt :: from ( -1 ) , Some ( i256:: from ( -1 ) ) ) ,
434
+ (
435
+ BigInt :: from_str ( i256:: MAX . to_string ( ) . as_str ( ) ) . unwrap ( ) ,
436
+ Some ( i256:: MAX ) ,
437
+ ) ,
438
+ (
439
+ BigInt :: from_str ( i256:: MIN . to_string ( ) . as_str ( ) ) . unwrap ( ) ,
440
+ Some ( i256:: MIN ) ,
441
+ ) ,
442
+ (
443
+ // Can't fit into i256
444
+ BigInt :: from_str ( ( i256:: MAX . to_string ( ) + "1" ) . as_str ( ) ) . unwrap ( ) ,
445
+ None ,
446
+ ) ,
447
+ ] ;
448
+
449
+ for ( input, expect) in cases {
450
+ let output = bigint_to_i256 ( & input) ;
451
+ assert_eq ! ( output, expect) ;
452
+ }
453
+ }
454
+
455
+ #[ test]
456
+ fn test_parse_decimal ( ) {
457
+ // Supported cases
458
+ let cases = [
459
+ ( "0" , ScalarValue :: Decimal128 ( Some ( 0 ) , 1 , 0 ) ) ,
460
+ ( "1" , ScalarValue :: Decimal128 ( Some ( 1 ) , 1 , 0 ) ) ,
461
+ ( "123.45" , ScalarValue :: Decimal128 ( Some ( 12345 ) , 5 , 2 ) ) ,
462
+ // Digit count is less than scale
463
+ ( "0.001" , ScalarValue :: Decimal128 ( Some ( 1 ) , 3 , 3 ) ) ,
464
+ // Scientific notation
465
+ ( "123.456e-2" , ScalarValue :: Decimal128 ( Some ( 123456 ) , 6 , 5 ) ) ,
466
+ // Negative scale
467
+ ( "123456e128" , ScalarValue :: Decimal128 ( Some ( 123456 ) , 6 , -128 ) ) ,
468
+ // Decimal256
469
+ (
470
+ & ( "9" . repeat ( 39 ) + "." + "99999" ) ,
471
+ ScalarValue :: Decimal256 (
472
+ Some ( i256:: from_string ( & "9" . repeat ( 44 ) ) . unwrap ( ) ) ,
473
+ 44 ,
474
+ 5 ,
475
+ ) ,
476
+ ) ,
477
+ ] ;
478
+ for ( input, expect) in cases {
479
+ let output = parse_decimal ( input, true ) . unwrap ( ) ;
480
+ assert_eq ! ( output, Expr :: Literal ( expect. arithmetic_negate( ) . unwrap( ) ) ) ;
481
+
482
+ let output = parse_decimal ( input, false ) . unwrap ( ) ;
483
+ assert_eq ! ( output, Expr :: Literal ( expect) ) ;
484
+ }
485
+
486
+ // scale < i8::MIN
487
+ assert_eq ! (
488
+ parse_decimal( "1e129" , false )
489
+ . unwrap_err( )
490
+ . strip_backtrace( ) ,
491
+ "This feature is not implemented: Decimal scale -129 exceeds the minimum supported scale: -128"
492
+ ) ;
493
+
494
+ // Unsupported precision
495
+ assert_eq ! (
496
+ parse_decimal( & "1" . repeat( 77 ) , false )
497
+ . unwrap_err( )
498
+ . strip_backtrace( ) ,
499
+ "This feature is not implemented: Decimal precision 77 exceeds the maximum supported precision: 76"
500
+ ) ;
501
+ }
382
502
}
0 commit comments