18
18
//! CSV format abstractions
19
19
20
20
use std:: any:: Any ;
21
-
22
21
use std:: collections:: HashSet ;
22
+ use std:: fmt;
23
+ use std:: fmt:: { Debug , Display } ;
23
24
use std:: sync:: Arc ;
24
25
26
+ use arrow:: csv:: WriterBuilder ;
25
27
use arrow:: datatypes:: { DataType , Field , Fields , Schema } ;
26
28
use arrow:: { self , datatypes:: SchemaRef } ;
27
- use async_trait:: async_trait;
28
- use bytes:: { Buf , Bytes } ;
29
-
29
+ use arrow_array:: RecordBatch ;
30
30
use datafusion_common:: DataFusionError ;
31
-
31
+ use datafusion_execution :: TaskContext ;
32
32
use datafusion_physical_expr:: PhysicalExpr ;
33
+
34
+ use async_trait:: async_trait;
35
+ use bytes:: { Buf , Bytes } ;
33
36
use futures:: stream:: BoxStream ;
34
37
use futures:: { pin_mut, Stream , StreamExt , TryStreamExt } ;
35
38
use object_store:: { delimited:: newline_delimited_stream, ObjectMeta , ObjectStore } ;
39
+ use tokio:: io:: { AsyncWrite , AsyncWriteExt } ;
36
40
37
41
use super :: FileFormat ;
38
42
use crate :: datasource:: file_format:: file_type:: FileCompressionType ;
39
- use crate :: datasource:: file_format:: DEFAULT_SCHEMA_INFER_MAX_RECORD ;
43
+ use crate :: datasource:: file_format:: FileWriterMode ;
44
+ use crate :: datasource:: file_format:: {
45
+ AbortMode , AbortableWrite , AsyncPutWriter , BatchSerializer , MultiPart ,
46
+ DEFAULT_SCHEMA_INFER_MAX_RECORD ,
47
+ } ;
40
48
use crate :: error:: Result ;
41
49
use crate :: execution:: context:: SessionState ;
42
- use crate :: physical_plan:: file_format:: { CsvExec , FileScanConfig } ;
43
- use crate :: physical_plan:: ExecutionPlan ;
50
+ use crate :: physical_plan:: file_format:: {
51
+ CsvExec , FileGroupDisplay , FileMeta , FileScanConfig , FileSinkConfig ,
52
+ } ;
53
+ use crate :: physical_plan:: insert:: { DataSink , InsertExec } ;
44
54
use crate :: physical_plan:: Statistics ;
55
+ use crate :: physical_plan:: { ExecutionPlan , SendableRecordBatchStream } ;
45
56
46
57
/// The default file extension of csv files
47
58
pub const DEFAULT_CSV_EXTENSION : & str = ".csv" ;
@@ -220,6 +231,22 @@ impl FileFormat for CsvFormat {
220
231
) ;
221
232
Ok ( Arc :: new ( exec) )
222
233
}
234
+
235
+ async fn create_writer_physical_plan (
236
+ & self ,
237
+ input : Arc < dyn ExecutionPlan > ,
238
+ _state : & SessionState ,
239
+ conf : FileSinkConfig ,
240
+ ) -> Result < Arc < dyn ExecutionPlan > > {
241
+ let sink = Arc :: new ( CsvSink :: new (
242
+ conf,
243
+ self . has_header ,
244
+ self . delimiter ,
245
+ self . file_compression_type . clone ( ) ,
246
+ ) ) ;
247
+
248
+ Ok ( Arc :: new ( InsertExec :: new ( input, sink) ) as _ )
249
+ }
223
250
}
224
251
225
252
impl CsvFormat {
@@ -324,6 +351,243 @@ fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schem
324
351
Schema :: new ( fields)
325
352
}
326
353
354
+ impl Default for CsvSerializer {
355
+ fn default ( ) -> Self {
356
+ Self :: new ( )
357
+ }
358
+ }
359
+
360
+ /// Define a struct for serializing CSV records to a stream
361
+ pub struct CsvSerializer {
362
+ // CSV writer builder
363
+ builder : WriterBuilder ,
364
+ // Inner buffer for avoiding reallocation
365
+ buffer : Vec < u8 > ,
366
+ // Flag to indicate whether there will be a header
367
+ header : bool ,
368
+ }
369
+
370
+ impl CsvSerializer {
371
+ /// Constructor for the CsvSerializer object
372
+ pub fn new ( ) -> Self {
373
+ Self {
374
+ builder : WriterBuilder :: new ( ) ,
375
+ header : true ,
376
+ buffer : Vec :: with_capacity ( 4096 ) ,
377
+ }
378
+ }
379
+
380
+ /// Method for setting the CSV writer builder
381
+ pub fn with_builder ( mut self , builder : WriterBuilder ) -> Self {
382
+ self . builder = builder;
383
+ self
384
+ }
385
+
386
+ /// Method for setting the CSV writer header status
387
+ pub fn with_header ( mut self , header : bool ) -> Self {
388
+ self . header = header;
389
+ self
390
+ }
391
+ }
392
+
393
+ #[ async_trait]
394
+ impl BatchSerializer for CsvSerializer {
395
+ async fn serialize ( & mut self , batch : RecordBatch ) -> Result < Bytes > {
396
+ let builder = self . builder . clone ( ) ;
397
+ let mut writer = builder. has_headers ( self . header ) . build ( & mut self . buffer ) ;
398
+ writer. write ( & batch) ?;
399
+ drop ( writer) ;
400
+ self . header = false ;
401
+ Ok ( Bytes :: from ( self . buffer . drain ( ..) . collect :: < Vec < u8 > > ( ) ) )
402
+ }
403
+ }
404
+
405
+ async fn check_for_errors < T , W : AsyncWrite + Unpin + Send > (
406
+ result : Result < T > ,
407
+ writers : & mut [ AbortableWrite < W > ] ,
408
+ ) -> Result < T > {
409
+ match result {
410
+ Ok ( value) => Ok ( value) ,
411
+ Err ( e) => {
412
+ // Abort all writers before returning the error:
413
+ for writer in writers {
414
+ let mut abort_future = writer. abort_writer ( ) ;
415
+ if let Ok ( abort_future) = & mut abort_future {
416
+ let _ = abort_future. await ;
417
+ }
418
+ // Ignore errors that occur during abortion,
419
+ // We do try to abort all writers before returning error.
420
+ }
421
+ // After aborting writers return original error.
422
+ Err ( e)
423
+ }
424
+ }
425
+ }
426
+
427
+ /// Implements [`DataSink`] for writing to a CSV file.
428
+ struct CsvSink {
429
+ /// Config options for writing data
430
+ config : FileSinkConfig ,
431
+ has_header : bool ,
432
+ delimiter : u8 ,
433
+ file_compression_type : FileCompressionType ,
434
+ }
435
+
436
+ impl Debug for CsvSink {
437
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
438
+ f. debug_struct ( "CsvSink" )
439
+ . field ( "has_header" , & self . has_header )
440
+ . field ( "delimiter" , & self . delimiter )
441
+ . field ( "file_compression_type" , & self . file_compression_type )
442
+ . finish ( )
443
+ }
444
+ }
445
+
446
+ impl Display for CsvSink {
447
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
448
+ write ! (
449
+ f,
450
+ "CsvSink(writer_mode={:?}, file_groups={})" ,
451
+ self . config. writer_mode,
452
+ FileGroupDisplay ( & self . config. file_groups) ,
453
+ )
454
+ }
455
+ }
456
+
457
+ impl CsvSink {
458
+ fn new (
459
+ config : FileSinkConfig ,
460
+ has_header : bool ,
461
+ delimiter : u8 ,
462
+ file_compression_type : FileCompressionType ,
463
+ ) -> Self {
464
+ Self {
465
+ config,
466
+ has_header,
467
+ delimiter,
468
+ file_compression_type,
469
+ }
470
+ }
471
+
472
+ // Create a write for Csv files
473
+ async fn create_writer (
474
+ & self ,
475
+ file_meta : FileMeta ,
476
+ object_store : Arc < dyn ObjectStore > ,
477
+ ) -> Result < AbortableWrite < Box < dyn AsyncWrite + Send + Unpin > > > {
478
+ let object = & file_meta. object_meta ;
479
+ match self . config . writer_mode {
480
+ // If the mode is append, call the store's append method and return wrapped in
481
+ // a boxed trait object.
482
+ FileWriterMode :: Append => {
483
+ let writer = object_store
484
+ . append ( & object. location )
485
+ . await
486
+ . map_err ( DataFusionError :: ObjectStore ) ?;
487
+ let writer = AbortableWrite :: new (
488
+ self . file_compression_type . convert_async_writer ( writer) ?,
489
+ AbortMode :: Append ,
490
+ ) ;
491
+ Ok ( writer)
492
+ }
493
+ // If the mode is put, create a new AsyncPut writer and return it wrapped in
494
+ // a boxed trait object
495
+ FileWriterMode :: Put => {
496
+ let writer = Box :: new ( AsyncPutWriter :: new ( object. clone ( ) , object_store) ) ;
497
+ let writer = AbortableWrite :: new (
498
+ self . file_compression_type . convert_async_writer ( writer) ?,
499
+ AbortMode :: Put ,
500
+ ) ;
501
+ Ok ( writer)
502
+ }
503
+ // If the mode is put multipart, call the store's put_multipart method and
504
+ // return the writer wrapped in a boxed trait object.
505
+ FileWriterMode :: PutMultipart => {
506
+ let ( multipart_id, writer) = object_store
507
+ . put_multipart ( & object. location )
508
+ . await
509
+ . map_err ( DataFusionError :: ObjectStore ) ?;
510
+ Ok ( AbortableWrite :: new (
511
+ self . file_compression_type . convert_async_writer ( writer) ?,
512
+ AbortMode :: MultiPart ( MultiPart :: new (
513
+ object_store,
514
+ multipart_id,
515
+ object. location . clone ( ) ,
516
+ ) ) ,
517
+ ) )
518
+ }
519
+ }
520
+ }
521
+ }
522
+
523
+ #[ async_trait]
524
+ impl DataSink for CsvSink {
525
+ async fn write_all (
526
+ & self ,
527
+ mut data : SendableRecordBatchStream ,
528
+ context : & Arc < TaskContext > ,
529
+ ) -> Result < u64 > {
530
+ let num_partitions = self . config . file_groups . len ( ) ;
531
+
532
+ let object_store = context
533
+ . runtime_env ( )
534
+ . object_store ( & self . config . object_store_url ) ?;
535
+
536
+ // Construct serializer and writer for each file group
537
+ let mut serializers = vec ! [ ] ;
538
+ let mut writers = vec ! [ ] ;
539
+ for file_group in & self . config . file_groups {
540
+ // In append mode, consider has_header flag only when file is empty (at the start).
541
+ // For other modes, use has_header flag as is.
542
+ let header = self . has_header
543
+ && ( !matches ! ( & self . config. writer_mode, FileWriterMode :: Append )
544
+ || file_group. object_meta . size == 0 ) ;
545
+ let builder = WriterBuilder :: new ( ) . with_delimiter ( self . delimiter ) ;
546
+ let serializer = CsvSerializer :: new ( )
547
+ . with_builder ( builder)
548
+ . with_header ( header) ;
549
+ serializers. push ( serializer) ;
550
+
551
+ let file = file_group. clone ( ) ;
552
+ let writer = self
553
+ . create_writer ( file. object_meta . clone ( ) . into ( ) , object_store. clone ( ) )
554
+ . await ?;
555
+ writers. push ( writer) ;
556
+ }
557
+
558
+ let mut idx = 0 ;
559
+ let mut row_count = 0 ;
560
+ // Map errors to DatafusionError.
561
+ let err_converter =
562
+ |_| DataFusionError :: Internal ( "Unexpected FileSink Error" . to_string ( ) ) ;
563
+ while let Some ( maybe_batch) = data. next ( ) . await {
564
+ // Write data to files in a round robin fashion:
565
+ idx = ( idx + 1 ) % num_partitions;
566
+ let serializer = & mut serializers[ idx] ;
567
+ let batch = check_for_errors ( maybe_batch, & mut writers) . await ?;
568
+ row_count += batch. num_rows ( ) ;
569
+ let bytes =
570
+ check_for_errors ( serializer. serialize ( batch) . await , & mut writers) . await ?;
571
+ let writer = & mut writers[ idx] ;
572
+ check_for_errors (
573
+ writer. write_all ( & bytes) . await . map_err ( err_converter) ,
574
+ & mut writers,
575
+ )
576
+ . await ?;
577
+ }
578
+ // Perform cleanup:
579
+ let n_writers = writers. len ( ) ;
580
+ for idx in 0 ..n_writers {
581
+ check_for_errors (
582
+ writers[ idx] . shutdown ( ) . await . map_err ( err_converter) ,
583
+ & mut writers,
584
+ )
585
+ . await ?;
586
+ }
587
+ Ok ( row_count as u64 )
588
+ }
589
+ }
590
+
327
591
#[ cfg( test) ]
328
592
mod tests {
329
593
use super :: super :: test_util:: scan_format;
@@ -333,6 +597,7 @@ mod tests {
333
597
use crate :: physical_plan:: collect;
334
598
use crate :: prelude:: { CsvReadOptions , SessionConfig , SessionContext } ;
335
599
use crate :: test_util:: arrow_test_data;
600
+ use arrow:: compute:: concat_batches;
336
601
use bytes:: Bytes ;
337
602
use chrono:: DateTime ;
338
603
use datafusion_common:: cast:: as_string_array;
@@ -606,4 +871,52 @@ mod tests {
606
871
let format = CsvFormat :: default ( ) ;
607
872
scan_format ( state, & format, & root, file_name, projection, limit) . await
608
873
}
874
+
875
+ #[ tokio:: test]
876
+ async fn test_csv_serializer ( ) -> Result < ( ) > {
877
+ let ctx = SessionContext :: new ( ) ;
878
+ let df = ctx
879
+ . read_csv (
880
+ & format ! ( "{}/csv/aggregate_test_100.csv" , arrow_test_data( ) ) ,
881
+ CsvReadOptions :: default ( ) . has_header ( true ) ,
882
+ )
883
+ . await ?;
884
+ let batches = df
885
+ . select_columns ( & [ "c2" , "c3" ] ) ?
886
+ . limit ( 0 , Some ( 10 ) ) ?
887
+ . collect ( )
888
+ . await ?;
889
+ let batch = concat_batches ( & batches[ 0 ] . schema ( ) , & batches) ?;
890
+ let mut serializer = CsvSerializer :: new ( ) ;
891
+ let bytes = serializer. serialize ( batch) . await ?;
892
+ assert_eq ! (
893
+ "c2,c3\n 2,1\n 5,-40\n 1,29\n 1,-85\n 5,-82\n 4,-111\n 3,104\n 3,13\n 1,38\n 4,-38\n " ,
894
+ String :: from_utf8( bytes. into( ) ) . unwrap( )
895
+ ) ;
896
+ Ok ( ( ) )
897
+ }
898
+
899
+ #[ tokio:: test]
900
+ async fn test_csv_serializer_no_header ( ) -> Result < ( ) > {
901
+ let ctx = SessionContext :: new ( ) ;
902
+ let df = ctx
903
+ . read_csv (
904
+ & format ! ( "{}/csv/aggregate_test_100.csv" , arrow_test_data( ) ) ,
905
+ CsvReadOptions :: default ( ) . has_header ( true ) ,
906
+ )
907
+ . await ?;
908
+ let batches = df
909
+ . select_columns ( & [ "c2" , "c3" ] ) ?
910
+ . limit ( 0 , Some ( 10 ) ) ?
911
+ . collect ( )
912
+ . await ?;
913
+ let batch = concat_batches ( & batches[ 0 ] . schema ( ) , & batches) ?;
914
+ let mut serializer = CsvSerializer :: new ( ) . with_header ( false ) ;
915
+ let bytes = serializer. serialize ( batch) . await ?;
916
+ assert_eq ! (
917
+ "2,1\n 5,-40\n 1,29\n 1,-85\n 5,-82\n 4,-111\n 3,104\n 3,13\n 1,38\n 4,-38\n " ,
918
+ String :: from_utf8( bytes. into( ) ) . unwrap( )
919
+ ) ;
920
+ Ok ( ( ) )
921
+ }
609
922
}
0 commit comments