Skip to content

feat: topk functionality for aggregates should support utf8view and largeutf8 #15152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datafusion/physical-optimizer/src/topk_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ impl TopKAggregation {
}
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
if !kt.is_primitive() && kt != DataType::Utf8 {
if !kt.is_primitive()
&& kt != DataType::Utf8
&& kt != DataType::Utf8View
&& kt != DataType::LargeUtf8
{
return None;
}
if aggr.filter_expr().iter().any(|e| e.is_some()) {
Expand Down
72 changes: 58 additions & 14 deletions datafusion/physical-plan/src/aggregates/topk/hash_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use ahash::RandomState;
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
use arrow::array::{
builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef,
ArrowPrimitiveType, PrimitiveArray, StringArray,
ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray,
};
use arrow::datatypes::{i256, DataType};
use datafusion_common::DataFusionError;
Expand Down Expand Up @@ -88,6 +88,7 @@ pub struct StringHashTable {
owned: ArrayRef,
map: TopKHashTable<Option<String>>,
rnd: RandomState,
data_type: DataType,
}

// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key
Expand All @@ -101,13 +102,20 @@ where
}

impl StringHashTable {
pub fn new(limit: usize) -> Self {
pub fn new(limit: usize, data_type: DataType) -> Self {
let vals: Vec<&str> = Vec::new();
let owned = Arc::new(StringArray::from(vals));
let owned: ArrayRef = match data_type {
DataType::Utf8 => Arc::new(StringArray::from(vals)),
DataType::Utf8View => Arc::new(StringViewArray::from(vals)),
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)),
_ => panic!("Unsupported data type"),
};

Self {
owned,
map: TopKHashTable::new(limit, limit * 10),
rnd: RandomState::default(),
data_type,
}
}
}
Expand All @@ -131,7 +139,12 @@ impl ArrowHashTable for StringHashTable {

unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
let ids = self.map.take_all(indexes);
Arc::new(StringArray::from(ids))
match self.data_type {
DataType::Utf8 => Arc::new(StringArray::from(ids)),
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)),
DataType::Utf8View => Arc::new(StringViewArray::from(ids)),
_ => unreachable!(),
}
}

unsafe fn find_or_insert(
Expand All @@ -140,15 +153,44 @@ impl ArrowHashTable for StringHashTable {
replace_idx: usize,
mapper: &mut Vec<(usize, usize)>,
) -> (usize, bool) {
let ids = self
.owned
.as_any()
.downcast_ref::<StringArray>()
.expect("StringArray required");
let id = if ids.is_null(row_idx) {
None
} else {
Some(ids.value(row_idx))
let id = match self.data_type {
DataType::Utf8 => {
let ids = self
.owned
.as_any()
.downcast_ref::<StringArray>()
.expect("Expected StringArray for DataType::Utf8");
if ids.is_null(row_idx) {
None
} else {
Some(ids.value(row_idx))
}
}
DataType::LargeUtf8 => {
let ids = self
.owned
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("Expected LargeStringArray for DataType::LargeUtf8");
if ids.is_null(row_idx) {
None
} else {
Some(ids.value(row_idx))
}
}
DataType::Utf8View => {
let ids = self
.owned
.as_any()
.downcast_ref::<StringViewArray>()
.expect("Expected StringViewArray for DataType::Utf8View");
if ids.is_null(row_idx) {
None
} else {
Some(ids.value(row_idx))
}
}
_ => panic!("Unsupported data type"),
};

let hash = self.rnd.hash_one(id);
Expand Down Expand Up @@ -377,7 +419,9 @@ pub fn new_hash_table(

downcast_primitive! {
kt => (downcast_helper, kt),
DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))),
DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))),
DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))),
DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))),
_ => {}
}

Expand Down
72 changes: 71 additions & 1 deletion datafusion/physical-plan/src/aggregates/topk/priority_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,67 @@ impl PriorityMap {
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, RecordBatch, StringArray};
use arrow::array::{
Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray,
};
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::util::pretty::pretty_format_batches;
use std::sync::Arc;

