Skip to content

Commit 9b0f4e2

Browse files
alambjoroKr21
andauthored
Fix DistinctCount for timestamps with time zone (#10043) (#10105)
* Fix DistinctCount for timestamps with time zone Preserve the original data type in the aggregation state * Add tests for decimal count distinct Co-authored-by: Georgi Krastev <[email protected]>
1 parent ec4a6da commit 9b0f4e2

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

datafusion/physical-expr/src/aggregate/count_distinct/mod.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount {
109109
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
110110
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
111111
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
112-
Decimal128(_, _) => {
113-
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
114-
}
115-
Decimal256(_, _) => {
116-
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
117-
}
112+
dt @ Decimal128(_, _) => Box::new(
113+
PrimitiveDistinctCountAccumulator::<Decimal128Type>::new()
114+
.with_data_type(dt.clone()),
115+
),
116+
dt @ Decimal256(_, _) => Box::new(
117+
PrimitiveDistinctCountAccumulator::<Decimal256Type>::new()
118+
.with_data_type(dt.clone()),
119+
),
118120

119121
Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
120122
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
@@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount {
130132
Time64(Nanosecond) => {
131133
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
132134
}
133-
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
134-
TimestampMicrosecondType,
135-
>::new()),
136-
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
137-
TimestampMillisecondType,
138-
>::new()),
139-
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
140-
TimestampNanosecondType,
141-
>::new()),
142-
Timestamp(Second, _) => {
143-
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
144-
}
135+
dt @ Timestamp(Microsecond, _) => Box::new(
136+
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new()
137+
.with_data_type(dt.clone()),
138+
),
139+
dt @ Timestamp(Millisecond, _) => Box::new(
140+
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new()
141+
.with_data_type(dt.clone()),
142+
),
143+
dt @ Timestamp(Nanosecond, _) => Box::new(
144+
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new()
145+
.with_data_type(dt.clone()),
146+
),
147+
dt @ Timestamp(Second, _) => Box::new(
148+
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new()
149+
.with_data_type(dt.clone()),
150+
),
145151

146152
Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
147153
Float32 => Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),

datafusion/physical-expr/src/aggregate/count_distinct/native.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use ahash::RandomState;
3030
use arrow::array::ArrayRef;
3131
use arrow_array::types::ArrowPrimitiveType;
3232
use arrow_array::PrimitiveArray;
33+
use arrow_schema::DataType;
3334

3435
use datafusion_common::cast::{as_list_array, as_primitive_array};
3536
use datafusion_common::utils::array_into_list_array;
@@ -45,6 +46,7 @@ where
4546
T::Native: Eq + Hash,
4647
{
4748
values: HashSet<T::Native, RandomState>,
49+
data_type: DataType,
4850
}
4951

5052
impl<T> PrimitiveDistinctCountAccumulator<T>
@@ -55,8 +57,14 @@ where
5557
pub(super) fn new() -> Self {
5658
Self {
5759
values: HashSet::default(),
60+
data_type: T::DATA_TYPE,
5861
}
5962
}
63+
64+
pub(super) fn with_data_type(mut self, data_type: DataType) -> Self {
65+
self.data_type = data_type;
66+
self
67+
}
6068
}
6169

6270
impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
@@ -65,9 +73,10 @@ where
6573
T::Native: Eq + Hash,
6674
{
6775
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
68-
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
69-
self.values.iter().cloned(),
70-
)) as ArrayRef;
76+
let arr = Arc::new(
77+
PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
78+
.with_data_type(self.data_type.clone()),
79+
);
7180
let list = Arc::new(array_into_list_array(arr));
7281
Ok(vec![ScalarValue::List(list)])
7382
}

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,18 +1876,22 @@ select
18761876
arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros,
18771877
arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis,
18781878
arrow_cast(column1, 'Timestamp(Second, None)') as secs,
1879+
arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc,
1880+
arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc,
1881+
arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc,
1882+
arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc,
18791883
column2 as names,
18801884
column3 as tag
18811885
from t_source;
18821886

18831887
# Demonstate the contents
1884-
query PPPPTT
1888+
query PPPPPPPPTT
18851889
select * from t;
18861890
----
1887-
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X
1888-
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 X
1889-
NULL NULL NULL NULL Row 2 Y
1890-
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 Y
1891+
2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 X
1892+
2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X
1893+
NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
1894+
2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y
18911895

18921896

18931897
# aggregate_timestamps_sum
@@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag
19331937
X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10
19341938
Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10
19351939

1940+
# aggregate_timestamps_count_distinct_with_tz
1941+
query IIII
1942+
SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t;
1943+
----
1944+
3 3 3 3
1945+
1946+
query TIIII
1947+
SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER BY tag;
1948+
----
1949+
X 2 2 2 2
1950+
Y 1 1 1 1
19361951

19371952
# aggregate_timestamps_avg
19381953
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\.
@@ -2285,6 +2300,18 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2
22852300
A 110.0045 Decimal128(14, 7)
22862301
B -100.0045 Decimal128(14, 7)
22872302

2303+
# aggregate_decimal_count_distinct
2304+
query I
2305+
select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table
2306+
----
2307+
4
2308+
2309+
query TI
2310+
select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2 ORDER BY c2
2311+
----
2312+
A 2
2313+
B 2
2314+
22882315
# Use PostgresSQL dialect
22892316
statement ok
22902317
set datafusion.sql_parser.dialect = 'Postgres';

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,5 +720,16 @@ select count(*),c1 from decimal256_simple group by c1 order by c1;
720720
4 0.00004
721721
5 0.00005
722722

723+
query I
724+
select count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple;
725+
----
726+
2
727+
728+
query BI
729+
select c4, count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple GROUP BY c4 ORDER BY c4;
730+
----
731+
false 2
732+
true 2
733+
723734
statement ok
724735
drop table decimal256_simple;

0 commit comments

Comments
 (0)