Skip to content

Commit f8828ab

Browse files
authored
feat: topk functionality for aggregates should support utf8view and largeutf8 (#15152)
* feat: topk functionality for aggregates should support utf8view * Add testing for Utf8view slt * Add large utf8 support * Address comments
1 parent 072098e commit f8828ab

File tree

4 files changed

+191
-17
lines changed

4 files changed

+191
-17
lines changed

datafusion/physical-optimizer/src/topk_aggregation.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ impl TopKAggregation {
5656
}
5757
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
5858
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
59-
if !kt.is_primitive() && kt != DataType::Utf8 {
59+
if !kt.is_primitive()
60+
&& kt != DataType::Utf8
61+
&& kt != DataType::Utf8View
62+
&& kt != DataType::LargeUtf8
63+
{
6064
return None;
6165
}
6266
if aggr.filter_expr().iter().any(|e| e.is_some()) {

datafusion/physical-plan/src/aggregates/topk/hash_table.rs

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use ahash::RandomState;
2323
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
2424
use arrow::array::{
2525
builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef,
26-
ArrowPrimitiveType, PrimitiveArray, StringArray,
26+
ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray,
2727
};
2828
use arrow::datatypes::{i256, DataType};
2929
use datafusion_common::DataFusionError;
@@ -88,6 +88,7 @@ pub struct StringHashTable {
8888
owned: ArrayRef,
8989
map: TopKHashTable<Option<String>>,
9090
rnd: RandomState,
91+
data_type: DataType,
9192
}
9293

9394
// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key
@@ -101,13 +102,20 @@ where
101102
}
102103

103104
impl StringHashTable {
104-
pub fn new(limit: usize) -> Self {
105+
pub fn new(limit: usize, data_type: DataType) -> Self {
105106
let vals: Vec<&str> = Vec::new();
106-
let owned = Arc::new(StringArray::from(vals));
107+
let owned: ArrayRef = match data_type {
108+
DataType::Utf8 => Arc::new(StringArray::from(vals)),
109+
DataType::Utf8View => Arc::new(StringViewArray::from(vals)),
110+
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)),
111+
_ => panic!("Unsupported data type"),
112+
};
113+
107114
Self {
108115
owned,
109116
map: TopKHashTable::new(limit, limit * 10),
110117
rnd: RandomState::default(),
118+
data_type,
111119
}
112120
}
113121
}
@@ -131,7 +139,12 @@ impl ArrowHashTable for StringHashTable {
131139

132140
unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
133141
let ids = self.map.take_all(indexes);
134-
Arc::new(StringArray::from(ids))
142+
match self.data_type {
143+
DataType::Utf8 => Arc::new(StringArray::from(ids)),
144+
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)),
145+
DataType::Utf8View => Arc::new(StringViewArray::from(ids)),
146+
_ => unreachable!(),
147+
}
135148
}
136149

137150
unsafe fn find_or_insert(
@@ -140,15 +153,44 @@ impl ArrowHashTable for StringHashTable {
140153
replace_idx: usize,
141154
mapper: &mut Vec<(usize, usize)>,
142155
) -> (usize, bool) {
143-
let ids = self
144-
.owned
145-
.as_any()
146-
.downcast_ref::<StringArray>()
147-
.expect("StringArray required");
148-
let id = if ids.is_null(row_idx) {
149-
None
150-
} else {
151-
Some(ids.value(row_idx))
156+
let id = match self.data_type {
157+
DataType::Utf8 => {
158+
let ids = self
159+
.owned
160+
.as_any()
161+
.downcast_ref::<StringArray>()
162+
.expect("Expected StringArray for DataType::Utf8");
163+
if ids.is_null(row_idx) {
164+
None
165+
} else {
166+
Some(ids.value(row_idx))
167+
}
168+
}
169+
DataType::LargeUtf8 => {
170+
let ids = self
171+
.owned
172+
.as_any()
173+
.downcast_ref::<LargeStringArray>()
174+
.expect("Expected LargeStringArray for DataType::LargeUtf8");
175+
if ids.is_null(row_idx) {
176+
None
177+
} else {
178+
Some(ids.value(row_idx))
179+
}
180+
}
181+
DataType::Utf8View => {
182+
let ids = self
183+
.owned
184+
.as_any()
185+
.downcast_ref::<StringViewArray>()
186+
.expect("Expected StringViewArray for DataType::Utf8View");
187+
if ids.is_null(row_idx) {
188+
None
189+
} else {
190+
Some(ids.value(row_idx))
191+
}
192+
}
193+
_ => panic!("Unsupported data type"),
152194
};
153195

154196
let hash = self.rnd.hash_one(id);
@@ -377,7 +419,9 @@ pub fn new_hash_table(
377419

378420
downcast_primitive! {
379421
kt => (downcast_helper, kt),
380-
DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))),
422+
DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))),
423+
DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))),
424+
DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))),
381425
_ => {}
382426
}
383427

