Skip to content

Commit b2c36e8

Browse files
committed
fix batch equal to logic.
1 parent 379132a commit b2c36e8

File tree

1 file changed

+34
-14
lines changed

1 file changed

+34
-14
lines changed

datafusion/physical-plan/src/aggregates/group_values/aggregation_hashtable.rs

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,37 @@ struct AggregationHashTable<T: AggregationHashTableEntry> {
2828
/// Pending indices needed to batch compare after
2929
/// If equal after comparing, this group exist
3030
/// If not equal after comparing, prob the next bucket
31-
pending_equal_to_indices: Vec<usize>,
31+
pending_equal_to_ctxs: Vec<EqualToContext>,
3232

3333
/// The non equal indices in one round
3434
non_equal_indices: Vec<usize>,
3535
}
3636

37+
pub struct EqualToContext {
38+
row_index: usize,
39+
group_index: usize,
40+
}
41+
3742
#[derive(Debug, Clone)]
3843
pub enum AggregationHashTableEntryState<T: AggregationHashTableEntry> {
3944
Entry(T),
40-
PlaceHolder,
45+
PlaceHolder(usize),
4146
Empty,
4247
}
4348

4449
trait AggregationHashTableEntry: fmt::Debug + Clone + Default {
4550
fn check_hash(&self, target_hash: u64) -> bool;
4651

47-
fn get_hash(&self, random_state: &RandomState) -> u64;
52+
fn hash_value(&self, random_state: &RandomState) -> u64;
53+
54+
fn group_index(&self) -> usize;
4855
}
4956

5057
impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
5158
fn get_or_create_groups<A, E, C>(
5259
&mut self,
5360
cols: &[ArrayRef],
61+
mut current_total_groups: usize,
5462
batch_hashes: &[u64],
5563
random_state: &RandomState,
5664
mut batch_append: A,
@@ -59,7 +67,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
5967
groups: &mut Vec<usize>,
6068
) where
6169
A: FnMut(&[ArrayRef], &[usize], &mut Vec<usize>),
62-
E: FnMut(&[ArrayRef], &[usize], &mut Vec<usize>, &mut Vec<usize>),
70+
E: FnMut(&[ArrayRef], &[EqualToContext], &mut Vec<usize>, &mut Vec<usize>),
6371
C: Fn(usize, u64) -> T,
6472
{
6573
let num_input_rows = cols[0].len();
@@ -78,7 +86,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
7886
self.remaining_indices.extend(0..num_input_rows);
7987
while self.remaining_indices.len() > 0 {
8088
self.pending_append_indices.clear();
81-
self.pending_equal_to_indices.clear();
89+
self.pending_equal_to_ctxs.clear();
8290
self.non_equal_indices.clear();
8391

8492
// 3. Find entry in `raw_table` first, and check:
@@ -92,19 +100,31 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
92100

93101
let table_entry = match &table_entry {
94102
AggregationHashTableEntryState::Empty => {
95-
*table_entry = AggregationHashTableEntryState::PlaceHolder;
103+
self.pending_append_indices.push(row_idx);
104+
*table_entry = AggregationHashTableEntryState::PlaceHolder(
105+
current_total_groups,
106+
);
107+
current_total_groups += 1;
96108
continue;
97109
}
98-
AggregationHashTableEntryState::PlaceHolder => {
99-
self.non_equal_indices.push(row_idx);
110+
AggregationHashTableEntryState::PlaceHolder(group_index) => {
111+
let equal_to_ctx = EqualToContext {
112+
row_index: row_idx,
113+
group_index: *group_index,
114+
};
115+
self.pending_equal_to_ctxs.push(equal_to_ctx);
100116
continue;
101117
}
102118
AggregationHashTableEntryState::Entry(entry) => entry,
103119
};
104120

105121
let target_hash = batch_hashes[row_idx];
106122
if table_entry.check_hash(target_hash) {
107-
self.pending_equal_to_indices.push(row_idx);
123+
let equal_to_ctx = EqualToContext {
124+
row_index: row_idx,
125+
group_index: table_entry.group_index(),
126+
};
127+
self.pending_equal_to_ctxs.push(equal_to_ctx);
108128
} else {
109129
self.non_equal_indices.push(row_idx);
110130
}
@@ -122,7 +142,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
122142
let offset = self.table_offsets[row_idx];
123143
debug_assert!(matches!(
124144
&self.raw_table[offset],
125-
AggregationHashTableEntryState::PlaceHolder
145+
AggregationHashTableEntryState::PlaceHolder(_)
126146
));
127147
self.raw_table[offset] =
128148
AggregationHashTableEntryState::Entry(new_entry(group_idx, hash));
@@ -133,10 +153,10 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
133153
}
134154

135155
// 5. Batch equal to
136-
if !self.pending_equal_to_indices.is_empty() {
156+
if !self.pending_equal_to_ctxs.is_empty() {
137157
batch_equal_to(
138158
&cols,
139-
&self.pending_equal_to_indices,
159+
&self.pending_equal_to_ctxs,
140160
&mut self.non_equal_indices,
141161
groups,
142162
);
@@ -179,10 +199,10 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
179199
for state in old_raw_table {
180200
debug_assert!(!matches!(
181201
&state,
182-
AggregationHashTableEntryState::PlaceHolder
202+
AggregationHashTableEntryState::PlaceHolder(_)
183203
));
184204
if let AggregationHashTableEntryState::Entry(entry) = &state {
185-
let hash = entry.get_hash(random_state);
205+
let hash = entry.hash_value(random_state);
186206
let offset = (hash & self.bit_mask) as usize;
187207
self.raw_table[offset] = state;
188208
}

0 commit comments

Comments
 (0)