Skip to content

Commit 05e0d15

Browse files
authored
Add Map support to arrow-avro (#7451)
* Added support for reading Avro Maps types * Fixed lint errors, improved readability of `read_blockwise_items`, added `Map` comments and improved `Map` nullability handling in `data_type` in codec.rs
1 parent e9df239 commit 05e0d15

File tree

2 files changed

+237
-5
lines changed

2 files changed

+237
-5
lines changed

arrow-avro/src/codec.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
1919
use arrow_schema::{
20-
ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
20+
ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
2121
};
2222
use std::borrow::Cow;
2323
use std::collections::HashMap;
@@ -45,6 +45,19 @@ pub struct AvroDataType {
4545
}
4646

4747
impl AvroDataType {
48+
/// Create a new [`AvroDataType`] with the given parts.
49+
pub fn new(
50+
codec: Codec,
51+
metadata: HashMap<String, String>,
52+
nullability: Option<Nullability>,
53+
) -> Self {
54+
AvroDataType {
55+
codec,
56+
metadata,
57+
nullability,
58+
}
59+
}
60+
4861
/// Returns an arrow [`Field`] with the given name
4962
pub fn field_with_name(&self, name: &str) -> Field {
5063
let d = self.codec.data_type();
@@ -183,6 +196,8 @@ pub enum Codec {
183196
List(Arc<AvroDataType>),
184197
/// Represents Avro record type, maps to Arrow's Struct data type
185198
Struct(Arc<[AvroField]>),
199+
/// Represents Avro map type, maps to Arrow's Map data type
200+
Map(Arc<AvroDataType>),
186201
/// Represents Avro duration logical type, maps to Arrow's Interval(IntervalUnit::MonthDayNano) data type
187202
Interval,
188203
}
@@ -214,6 +229,22 @@ impl Codec {
214229
DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
215230
}
216231
Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
232+
Self::Map(value_type) => {
233+
let val_dt = value_type.codec.data_type();
234+
let val_field = Field::new("value", val_dt, value_type.nullability.is_some())
235+
.with_metadata(value_type.metadata.clone());
236+
DataType::Map(
237+
Arc::new(Field::new(
238+
"entries",
239+
DataType::Struct(Fields::from(vec![
240+
Field::new("key", DataType::Utf8, false),
241+
val_field,
242+
])),
243+
false,
244+
)),
245+
false,
246+
)
247+
}
217248
}
218249
}
219250
}
@@ -390,9 +421,14 @@ fn make_data_type<'a>(
390421
ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
391422
"Enum of {e:?} not currently supported"
392423
))),
393-
ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!(
394-
"Map of {m:?} not currently supported"
395-
))),
424+
ComplexType::Map(m) => {
425+
let val = make_data_type(&m.values, namespace, resolver)?;
426+
Ok(AvroDataType {
427+
nullability: None,
428+
metadata: m.attributes.field_metadata(),
429+
codec: Codec::Map(Arc::new(val)),
430+
})
431+
}
396432
},
397433
Schema::Type(t) => {
398434
let mut field = make_data_type(

arrow-avro/src/reader/record.rs

Lines changed: 197 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use arrow_buffer::*;
2727
use arrow_schema::{
2828
ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
2929
};
30+
use std::cmp::Ordering;
3031
use std::collections::HashMap;
3132
use std::io::Read;
3233
use std::sync::Arc;
@@ -114,6 +115,13 @@ enum Decoder {
114115
StringView(OffsetBufferBuilder<i32>, Vec<u8>),
115116
List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
116117
Record(Fields, Vec<Decoder>),
118+
Map(
119+
FieldRef,
120+
OffsetBufferBuilder<i32>,
121+
OffsetBufferBuilder<i32>,
122+
Vec<u8>,
123+
Box<Decoder>,
124+
),
117125
Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
118126
}
119127

@@ -169,6 +177,25 @@ impl Decoder {
169177
}
170178
Self::Record(arrow_fields.into(), encodings)
171179
}
180+
Codec::Map(child) => {
181+
let val_field = child.field_with_name("value").with_nullable(true);
182+
let map_field = Arc::new(ArrowField::new(
183+
"entries",
184+
DataType::Struct(Fields::from(vec![
185+
ArrowField::new("key", DataType::Utf8, false),
186+
val_field,
187+
])),
188+
false,
189+
));
190+
let val_dec = Self::try_new(child)?;
191+
Self::Map(
192+
map_field,
193+
OffsetBufferBuilder::new(DEFAULT_CAPACITY),
194+
OffsetBufferBuilder::new(DEFAULT_CAPACITY),
195+
Vec::with_capacity(DEFAULT_CAPACITY),
196+
Box::new(val_dec),
197+
)
198+
}
172199
};
173200