datafusion/physical-plan/src/aggregates/topk/priority_map.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,67 @@ impl PriorityMap {
108108
#[cfg(test)]
109109
mod tests {
110110
use super::*;
111-
use arrow::array::{Int64Array, RecordBatch, StringArray};
111+
use arrow::array::{
112+
Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray,
113+
};
112114
use arrow::datatypes::{Field, Schema, SchemaRef};
113115
use arrow::util::pretty::pretty_format_batches;
114116
use std::sync::Arc;
115117

118+
#[test]
119+
fn should_append_with_utf8view() -> Result<()> {
120+
let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"]));
121+
let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
122+
let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?;
123+
agg.set_batch(ids, vals);
124+
agg.insert(0)?;
125+
126+
let cols = agg.emit()?;
127+
let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?;
128+
let batch_schema = batch.schema();
129+
assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View);
130+
131+
let actual = format!("{}", pretty_format_batches(&[batch])?);
132+
let expected = r#"
133+
+----------+--------------+
134+
| trace_id | timestamp_ms |
135+
+----------+--------------+
136+
| 1 | 1 |
137+
+----------+--------------+
138+
"#
139+
.trim();
140+
assert_eq!(actual, expected);
141+
142+
Ok(())
143+
}
144+
145+
#[test]
146+
fn should_append_with_large_utf8() -> Result<()> {
147+
let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"]));
148+
let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
149+
let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?;
150+
agg.set_batch(ids, vals);
151+
agg.insert(0)?;
152+
153+
let cols = agg.emit()?;
154+
let batch = RecordBatch::try_new(test_large_schema(), cols)?;
155+
let batch_schema = batch.schema();
156+
assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8);
157+
158+
let actual = format!("{}", pretty_format_batches(&[batch])?);
159+
let expected = r#"
160+
+----------+--------------+
161+
| trace_id | timestamp_ms |
162+
+----------+--------------+
163+
| 1 | 1 |
164+
+----------+--------------+
165+
"#
166+
.trim();
167+
assert_eq!(actual, expected);
168+
169+
Ok(())
170+
}
171+
116172
#[test]
117173
fn should_append() -> Result<()> {
118174
let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"]));
@@ -370,4 +426,18 @@ mod tests {
370426
Field::new("timestamp_ms", DataType::Int64, true),
371427
]))
372428
}
429+
430+
fn test_schema_utf8view() -> SchemaRef {
431+
Arc::new(Schema::new(vec![
432+
Field::new("trace_id", DataType::Utf8View, true),
433+
Field::new("timestamp_ms", DataType::Int64, true),
434+
]))
435+
}
436+
437+
fn test_large_schema() -> SchemaRef {
438+
Arc::new(Schema::new(vec![
439+
Field::new("trace_id", DataType::LargeUtf8, true),
440+
Field::new("timestamp_ms", DataType::Int64, true),
441+
]))
442+
}
373443
}

datafusion/sqllogictest/test_files/aggregates_topk.slt

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#######
1919
# Setup test data table
2020
#######
21-
2221
# TopK aggregation
2322
statement ok
2423
CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES
@@ -214,5 +213,62 @@ a -1 -1
214213
NULL 0 0
215214
c 1 2
216215

216+
217+
# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152
218+
# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation
219+
statement ok
220+
CREATE TABLE traces_utf8view
221+
AS SELECT
222+
arrow_cast(trace_id, 'Utf8View') as trace_id,
223+
timestamp,
224+
other
225+
FROM traces;
226+
227+
query TT
228+
explain select trace_id, MAX(timestamp) from traces_utf8view group by trace_id order by MAX(timestamp) desc limit 4;
229+
----
230+
logical_plan
231+
01)Sort: max(traces_utf8view.timestamp) DESC NULLS FIRST, fetch=4
232+
02)--Aggregate: groupBy=[[traces_utf8view.trace_id]], aggr=[[max(traces_utf8view.timestamp)]]
233+
03)----TableScan: traces_utf8view projection=[trace_id, timestamp]
234+
physical_plan
235+
01)SortPreservingMergeExec: [max(traces_utf8view.timestamp)@1 DESC], fetch=4
236+
02)--SortExec: TopK(fetch=4), expr=[max(traces_utf8view.timestamp)@1 DESC], preserve_partitioning=[true]
237+
03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4]
238+
04)------CoalesceBatchesExec: target_batch_size=8192
239+
05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4
240+
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
241+
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4]
242+
08)--------------DataSourceExec: partitions=1, partition_sizes=[1]
243+
244+
245+
# Also add LargeUtf8 to test PR https://github.com/apache/datafusion/pull/15152
246+
# Before the PR, the test case would not work because the LargeUtf8 will not be supported by the TopK aggregation
247+
statement ok
248+
CREATE TABLE traces_largeutf8
249+
AS SELECT
250+
arrow_cast(trace_id, 'LargeUtf8') as trace_id,
251+
timestamp,
252+
other
253+
FROM traces;
254+
255+
query TT
256+
explain select trace_id, MAX(timestamp) from traces_largeutf8 group by trace_id order by MAX(timestamp) desc limit 4;
257+
----
258+
logical_plan
259+
01)Sort: max(traces_largeutf8.timestamp) DESC NULLS FIRST, fetch=4
260+
02)--Aggregate: groupBy=[[traces_largeutf8.trace_id]], aggr=[[max(traces_largeutf8.timestamp)]]
261+
03)----TableScan: traces_largeutf8 projection=[trace_id, timestamp]
262+
physical_plan
263+
01)SortPreservingMergeExec: [max(traces_largeutf8.timestamp)@1 DESC], fetch=4
264+
02)--SortExec: TopK(fetch=4), expr=[max(traces_largeutf8.timestamp)@1 DESC], preserve_partitioning=[true]
265+
03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4]
266+
04)------CoalesceBatchesExec: target_batch_size=8192
267+
05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4
268+
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
269+
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4]
270+
08)--------------DataSourceExec: partitions=1, partition_sizes=[1]
271+
272+
217273
statement ok
218274
drop table traces;

0 commit comments

Comments
 (0)