Skip to content

Commit 52e198e

Browse files
authored
refactor (apache#4391)
Signed-off-by: remzi <[email protected]> Signed-off-by: remzi <[email protected]>
1 parent a31b44e commit 52e198e

File tree

1 file changed

+28
-38
lines changed
  • datafusion/core/src/physical_plan

1 file changed

+28
-38
lines changed

datafusion/core/src/physical_plan/limit.rs

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ impl GlobalLimitExec {
7979
}
8080

8181
/// Maximum number of rows to fetch
82-
pub fn fetch(&self) -> Option<&usize> {
83-
self.fetch.as_ref()
82+
pub fn fetch(&self) -> Option<usize> {
83+
self.fetch
8484
}
8585
}
8686

@@ -365,30 +365,17 @@ impl ExecutionPlan for LocalLimitExec {
365365
}
366366
}
367367

368-
/// Truncate a RecordBatch to maximum of n rows
369-
pub fn truncate_batch(batch: &RecordBatch, n: usize) -> RecordBatch {
370-
let limited_columns: Vec<ArrayRef> = (0..batch.num_columns())
371-
.map(|i| limit(batch.column(i), n))
372-
.collect();
373-
374-
RecordBatch::try_new(batch.schema(), limited_columns).unwrap()
375-
}
376-
377368
/// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.
378369
struct LimitStream {
379-
/// The number of rows to skip
370+
/// The remaining number of rows to skip
380371
skip: usize,
381-
/// The maximum number of rows to produce, after `skip` are skipped
372+
/// The remaining number of rows to produce
382373
fetch: usize,
383374
/// The input to read from. This is set to None once the limit is
384375
/// reached to enable early termination
385376
input: Option<SendableRecordBatchStream>,
386377
/// Copy of the input schema
387378
schema: SchemaRef,
388-
/// Number of rows have already skipped
389-
current_skipped: usize,
390-
/// the current number of rows which have been produced
391-
current_fetched: usize,
392379
/// Execution time metrics
393380
baseline_metrics: BaselineMetrics,
394381
}
@@ -406,8 +393,6 @@ impl LimitStream {
406393
fetch: fetch.unwrap_or(usize::MAX),
407394
input: Some(input),
408395
schema,
409-
current_skipped: 0,
410-
current_fetched: 0,
411396
baseline_metrics,
412397
}
413398
}
@@ -420,47 +405,52 @@ impl LimitStream {
420405
loop {
421406
let poll = input.poll_next_unpin(cx);
422407
let poll = poll.map_ok(|batch| {
423-
if batch.num_rows() + self.current_skipped <= self.skip {
424-
self.current_skipped += batch.num_rows();
408+
if batch.num_rows() <= self.skip {
409+
self.skip -= batch.num_rows();
425410
RecordBatch::new_empty(input.schema())
426411
} else {
427-
let offset = self.skip - self.current_skipped;
428-
let new_batch = batch.slice(offset, batch.num_rows() - offset);
429-
self.current_skipped = self.skip;
412+
let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
413+
self.skip = 0;
430414
new_batch
431415
}
432416
});
433417

434418
match &poll {
435-
Poll::Ready(Some(Ok(batch)))
436-
if batch.num_rows() > 0 && self.current_skipped == self.skip =>
437-
{
438-
break poll
419+
Poll::Ready(Some(Ok(batch))) => {
420+
if batch.num_rows() > 0 && self.skip == 0 {
421+
break poll;
422+
} else {
423+
// continue to poll input stream
424+
}
439425
}
440426
Poll::Ready(Some(Err(_e))) => break poll,
441427
Poll::Ready(None) => break poll,
442428
Poll::Pending => break poll,
443-
_ => {
444-
// continue to poll input stream
445-
}
446429
}
447430
}
448431
}
449432

433+
/// fetches from the batch
450434
fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
451435
// records time on drop
452436
let _timer = self.baseline_metrics.elapsed_compute().timer();
453-
if self.current_fetched == self.fetch {
437+
if self.fetch == 0 {
454438
self.input = None; // clear input so it can be dropped early
455439
None
456-
} else if self.current_fetched + batch.num_rows() <= self.fetch {
457-
self.current_fetched += batch.num_rows();
440+
} else if batch.num_rows() <= self.fetch {
441+
self.fetch -= batch.num_rows();
458442
Some(batch)
459443
} else {
460-
let batch_rows = self.fetch - self.current_fetched;
461-
self.current_fetched = self.fetch;
444+
let batch_rows = self.fetch;
445+
self.fetch = 0;
462446
self.input = None; // clear input so it can be dropped early
463-
Some(truncate_batch(&batch, batch_rows))
447+
448+
let limited_columns: Vec<ArrayRef> = batch
449+
.columns()
450+
.iter()
451+
.map(|col| limit(col, batch_rows))
452+
.collect();
453+
Some(RecordBatch::try_new(batch.schema(), limited_columns).unwrap())
464454
}
465455
}
466456
}
@@ -472,7 +462,7 @@ impl Stream for LimitStream {
472462
mut self: Pin<&mut Self>,
473463
cx: &mut Context<'_>,
474464
) -> Poll<Option<Self::Item>> {
475-
let fetch_started = self.current_skipped == self.skip;
465+
let fetch_started = self.skip == 0;
476466
let poll = match &mut self.input {
477467
Some(input) => {
478468
let poll = if fetch_started {

0 commit comments

Comments
 (0)