174201
Ok(match data_type.nullability() {
@@ -201,6 +228,9 @@ impl Decoder {
201228
e.append_null();
202229
}
203230
Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
231+
Self::Map(_, _koff, moff, _, _) => {
232+
moff.push_length(0);
233+
}
204234
Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
205235
}
206236
}
@@ -236,6 +266,15 @@ impl Decoder {
236266
encoding.decode(buf)?;
237267
}
238268
}
269+
Self::Map(_, koff, moff, kdata, valdec) => {
270+
let newly_added = read_map_blocks(buf, |cur| {
271+
let kb = cur.get_bytes()?;
272+
koff.push_length(kb.len());
273+
kdata.extend_from_slice(kb);
274+
valdec.decode(cur)
275+
})?;
276+
moff.push_length(newly_added);
277+
}
239278
Self::Nullable(nullability, nulls, e) => {
240279
let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
241280
nulls.append(is_valid);
@@ -273,7 +312,6 @@ impl Decoder {
273312
),
274313
Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
275314
Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
276-
277315
Self::Binary(offsets, values) => {
278316
let offsets = flush_offsets(offsets);
279317
let values = flush_values(values).into();
@@ -313,10 +351,89 @@ impl Decoder {
313351
.collect::<Result<Vec<_>, _>>()?;
314352
Arc::new(StructArray::new(fields.clone(), arrays, nulls))
315353
}
354+
Self::Map(map_field, k_off, m_off, kdata, valdec) => {
355+
let moff = flush_offsets(m_off);
356+
let koff = flush_offsets(k_off);
357+
let kd = flush_values(kdata).into();
358+
let val_arr = valdec.flush(None)?;
359+
let key_arr = StringArray::new(koff, kd, None);
360+
if key_arr.len() != val_arr.len() {
361+
return Err(ArrowError::InvalidArgumentError(format!(
362+
"Map keys length ({}) != map values length ({})",
363+
key_arr.len(),
364+
val_arr.len()
365+
)));
366+
}
367+
let final_len = moff.len() - 1;
368+
if let Some(n) = &nulls {
369+
if n.len() != final_len {
370+
return Err(ArrowError::InvalidArgumentError(format!(
371+
"Map array null buffer length {} != final map length {final_len}",
372+
n.len()
373+
)));
374+
}
375+
}
376+
let entries_struct = StructArray::new(
377+
Fields::from(vec![
378+
Arc::new(ArrowField::new("key", DataType::Utf8, false)),
379+
Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)),
380+
]),
381+
vec![Arc::new(key_arr), val_arr],
382+
None,
383+
);
384+
let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false);
385+
Arc::new(map_arr)
386+
}
316387
})
317388
}
318389
}
319390

