@@ -27,6 +27,7 @@ use arrow_buffer::*;
27
27
use arrow_schema:: {
28
28
ArrowError , DataType , Field as ArrowField , FieldRef , Fields , Schema as ArrowSchema , SchemaRef ,
29
29
} ;
30
+ use std:: cmp:: Ordering ;
30
31
use std:: collections:: HashMap ;
31
32
use std:: io:: Read ;
32
33
use std:: sync:: Arc ;
@@ -114,6 +115,13 @@ enum Decoder {
114
115
StringView ( OffsetBufferBuilder < i32 > , Vec < u8 > ) ,
115
116
List ( FieldRef , OffsetBufferBuilder < i32 > , Box < Decoder > ) ,
116
117
Record ( Fields , Vec < Decoder > ) ,
118
+ Map (
119
+ FieldRef ,
120
+ OffsetBufferBuilder < i32 > ,
121
+ OffsetBufferBuilder < i32 > ,
122
+ Vec < u8 > ,
123
+ Box < Decoder > ,
124
+ ) ,
117
125
Nullable ( Nullability , NullBufferBuilder , Box < Decoder > ) ,
118
126
}
119
127
@@ -169,6 +177,25 @@ impl Decoder {
169
177
}
170
178
Self :: Record ( arrow_fields. into ( ) , encodings)
171
179
}
180
+ Codec :: Map ( child) => {
181
+ let val_field = child. field_with_name ( "value" ) . with_nullable ( true ) ;
182
+ let map_field = Arc :: new ( ArrowField :: new (
183
+ "entries" ,
184
+ DataType :: Struct ( Fields :: from ( vec ! [
185
+ ArrowField :: new( "key" , DataType :: Utf8 , false ) ,
186
+ val_field,
187
+ ] ) ) ,
188
+ false ,
189
+ ) ) ;
190
+ let val_dec = Self :: try_new ( child) ?;
191
+ Self :: Map (
192
+ map_field,
193
+ OffsetBufferBuilder :: new ( DEFAULT_CAPACITY ) ,
194
+ OffsetBufferBuilder :: new ( DEFAULT_CAPACITY ) ,
195
+ Vec :: with_capacity ( DEFAULT_CAPACITY ) ,
196
+ Box :: new ( val_dec) ,
197
+ )
198
+ }
172
199
} ;
173
200
174
201
Ok ( match data_type. nullability ( ) {
@@ -201,6 +228,9 @@ impl Decoder {
201
228
e. append_null ( ) ;
202
229
}
203
230
Self :: Record ( _, e) => e. iter_mut ( ) . for_each ( |e| e. append_null ( ) ) ,
231
+ Self :: Map ( _, _koff, moff, _, _) => {
232
+ moff. push_length ( 0 ) ;
233
+ }
204
234
Self :: Nullable ( _, _, _) => unreachable ! ( "Nulls cannot be nested" ) ,
205
235
}
206
236
}
@@ -236,6 +266,15 @@ impl Decoder {
236
266
encoding. decode ( buf) ?;
237
267
}
238
268
}
269
+ Self :: Map ( _, koff, moff, kdata, valdec) => {
270
+ let newly_added = read_map_blocks ( buf, |cur| {
271
+ let kb = cur. get_bytes ( ) ?;
272
+ koff. push_length ( kb. len ( ) ) ;
273
+ kdata. extend_from_slice ( kb) ;
274
+ valdec. decode ( cur)
275
+ } ) ?;
276
+ moff. push_length ( newly_added) ;
277
+ }
239
278
Self :: Nullable ( nullability, nulls, e) => {
240
279
let is_valid = buf. get_bool ( ) ? == matches ! ( nullability, Nullability :: NullFirst ) ;
241
280
nulls. append ( is_valid) ;
@@ -273,7 +312,6 @@ impl Decoder {
273
312
) ,
274
313
Self :: Float32 ( values) => Arc :: new ( flush_primitive :: < Float32Type > ( values, nulls) ) ,
275
314
Self :: Float64 ( values) => Arc :: new ( flush_primitive :: < Float64Type > ( values, nulls) ) ,
276
-
277
315
Self :: Binary ( offsets, values) => {
278
316
let offsets = flush_offsets ( offsets) ;
279
317
let values = flush_values ( values) . into ( ) ;
@@ -313,10 +351,89 @@ impl Decoder {
313
351
. collect :: < Result < Vec < _ > , _ > > ( ) ?;
314
352
Arc :: new ( StructArray :: new ( fields. clone ( ) , arrays, nulls) )
315
353
}
354
+ Self :: Map ( map_field, k_off, m_off, kdata, valdec) => {
355
+ let moff = flush_offsets ( m_off) ;
356
+ let koff = flush_offsets ( k_off) ;
357
+ let kd = flush_values ( kdata) . into ( ) ;
358
+ let val_arr = valdec. flush ( None ) ?;
359
+ let key_arr = StringArray :: new ( koff, kd, None ) ;
360
+ if key_arr. len ( ) != val_arr. len ( ) {
361
+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
362
+ "Map keys length ({}) != map values length ({})" ,
363
+ key_arr. len( ) ,
364
+ val_arr. len( )
365
+ ) ) ) ;
366
+ }
367
+ let final_len = moff. len ( ) - 1 ;
368
+ if let Some ( n) = & nulls {
369
+ if n. len ( ) != final_len {
370
+ return Err ( ArrowError :: InvalidArgumentError ( format ! (
371
+ "Map array null buffer length {} != final map length {final_len}" ,
372
+ n. len( )
373
+ ) ) ) ;
374
+ }
375
+ }
376
+ let entries_struct = StructArray :: new (
377
+ Fields :: from ( vec ! [
378
+ Arc :: new( ArrowField :: new( "key" , DataType :: Utf8 , false ) ) ,
379
+ Arc :: new( ArrowField :: new( "value" , val_arr. data_type( ) . clone( ) , true ) ) ,
380
+ ] ) ,
381
+ vec ! [ Arc :: new( key_arr) , val_arr] ,
382
+ None ,
383
+ ) ;
384
+ let map_arr = MapArray :: new ( map_field. clone ( ) , moff, entries_struct, nulls, false ) ;
385
+ Arc :: new ( map_arr)
386
+ }
316
387
} )
317
388
}
318
389
}
319
390
391
+ fn read_map_blocks (
392
+ buf : & mut AvroCursor ,
393
+ decode_entry : impl FnMut ( & mut AvroCursor ) -> Result < ( ) , ArrowError > ,
394
+ ) -> Result < usize , ArrowError > {
395
+ read_blockwise_items ( buf, true , decode_entry)
396
+ }
397
+
398
+ fn read_blockwise_items (
399
+ buf : & mut AvroCursor ,
400
+ read_size_after_negative : bool ,
401
+ mut decode_fn : impl FnMut ( & mut AvroCursor ) -> Result < ( ) , ArrowError > ,
402
+ ) -> Result < usize , ArrowError > {
403
+ let mut total = 0usize ;
404
+ loop {
405
+ // Read the block count
406
+ // positive = that many items
407
+ // negative = that many items + read block size
408
+ // See: https://avro.apache.org/docs/1.11.1/specification/#maps
409
+ let block_count = buf. get_long ( ) ?;
410
+ match block_count. cmp ( & 0 ) {
411
+ Ordering :: Equal => break ,
412
+ Ordering :: Less => {
413
+ // If block_count is negative, read the absolute value of count,
414
+ // then read the block size as a long and discard
415
+ let count = ( -block_count) as usize ;
416
+ if read_size_after_negative {
417
+ let _size_in_bytes = buf. get_long ( ) ?;
418
+ }
419
+ for _ in 0 ..count {
420
+ decode_fn ( buf) ?;
421
+ }
422
+ total += count;
423
+ }
424
+ Ordering :: Greater => {
425
+ // If block_count is positive, decode that many items
426
+ let count = block_count as usize ;
427
+ for _i in 0 ..count {
428
+ decode_fn ( buf) ?;
429
+ }
430
+ total += count;
431
+ }
432
+ }
433
+ }
434
+ Ok ( total)
435
+ }
436
+
320
437
#[ inline]
321
438
fn flush_values < T > ( values : & mut Vec < T > ) -> Vec < T > {
322
439
std:: mem:: replace ( values, Vec :: with_capacity ( DEFAULT_CAPACITY ) )
@@ -336,3 +453,82 @@ fn flush_primitive<T: ArrowPrimitiveType>(
336
453
}
337
454
338
455
const DEFAULT_CAPACITY : usize = 1024 ;
456
+
457
+ #[ cfg( test) ]
458
+ mod tests {
459
+ use super :: * ;
460
+ use arrow_array:: {
461
+ cast:: AsArray , Array , Decimal128Array , DictionaryArray , FixedSizeBinaryArray ,
462
+ IntervalMonthDayNanoArray , ListArray , MapArray , StringArray , StructArray ,
463
+ } ;
464
+
465
+ fn encode_avro_long ( value : i64 ) -> Vec < u8 > {
466
+ let mut buf = Vec :: new ( ) ;
467
+ let mut v = ( value << 1 ) ^ ( value >> 63 ) ;
468
+ while v & !0x7F != 0 {
469
+ buf. push ( ( ( v & 0x7F ) | 0x80 ) as u8 ) ;
470
+ v >>= 7 ;
471
+ }
472
+ buf. push ( v as u8 ) ;
473
+ buf
474
+ }
475
+
476
+ fn encode_avro_bytes ( bytes : & [ u8 ] ) -> Vec < u8 > {
477
+ let mut buf = encode_avro_long ( bytes. len ( ) as i64 ) ;
478
+ buf. extend_from_slice ( bytes) ;
479
+ buf
480
+ }
481
+
482
+ fn avro_from_codec ( codec : Codec ) -> AvroDataType {
483
+ AvroDataType :: new ( codec, Default :: default ( ) , None )
484
+ }
485
+
486
+ #[ test]
487
+ fn test_map_decoding_one_entry ( ) {
488
+ let value_type = avro_from_codec ( Codec :: Utf8 ) ;
489
+ let map_type = avro_from_codec ( Codec :: Map ( Arc :: new ( value_type) ) ) ;
490
+ let mut decoder = Decoder :: try_new ( & map_type) . unwrap ( ) ;
491
+ // Encode a single map with one entry: {"hello": "world"}
492
+ let mut data = Vec :: new ( ) ;
493
+ data. extend_from_slice ( & encode_avro_long ( 1 ) ) ;
494
+ data. extend_from_slice ( & encode_avro_bytes ( b"hello" ) ) ; // key
495
+ data. extend_from_slice ( & encode_avro_bytes ( b"world" ) ) ; // value
496
+ data. extend_from_slice ( & encode_avro_long ( 0 ) ) ;
497
+ let mut cursor = AvroCursor :: new ( & data) ;
498
+ decoder. decode ( & mut cursor) . unwrap ( ) ;
499
+ let array = decoder. flush ( None ) . unwrap ( ) ;
500
+ let map_arr = array. as_any ( ) . downcast_ref :: < MapArray > ( ) . unwrap ( ) ;
501
+ assert_eq ! ( map_arr. len( ) , 1 ) ; // one map
502
+ assert_eq ! ( map_arr. value_length( 0 ) , 1 ) ;
503
+ let entries = map_arr. value ( 0 ) ;
504
+ let struct_entries = entries. as_any ( ) . downcast_ref :: < StructArray > ( ) . unwrap ( ) ;
505
+ assert_eq ! ( struct_entries. len( ) , 1 ) ;
506
+ let key_arr = struct_entries
507
+ . column_by_name ( "key" )
508
+ . unwrap ( )
509
+ . as_any ( )
510
+ . downcast_ref :: < StringArray > ( )
511
+ . unwrap ( ) ;
512
+ let val_arr = struct_entries
513
+ . column_by_name ( "value" )
514
+ . unwrap ( )
515
+ . as_any ( )
516
+ . downcast_ref :: < StringArray > ( )
517
+ . unwrap ( ) ;
518
+ assert_eq ! ( key_arr. value( 0 ) , "hello" ) ;
519
+ assert_eq ! ( val_arr. value( 0 ) , "world" ) ;
520
+ }
521
+
522
+ #[ test]
523
+ fn test_map_decoding_empty ( ) {
524
+ let value_type = avro_from_codec ( Codec :: Utf8 ) ;
525
+ let map_type = avro_from_codec ( Codec :: Map ( Arc :: new ( value_type) ) ) ;
526
+ let mut decoder = Decoder :: try_new ( & map_type) . unwrap ( ) ;
527
+ let data = encode_avro_long ( 0 ) ;
528
+ decoder. decode ( & mut AvroCursor :: new ( & data) ) . unwrap ( ) ;
529
+ let array = decoder. flush ( None ) . unwrap ( ) ;
530
+ let map_arr = array. as_any ( ) . downcast_ref :: < MapArray > ( ) . unwrap ( ) ;
531
+ assert_eq ! ( map_arr. len( ) , 1 ) ;
532
+ assert_eq ! ( map_arr. value_length( 0 ) , 0 ) ;
533
+ }
534
+ }
0 commit comments