#[test]
fn should_append_with_utf8view() -> Result<()> {
let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"]));
let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?;
agg.set_batch(ids, vals);
agg.insert(0)?;

let cols = agg.emit()?;
let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?;
let batch_schema = batch.schema();
assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View);

let actual = format!("{}", pretty_format_batches(&[batch])?);
let expected = r#"
+----------+--------------+
| trace_id | timestamp_ms |
+----------+--------------+
| 1 | 1 |
+----------+--------------+
"#
.trim();
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn should_append_with_large_utf8() -> Result<()> {
let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"]));
let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?;
agg.set_batch(ids, vals);
agg.insert(0)?;

let cols = agg.emit()?;
let batch = RecordBatch::try_new(test_large_schema(), cols)?;
let batch_schema = batch.schema();
assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8);

let actual = format!("{}", pretty_format_batches(&[batch])?);
let expected = r#"
+----------+--------------+
| trace_id | timestamp_ms |
+----------+--------------+
| 1 | 1 |
+----------+--------------+
"#
.trim();
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn should_append() -> Result<()> {
let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"]));
Expand Down Expand Up @@ -370,4 +426,18 @@ mod tests {
Field::new("timestamp_ms", DataType::Int64, true),
]))
}

fn test_schema_utf8view() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("trace_id", DataType::Utf8View, true),
Field::new("timestamp_ms", DataType::Int64, true),
]))
}

fn test_large_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("trace_id", DataType::LargeUtf8, true),
Field::new("timestamp_ms", DataType::Int64, true),
]))
}
}
58 changes: 57 additions & 1 deletion datafusion/sqllogictest/test_files/aggregates_topk.slt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#######
# Setup test data table
#######

# TopK aggregation
statement ok
CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES
Expand Down Expand Up @@ -214,5 +213,62 @@ a -1 -1
NULL 0 0
c 1 2


# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152
# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation
statement ok
CREATE TABLE traces_utf8view
AS SELECT
arrow_cast(trace_id, 'Utf8View') as trace_id,
timestamp,
other
FROM traces;

query TT
explain select trace_id, MAX(timestamp) from traces_utf8view group by trace_id order by MAX(timestamp) desc limit 4;
----
logical_plan
01)Sort: max(traces_utf8view.timestamp) DESC NULLS FIRST, fetch=4
02)--Aggregate: groupBy=[[traces_utf8view.trace_id]], aggr=[[max(traces_utf8view.timestamp)]]
03)----TableScan: traces_utf8view projection=[trace_id, timestamp]
physical_plan
01)SortPreservingMergeExec: [max(traces_utf8view.timestamp)@1 DESC], fetch=4
02)--SortExec: TopK(fetch=4), expr=[max(traces_utf8view.timestamp)@1 DESC], preserve_partitioning=[true]
03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4]
04)------CoalesceBatchesExec: target_batch_size=8192
05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4]
08)--------------DataSourceExec: partitions=1, partition_sizes=[1]


# Also add LargeUtf8 to test PR https://github.com/apache/datafusion/pull/15152
# Before the PR, the test case would not work because the LargeUtf8 will not be supported by the TopK aggregation
statement ok
CREATE TABLE traces_largeutf8
AS SELECT
arrow_cast(trace_id, 'LargeUtf8') as trace_id,
timestamp,
other
FROM traces;

query TT
explain select trace_id, MAX(timestamp) from traces_largeutf8 group by trace_id order by MAX(timestamp) desc limit 4;
----
logical_plan
01)Sort: max(traces_largeutf8.timestamp) DESC NULLS FIRST, fetch=4
02)--Aggregate: groupBy=[[traces_largeutf8.trace_id]], aggr=[[max(traces_largeutf8.timestamp)]]
03)----TableScan: traces_largeutf8 projection=[trace_id, timestamp]
physical_plan
01)SortPreservingMergeExec: [max(traces_largeutf8.timestamp)@1 DESC], fetch=4
02)--SortExec: TopK(fetch=4), expr=[max(traces_largeutf8.timestamp)@1 DESC], preserve_partitioning=[true]
03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4]
04)------CoalesceBatchesExec: target_batch_size=8192
05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4]
08)--------------DataSourceExec: partitions=1, partition_sizes=[1]


statement ok
drop table traces;