@@ -28,29 +28,37 @@ struct AggregationHashTable<T: AggregationHashTableEntry> {
28
28
/// Pending indices needed to batch compare after
29
29
/// If equal after comparing, this group exist
30
30
/// If not equal after comparing, prob the next bucket
31
- pending_equal_to_indices : Vec < usize > ,
31
+ pending_equal_to_ctxs : Vec < EqualToContext > ,
32
32
33
33
/// The non equal indices in one round
34
34
non_equal_indices : Vec < usize > ,
35
35
}
36
36
37
+ pub struct EqualToContext {
38
+ row_index : usize ,
39
+ group_index : usize ,
40
+ }
41
+
37
42
#[ derive( Debug , Clone ) ]
38
43
pub enum AggregationHashTableEntryState < T : AggregationHashTableEntry > {
39
44
Entry ( T ) ,
40
- PlaceHolder ,
45
+ PlaceHolder ( usize ) ,
41
46
Empty ,
42
47
}
43
48
44
49
trait AggregationHashTableEntry : fmt:: Debug + Clone + Default {
45
50
fn check_hash ( & self , target_hash : u64 ) -> bool ;
46
51
47
- fn get_hash ( & self , random_state : & RandomState ) -> u64 ;
52
+ fn hash_value ( & self , random_state : & RandomState ) -> u64 ;
53
+
54
+ fn group_index ( & self ) -> usize ;
48
55
}
49
56
50
57
impl < T : AggregationHashTableEntry > AggregationHashTable < T > {
51
58
fn get_or_create_groups < A , E , C > (
52
59
& mut self ,
53
60
cols : & [ ArrayRef ] ,
61
+ mut current_total_groups : usize ,
54
62
batch_hashes : & [ u64 ] ,
55
63
random_state : & RandomState ,
56
64
mut batch_append : A ,
@@ -59,7 +67,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
59
67
groups : & mut Vec < usize > ,
60
68
) where
61
69
A : FnMut ( & [ ArrayRef ] , & [ usize ] , & mut Vec < usize > ) ,
62
- E : FnMut ( & [ ArrayRef ] , & [ usize ] , & mut Vec < usize > , & mut Vec < usize > ) ,
70
+ E : FnMut ( & [ ArrayRef ] , & [ EqualToContext ] , & mut Vec < usize > , & mut Vec < usize > ) ,
63
71
C : Fn ( usize , u64 ) -> T ,
64
72
{
65
73
let num_input_rows = cols[ 0 ] . len ( ) ;
@@ -78,7 +86,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
78
86
self . remaining_indices . extend ( 0 ..num_input_rows) ;
79
87
while self . remaining_indices . len ( ) > 0 {
80
88
self . pending_append_indices . clear ( ) ;
81
- self . pending_equal_to_indices . clear ( ) ;
89
+ self . pending_equal_to_ctxs . clear ( ) ;
82
90
self . non_equal_indices . clear ( ) ;
83
91
84
92
// 3. Find entry in `raw_table` first, and check:
@@ -92,19 +100,31 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
92
100
93
101
let table_entry = match & table_entry {
94
102
AggregationHashTableEntryState :: Empty => {
95
- * table_entry = AggregationHashTableEntryState :: PlaceHolder ;
103
+ self . pending_append_indices . push ( row_idx) ;
104
+ * table_entry = AggregationHashTableEntryState :: PlaceHolder (
105
+ current_total_groups,
106
+ ) ;
107
+ current_total_groups += 1 ;
96
108
continue ;
97
109
}
98
- AggregationHashTableEntryState :: PlaceHolder => {
99
- self . non_equal_indices . push ( row_idx) ;
110
+ AggregationHashTableEntryState :: PlaceHolder ( group_index) => {
111
+ let equal_to_ctx = EqualToContext {
112
+ row_index : row_idx,
113
+ group_index : * group_index,
114
+ } ;
115
+ self . pending_equal_to_ctxs . push ( equal_to_ctx) ;
100
116
continue ;
101
117
}
102
118
AggregationHashTableEntryState :: Entry ( entry) => entry,
103
119
} ;
104
120
105
121
let target_hash = batch_hashes[ row_idx] ;
106
122
if table_entry. check_hash ( target_hash) {
107
- self . pending_equal_to_indices . push ( row_idx) ;
123
+ let equal_to_ctx = EqualToContext {
124
+ row_index : row_idx,
125
+ group_index : table_entry. group_index ( ) ,
126
+ } ;
127
+ self . pending_equal_to_ctxs . push ( equal_to_ctx) ;
108
128
} else {
109
129
self . non_equal_indices . push ( row_idx) ;
110
130
}
@@ -122,7 +142,7 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
122
142
let offset = self . table_offsets [ row_idx] ;
123
143
debug_assert ! ( matches!(
124
144
& self . raw_table[ offset] ,
125
- AggregationHashTableEntryState :: PlaceHolder
145
+ AggregationHashTableEntryState :: PlaceHolder ( _ )
126
146
) ) ;
127
147
self . raw_table [ offset] =
128
148
AggregationHashTableEntryState :: Entry ( new_entry ( group_idx, hash) ) ;
@@ -133,10 +153,10 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
133
153
}
134
154
135
155
// 5. Batch equal to
136
- if !self . pending_equal_to_indices . is_empty ( ) {
156
+ if !self . pending_equal_to_ctxs . is_empty ( ) {
137
157
batch_equal_to (
138
158
& cols,
139
- & self . pending_equal_to_indices ,
159
+ & self . pending_equal_to_ctxs ,
140
160
& mut self . non_equal_indices ,
141
161
groups,
142
162
) ;
@@ -179,10 +199,10 @@ impl<T: AggregationHashTableEntry> AggregationHashTable<T> {
179
199
for state in old_raw_table {
180
200
debug_assert ! ( !matches!(
181
201
& state,
182
- AggregationHashTableEntryState :: PlaceHolder
202
+ AggregationHashTableEntryState :: PlaceHolder ( _ )
183
203
) ) ;
184
204
if let AggregationHashTableEntryState :: Entry ( entry) = & state {
185
- let hash = entry. get_hash ( random_state) ;
205
+ let hash = entry. hash_value ( random_state) ;
186
206
let offset = ( hash & self . bit_mask ) as usize ;
187
207
self . raw_table [ offset] = state;
188
208
}
0 commit comments