391+
fn read_map_blocks(
392+
buf: &mut AvroCursor,
393+
decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
394+
) -> Result<usize, ArrowError> {
395+
read_blockwise_items(buf, true, decode_entry)
396+
}
397+
398+
fn read_blockwise_items(
399+
buf: &mut AvroCursor,
400+
read_size_after_negative: bool,
401+
mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
402+
) -> Result<usize, ArrowError> {
403+
let mut total = 0usize;
404+
loop {
405+
// Read the block count
406+
// positive = that many items
407+
// negative = that many items + read block size
408+
// See: https://avro.apache.org/docs/1.11.1/specification/#maps
409+
let block_count = buf.get_long()?;
410+
match block_count.cmp(&0) {
411+
Ordering::Equal => break,
412+
Ordering::Less => {
413+
// If block_count is negative, read the absolute value of count,
414+
// then read the block size as a long and discard
415+
let count = (-block_count) as usize;
416+
if read_size_after_negative {
417+
let _size_in_bytes = buf.get_long()?;
418+
}
419+
for _ in 0..count {
420+
decode_fn(buf)?;
421+
}
422+
total += count;
423+
}
424+
Ordering::Greater => {
425+
// If block_count is positive, decode that many items
426+
let count = block_count as usize;
427+
for _i in 0..count {
428+
decode_fn(buf)?;
429+
}
430+
total += count;
431+
}
432+
}
433+
}
434+
Ok(total)
435+
}
436+
320437
#[inline]
321438
fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
322439
std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
@@ -336,3 +453,82 @@ fn flush_primitive<T: ArrowPrimitiveType>(
336453
}
337454

338455
const DEFAULT_CAPACITY: usize = 1024;
456+
457+
#[cfg(test)]
458+
mod tests {
459+
use super::*;
460+
use arrow_array::{
461+
cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray,
462+
IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray,
463+
};
464+
465+
fn encode_avro_long(value: i64) -> Vec<u8> {
466+
let mut buf = Vec::new();
467+
let mut v = (value << 1) ^ (value >> 63);
468+
while v & !0x7F != 0 {
469+
buf.push(((v & 0x7F) | 0x80) as u8);
470+
v >>= 7;
471+
}
472+
buf.push(v as u8);
473+
buf
474+
}
475+
476+
fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
477+
let mut buf = encode_avro_long(bytes.len() as i64);
478+
buf.extend_from_slice(bytes);
479+
buf
480+
}
481+
482+
fn avro_from_codec(codec: Codec) -> AvroDataType {
483+
AvroDataType::new(codec, Default::default(), None)
484+
}
485+
486+
#[test]
487+
fn test_map_decoding_one_entry() {
488+
let value_type = avro_from_codec(Codec::Utf8);
489+
let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
490+
let mut decoder = Decoder::try_new(&map_type).unwrap();
491+
// Encode a single map with one entry: {"hello": "world"}
492+
let mut data = Vec::new();
493+
data.extend_from_slice(&encode_avro_long(1));
494+
data.extend_from_slice(&encode_avro_bytes(b"hello")); // key
495+
data.extend_from_slice(&encode_avro_bytes(b"world")); // value
496+
data.extend_from_slice(&encode_avro_long(0));
497+
let mut cursor = AvroCursor::new(&data);
498+
decoder.decode(&mut cursor).unwrap();
499+
let array = decoder.flush(None).unwrap();
500+
let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
501+
assert_eq!(map_arr.len(), 1); // one map
502+
assert_eq!(map_arr.value_length(0), 1);
503+
let entries = map_arr.value(0);
504+
let struct_entries = entries.as_any().downcast_ref::<StructArray>().unwrap();
505+
assert_eq!(struct_entries.len(), 1);
506+
let key_arr = struct_entries
507+
.column_by_name("key")
508+
.unwrap()
509+
.as_any()
510+
.downcast_ref::<StringArray>()
511+
.unwrap();
512+
let val_arr = struct_entries
513+
.column_by_name("value")
514+
.unwrap()
515+
.as_any()
516+
.downcast_ref::<StringArray>()
517+
.unwrap();
518+
assert_eq!(key_arr.value(0), "hello");
519+
assert_eq!(val_arr.value(0), "world");
520+
}
521+
522+
#[test]
523+
fn test_map_decoding_empty() {
524+
let value_type = avro_from_codec(Codec::Utf8);
525+
let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
526+
let mut decoder = Decoder::try_new(&map_type).unwrap();
527+
let data = encode_avro_long(0);
528+
decoder.decode(&mut AvroCursor::new(&data)).unwrap();
529+
let array = decoder.flush(None).unwrap();
530+
let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
531+
assert_eq!(map_arr.len(), 1);
532+
assert_eq!(map_arr.value_length(0), 0);
533+
}
534+
}

0 commit comments

Comments
 (0)