Skip to content

Commit d6fe1de

Browse files
authored
fix: overcounting of memory in first/last. (#15924)
When aggregating first/last list over a column of lists, the first/last accumulators hold the necessary scalar value as is, which points to the list in the original input buffer. This results in two issues: 1) We prevent the deallocation of the input arrays which might be significantly larger than the single value we want to hold. 2) During aggreagtion with groups, many accumulators receive slices of the same input buffer, resulting in all held values pointing to this buffer. Then, when calculating the size of all accumulators we count the buffer multiple times, since each accumulator considers it to be part of its own allocation.
1 parent e2a5c1e commit d6fe1de

File tree

3 files changed

+167
-50
lines changed

3 files changed

+167
-50
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,49 +3449,80 @@ impl ScalarValue {
34493449
.sum::<usize>()
34503450
}
34513451

3452-
/// Performs a deep clone of the ScalarValue, creating new copies of all nested data structures.
3453-
/// This is different from the standard `clone()` which may share data through `Arc`.
3454-
/// Aggregation functions like `max` will cost a lot of memory if the data is not cloned.
3455-
pub fn force_clone(&self) -> Self {
3452+
/// Compacts the allocation referenced by `self` to the minimum, copying the data if
3453+
/// necessary.
3454+
///
3455+
/// This can be relevant when `self` is a list or contains a list as a nested value, as
3456+
/// a single list holds an Arc to its entire original array buffer.
3457+
pub fn compact(&mut self) {
34563458
match self {
3457-
// Complex types need deep clone of their contents
3458-
ScalarValue::List(array) => {
3459-
let array = copy_array_data(&array.to_data());
3460-
let new_array = ListArray::from(array);
3461-
ScalarValue::List(Arc::new(new_array))
3459+
ScalarValue::Null
3460+
| ScalarValue::Boolean(_)
3461+
| ScalarValue::Float16(_)
3462+
| ScalarValue::Float32(_)
3463+
| ScalarValue::Float64(_)
3464+
| ScalarValue::Decimal128(_, _, _)
3465+
| ScalarValue::Decimal256(_, _, _)
3466+
| ScalarValue::Int8(_)
3467+
| ScalarValue::Int16(_)
3468+
| ScalarValue::Int32(_)
3469+
| ScalarValue::Int64(_)
3470+
| ScalarValue::UInt8(_)
3471+
| ScalarValue::UInt16(_)
3472+
| ScalarValue::UInt32(_)
3473+
| ScalarValue::UInt64(_)
3474+
| ScalarValue::Date32(_)
3475+
| ScalarValue::Date64(_)
3476+
| ScalarValue::Time32Second(_)
3477+
| ScalarValue::Time32Millisecond(_)
3478+
| ScalarValue::Time64Microsecond(_)
3479+
| ScalarValue::Time64Nanosecond(_)
3480+
| ScalarValue::IntervalYearMonth(_)
3481+
| ScalarValue::IntervalDayTime(_)
3482+
| ScalarValue::IntervalMonthDayNano(_)
3483+
| ScalarValue::DurationSecond(_)
3484+
| ScalarValue::DurationMillisecond(_)
3485+
| ScalarValue::DurationMicrosecond(_)
3486+
| ScalarValue::DurationNanosecond(_)
3487+
| ScalarValue::Utf8(_)
3488+
| ScalarValue::LargeUtf8(_)
3489+
| ScalarValue::Utf8View(_)
3490+
| ScalarValue::TimestampSecond(_, _)
3491+
| ScalarValue::TimestampMillisecond(_, _)
3492+
| ScalarValue::TimestampMicrosecond(_, _)
3493+
| ScalarValue::TimestampNanosecond(_, _)
3494+
| ScalarValue::Binary(_)
3495+
| ScalarValue::FixedSizeBinary(_, _)
3496+
| ScalarValue::LargeBinary(_)
3497+
| ScalarValue::BinaryView(_) => (),
3498+
ScalarValue::FixedSizeList(arr) => {
3499+
let array = copy_array_data(&arr.to_data());
3500+
*Arc::make_mut(arr) = FixedSizeListArray::from(array);
34623501
}
3463-
ScalarValue::LargeList(array) => {
3464-
let array = copy_array_data(&array.to_data());
3465-
let new_array = LargeListArray::from(array);
3466-
ScalarValue::LargeList(Arc::new(new_array))
3502+
ScalarValue::List(arr) => {
3503+
let array = copy_array_data(&arr.to_data());
3504+
*Arc::make_mut(arr) = ListArray::from(array);
34673505
}
3468-
ScalarValue::FixedSizeList(arr) => {
3506+
ScalarValue::LargeList(arr) => {
34693507
let array = copy_array_data(&arr.to_data());
3470-
let new_array = FixedSizeListArray::from(array);
3471-
ScalarValue::FixedSizeList(Arc::new(new_array))
3508+
*Arc::make_mut(arr) = LargeListArray::from(array)
34723509
}
34733510
ScalarValue::Struct(arr) => {
34743511
let array = copy_array_data(&arr.to_data());
3475-
let new_array = StructArray::from(array);
3476-
ScalarValue::Struct(Arc::new(new_array))
3512+
*Arc::make_mut(arr) = StructArray::from(array);
34773513
}
34783514
ScalarValue::Map(arr) => {
34793515
let array = copy_array_data(&arr.to_data());
3480-
let new_array = MapArray::from(array);
3481-
ScalarValue::Map(Arc::new(new_array))
3482-
}
3483-
ScalarValue::Union(Some((type_id, value)), fields, mode) => {
3484-
let new_value = Box::new(value.force_clone());
3485-
ScalarValue::Union(Some((*type_id, new_value)), fields.clone(), *mode)
3516+
*Arc::make_mut(arr) = MapArray::from(array);
34863517
}
3487-
ScalarValue::Union(None, fields, mode) => {
3488-
ScalarValue::Union(None, fields.clone(), *mode)
3518+
ScalarValue::Union(val, _, _) => {
3519+
if let Some((_, value)) = val.as_mut() {
3520+
value.compact();
3521+
}
34893522
}
3490-
ScalarValue::Dictionary(key_type, value) => {
3491-
let new_value = Box::new(value.force_clone());
3492-
ScalarValue::Dictionary(key_type.clone(), new_value)
3523+
ScalarValue::Dictionary(_, value) => {
3524+
value.compact();
34933525
}
3494-
_ => self.clone(),
34953526
}
34963527
}
34973528
}

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ where
752752

753753
fn size(&self) -> usize {
754754
self.vals.capacity() * size_of::<T::Native>()
755-
+ self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes
755+
+ self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes
756756
+ self.is_sets.capacity() / 8
757757
+ self.size_of_orderings
758758
+ self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
@@ -827,9 +827,14 @@ impl FirstValueAccumulator {
827827
}
828828

829829
// Updates state with the values in the given row.
830-
fn update_with_new_row(&mut self, row: &[ScalarValue]) {
831-
self.first = row[0].clone();
832-
self.orderings = row[1..].to_vec();
830+
fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
831+
// Ensure any Array based scalars hold have a single value to reduce memory pressure
832+
row.iter_mut().for_each(|s| {
833+
s.compact();
834+
});
835+
836+
self.first = row.remove(0);
837+
self.orderings = row;
833838
self.is_set = true;
834839
}
835840

@@ -888,7 +893,7 @@ impl Accumulator for FirstValueAccumulator {
888893
if !self.is_set {
889894
if let Some(first_idx) = self.get_first_idx(values)? {
890895
let row = get_row_at_idx(values, first_idx)?;
891-
self.update_with_new_row(&row);
896+
self.update_with_new_row(row);
892897
}
893898
} else if !self.requirement_satisfied {
894899
if let Some(first_idx) = self.get_first_idx(values)? {
@@ -901,7 +906,7 @@ impl Accumulator for FirstValueAccumulator {
901906
)?
902907
.is_gt()
903908
{
904-
self.update_with_new_row(&row);
909+
self.update_with_new_row(row);
905910
}
906911
}
907912
}
@@ -925,7 +930,7 @@ impl Accumulator for FirstValueAccumulator {
925930
let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
926931

927932
if let Some(first_idx) = min {
928-
let first_row = get_row_at_idx(&filtered_states, first_idx)?;
933+
let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
929934
// When collecting orderings, we exclude the is_set flag from the state.
930935
let first_ordering = &first_row[1..is_set_idx];
931936
let sort_options = get_sort_options(self.ordering_req.as_ref());
@@ -936,7 +941,9 @@ impl Accumulator for FirstValueAccumulator {
936941
// Update with first value in the state. Note that we should exclude the
937942
// is_set flag from the state. Otherwise, we will end up with a state
938943
// containing two is_set flags.
939-
self.update_with_new_row(&first_row[0..is_set_idx]);
944+
assert!(is_set_idx <= first_row.len());
945+
first_row.resize(is_set_idx, ScalarValue::Null);
946+
self.update_with_new_row(first_row);
940947
}
941948
}
942949
Ok(())
@@ -1226,9 +1233,14 @@ impl LastValueAccumulator {
12261233
}
12271234

12281235
// Updates state with the values in the given row.
1229-
fn update_with_new_row(&mut self, row: &[ScalarValue]) {
1230-
self.last = row[0].clone();
1231-
self.orderings = row[1..].to_vec();
1236+
fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1237+
// Ensure any Array based scalars hold have a single value to reduce memory pressure
1238+
row.iter_mut().for_each(|s| {
1239+
s.compact();
1240+
});
1241+
1242+
self.last = row.remove(0);
1243+
self.orderings = row;
12321244
self.is_set = true;
12331245
}
12341246

@@ -1289,7 +1301,7 @@ impl Accumulator for LastValueAccumulator {
12891301
if !self.is_set || self.requirement_satisfied {
12901302
if let Some(last_idx) = self.get_last_idx(values)? {
12911303
let row = get_row_at_idx(values, last_idx)?;
1292-
self.update_with_new_row(&row);
1304+
self.update_with_new_row(row);
12931305
}
12941306
} else if let Some(last_idx) = self.get_last_idx(values)? {
12951307
let row = get_row_at_idx(values, last_idx)?;
@@ -1302,7 +1314,7 @@ impl Accumulator for LastValueAccumulator {
13021314
)?
13031315
.is_lt()
13041316
{
1305-
self.update_with_new_row(&row);
1317+
self.update_with_new_row(row);
13061318
}
13071319
}
13081320

@@ -1326,7 +1338,7 @@ impl Accumulator for LastValueAccumulator {
13261338
let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
13271339

13281340
if let Some(last_idx) = max {
1329-
let last_row = get_row_at_idx(&filtered_states, last_idx)?;
1341+
let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
13301342
// When collecting orderings, we exclude the is_set flag from the state.
13311343
let last_ordering = &last_row[1..is_set_idx];
13321344
let sort_options = get_sort_options(self.ordering_req.as_ref());
@@ -1339,7 +1351,9 @@ impl Accumulator for LastValueAccumulator {
13391351
// Update with last value in the state. Note that we should exclude the
13401352
// is_set flag from the state. Otherwise, we will end up with a state
13411353
// containing two is_set flags.
1342-
self.update_with_new_row(&last_row[0..is_set_idx]);
1354+
assert!(is_set_idx <= last_row.len());
1355+
last_row.resize(is_set_idx, ScalarValue::Null);
1356+
self.update_with_new_row(last_row);
13431357
}
13441358
}
13451359
Ok(())
@@ -1382,7 +1396,13 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<Sort
13821396

13831397
#[cfg(test)]
13841398
mod tests {
1385-
use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema};
1399+
use std::iter::repeat_with;
1400+
1401+
use arrow::{
1402+
array::{Int64Array, ListArray},
1403+
compute::SortOptions,
1404+
datatypes::Schema,
1405+
};
13861406
use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
13871407

13881408
use super::*;
@@ -1772,4 +1792,60 @@ mod tests {
17721792

17731793
Ok(())
17741794
}
1795+
1796+
#[test]
1797+
fn test_first_list_acc_size() -> Result<()> {
1798+
fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1799+
let mut first_accumulator = FirstValueAccumulator::try_new(
1800+
&DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1801+
&[],
1802+
LexOrdering::default(),
1803+
false,
1804+
)?;
1805+
1806+
first_accumulator.update_batch(values)?;
1807+
1808+
Ok(first_accumulator.size())
1809+
}
1810+
1811+
let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1812+
repeat_with(|| Some(vec![Some(1)])).take(10000),
1813+
);
1814+
let batch2 =
1815+
ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1816+
1817+
let size1 = size_after_batch(&[Arc::new(batch1)])?;
1818+
let size2 = size_after_batch(&[Arc::new(batch2)])?;
1819+
assert_eq!(size1, size2);
1820+
1821+
Ok(())
1822+
}
1823+
1824+
#[test]
1825+
fn test_last_list_acc_size() -> Result<()> {
1826+
fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1827+
let mut last_accumulator = LastValueAccumulator::try_new(
1828+
&DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1829+
&[],
1830+
LexOrdering::default(),
1831+
false,
1832+
)?;
1833+
1834+
last_accumulator.update_batch(values)?;
1835+
1836+
Ok(last_accumulator.size())
1837+
}
1838+
1839+
let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1840+
repeat_with(|| Some(vec![Some(1)])).take(10000),
1841+
);
1842+
let batch2 =
1843+
ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1844+
1845+
let size1 = size_after_batch(&[Arc::new(batch1)])?;
1846+
let size2 = size_after_batch(&[Arc::new(batch2)])?;
1847+
assert_eq!(size1, size2);
1848+
1849+
Ok(())
1850+
}
17751851
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,19 +646,29 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarV
646646
}
647647
}
648648
}
649-
// use force_clone to free array reference
650-
Ok(extreme.force_clone())
649+
650+
Ok(extreme)
651651
}
652652

653653
macro_rules! min_max_generic {
654654
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
655655
if $VALUE.is_null() {
656-
$DELTA.clone()
656+
let mut delta_copy = $DELTA.clone();
657+
// When the new value won we want to compact it to
658+
// avoid storing the entire input
659+
delta_copy.compact();
660+
delta_copy
657661
} else if $DELTA.is_null() {
658662
$VALUE.clone()
659663
} else {
660664
match $VALUE.partial_cmp(&$DELTA) {
661-
Some(choose_min_max!($OP)) => $DELTA.clone(),
665+
Some(choose_min_max!($OP)) => {
666+
// When the new value won we want to compact it to
667+
// avoid storing the entire input
668+
let mut delta_copy = $DELTA.clone();
669+
delta_copy.compact();
670+
delta_copy
671+
}
662672
_ => $VALUE.clone(),
663673
}
664674
}

0 commit comments

Comments
 (0)