Skip to content

Commit 6134850

Browse files
committed
add avroTupleSerializationSchema
1 parent de8810a commit 6134850

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package com.dtstack.flink.sql.sink.kafka.serialization;
20+
21+
import com.dtstack.flink.sql.enums.EUpdateMode;
22+
import org.apache.avro.LogicalType;
23+
import org.apache.avro.LogicalTypes;
24+
import org.apache.avro.Schema;
25+
import org.apache.avro.SchemaParseException;
26+
import org.apache.avro.generic.GenericData;
27+
import org.apache.avro.generic.GenericDatumWriter;
28+
import org.apache.avro.generic.GenericRecord;
29+
import org.apache.avro.generic.IndexedRecord;
30+
import org.apache.avro.io.DatumWriter;
31+
import org.apache.avro.io.Encoder;
32+
import org.apache.avro.io.EncoderFactory;
33+
import org.apache.avro.specific.SpecificData;
34+
import org.apache.avro.specific.SpecificDatumWriter;
35+
import org.apache.avro.specific.SpecificRecord;
36+
import org.apache.avro.util.Utf8;
37+
import org.apache.commons.lang3.StringUtils;
38+
import org.apache.flink.api.common.serialization.SerializationSchema;
39+
import org.apache.flink.api.java.tuple.Tuple2;
40+
import org.apache.flink.types.Row;
41+
import org.apache.flink.util.Preconditions;
42+
43+
import java.io.ByteArrayOutputStream;
44+
import java.io.IOException;
45+
import java.io.ObjectInputStream;
46+
import java.io.ObjectOutputStream;
47+
import java.math.BigDecimal;
48+
import java.nio.ByteBuffer;
49+
import java.sql.Date;
50+
import java.sql.Time;
51+
import java.sql.Timestamp;
52+
import java.util.HashMap;
53+
import java.util.List;
54+
import java.util.Map;
55+
import java.util.Objects;
56+
import java.util.TimeZone;
57+
import java.util.stream.Collectors;
58+
59+
/**
60+
* Serialization schema that serializes CROW into Avro bytes.
61+
*
62+
* <p>Serializes objects that are represented in (nested) Flink rows. It support types that
63+
* are compatible with Flink's Table & SQL API.
64+
**
65+
* @author maqi
66+
*/
67+
public class AvroTuple2SerializationSchema implements SerializationSchema<Tuple2<Boolean,Row>> {
68+
69+
/**
70+
* Used for time conversions from SQL types.
71+
*/
72+
private static final TimeZone LOCAL_TZ = TimeZone.getDefault();
73+
74+
/**
75+
* Avro record class for serialization. Might be null if record class is not available.
76+
*/
77+
private Class<? extends SpecificRecord> recordClazz;
78+
79+
/**
80+
* Schema string for deserialization.
81+
*/
82+
private String schemaString;
83+
84+
/**
85+
* Avro serialization schema.
86+
*/
87+
private transient Schema schema;
88+
89+
/**
90+
* Writer to serialize Avro record into a byte array.
91+
*/
92+
private transient DatumWriter<IndexedRecord> datumWriter;
93+
94+
/**
95+
* Output stream to serialize records into byte array.
96+
*/
97+
private transient ByteArrayOutputStream arrayOutputStream;
98+
99+
/**
100+
* Low-level class for serialization of Avro values.
101+
*/
102+
private transient Encoder encoder;
103+
104+
private String updateMode;
105+
106+
private String retractKey = "retract";
107+
108+
/**
109+
* Creates an Avro serialization schema for the given specific record class.
110+
*
111+
* @param recordClazz Avro record class used to serialize Flink's row to Avro's record
112+
*/
113+
public AvroTuple2SerializationSchema(Class<? extends SpecificRecord> recordClazz, String updateMode) {
114+
Preconditions.checkNotNull(recordClazz, "Avro record class must not be null.");
115+
this.recordClazz = recordClazz;
116+
this.schema = SpecificData.get().getSchema(recordClazz);
117+
this.schemaString = schema.toString();
118+
this.datumWriter = new SpecificDatumWriter<>(schema);
119+
this.arrayOutputStream = new ByteArrayOutputStream();
120+
this.encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, null);
121+
this.updateMode = updateMode;
122+
}
123+
124+
/**
125+
* Creates an Avro serialization schema for the given Avro schema string.
126+
*
127+
* @param avroSchemaString Avro schema string used to serialize Flink's row to Avro's record
128+
*/
129+
public AvroTuple2SerializationSchema(String avroSchemaString, String updateMode) {
130+
Preconditions.checkNotNull(avroSchemaString, "Avro schema must not be null.");
131+
this.recordClazz = null;
132+
this.schemaString = avroSchemaString;
133+
try {
134+
this.schema = new Schema.Parser().parse(avroSchemaString);
135+
} catch (SchemaParseException e) {
136+
throw new IllegalArgumentException("Could not parse Avro schema string.", e);
137+
}
138+
this.datumWriter = new GenericDatumWriter<>(schema);
139+
this.arrayOutputStream = new ByteArrayOutputStream();
140+
this.encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, null);
141+
this.updateMode = updateMode;
142+
}
143+
144+
@Override
145+
public byte[] serialize(Tuple2<Boolean,Row> tuple2) {
146+
try {
147+
Row row = tuple2.f1;
148+
boolean change = tuple2.f0;
149+
150+
// convert to record
151+
final GenericRecord record = convertRowToAvroRecord(schema, row);
152+
153+
dealRetractField(change, record);
154+
155+
arrayOutputStream.reset();
156+
datumWriter.write(record, encoder);
157+
encoder.flush();
158+
return arrayOutputStream.toByteArray();
159+
} catch (Exception e) {
160+
throw new RuntimeException("Failed to serialize row.", e);
161+
}
162+
}
163+
164+
protected void dealRetractField(boolean change, GenericRecord record) {
165+
schema.getFields()
166+
.stream()
167+
.filter(field -> StringUtils.equalsIgnoreCase(field.name(), retractKey))
168+
.findFirst()
169+
.ifPresent(field -> {
170+
if (StringUtils.equalsIgnoreCase(updateMode, EUpdateMode.UPSERT.name())) {
171+
record.put(retractKey, convertFlinkType(field.schema(), change));
172+
}
173+
});
174+
}
175+
176+
@Override
177+
public boolean equals(Object o) {
178+
if (this == o) {
179+
return true;
180+
}
181+
if (o == null || getClass() != o.getClass()) {
182+
return false;
183+
}
184+
final AvroTuple2SerializationSchema that = (AvroTuple2SerializationSchema) o;
185+
return Objects.equals(recordClazz, that.recordClazz) && Objects.equals(schemaString, that.schemaString);
186+
}
187+
188+
@Override
189+
public int hashCode() {
190+
return Objects.hash(recordClazz, schemaString);
191+
}
192+
193+
// --------------------------------------------------------------------------------------------
194+
195+
private GenericRecord convertRowToAvroRecord(Schema schema, Row row) {
196+
197+
final List<Schema.Field> fields = schema.getFields()
198+
.stream()
199+
.filter(field -> !StringUtils.equalsIgnoreCase(field.name(), retractKey))
200+
.collect(Collectors.toList());
201+
202+
final int length = fields.size();
203+
final GenericRecord record = new GenericData.Record(schema);
204+
for (int i = 0; i < length; i++) {
205+
final Schema.Field field = fields.get(i);
206+
record.put(i, convertFlinkType(field.schema(), row.getField(i)));
207+
}
208+
return record;
209+
}
210+
211+
private Object convertFlinkType(Schema schema, Object object) {
212+
if (object == null) {
213+
return null;
214+
}
215+
switch (schema.getType()) {
216+
case RECORD:
217+
if (object instanceof Row) {
218+
return convertRowToAvroRecord(schema, (Row) object);
219+
}
220+
throw new IllegalStateException("Row expected but was: " + object.getClass());
221+
case ENUM:
222+
return new GenericData.EnumSymbol(schema, object.toString());
223+
case ARRAY:
224+
final Schema elementSchema = schema.getElementType();
225+
final Object[] array = (Object[]) object;
226+
final GenericData.Array<Object> convertedArray = new GenericData.Array<>(array.length, schema);
227+
for (Object element : array) {
228+
convertedArray.add(convertFlinkType(elementSchema, element));
229+
}
230+
return convertedArray;
231+
case MAP:
232+
final Map<?, ?> map = (Map<?, ?>) object;
233+
final Map<Utf8, Object> convertedMap = new HashMap<>();
234+
for (Map.Entry<?, ?> entry : map.entrySet()) {
235+
convertedMap.put(
236+
new Utf8(entry.getKey().toString()),
237+
convertFlinkType(schema.getValueType(), entry.getValue()));
238+
}
239+
return convertedMap;
240+
case UNION:
241+
final List<Schema> types = schema.getTypes();
242+
final int size = types.size();
243+
final Schema actualSchema;
244+
if (size == 2 && types.get(0).getType() == Schema.Type.NULL) {
245+
actualSchema = types.get(1);
246+
} else if (size == 2 && types.get(1).getType() == Schema.Type.NULL) {
247+
actualSchema = types.get(0);
248+
} else if (size == 1) {
249+
actualSchema = types.get(0);
250+
} else {
251+
// generic type
252+
return object;
253+
}
254+
return convertFlinkType(actualSchema, object);
255+
case FIXED:
256+
// check for logical type
257+
if (object instanceof BigDecimal) {
258+
return new GenericData.Fixed(
259+
schema,
260+
convertFromDecimal(schema, (BigDecimal) object));
261+
}
262+
return new GenericData.Fixed(schema, (byte[]) object);
263+
case STRING:
264+
return new Utf8(object.toString());
265+
case BYTES:
266+
// check for logical type
267+
if (object instanceof BigDecimal) {
268+
return ByteBuffer.wrap(convertFromDecimal(schema, (BigDecimal) object));
269+
}
270+
return ByteBuffer.wrap((byte[]) object);
271+
case INT:
272+
// check for logical types
273+
if (object instanceof Date) {
274+
return convertFromDate(schema, (Date) object);
275+
} else if (object instanceof Time) {
276+
return convertFromTime(schema, (Time) object);
277+
}
278+
return object;
279+
case LONG:
280+
// check for logical type
281+
if (object instanceof Timestamp) {
282+
return convertFromTimestamp(schema, (Timestamp) object);
283+
}
284+
return object;
285+
case FLOAT:
286+
case DOUBLE:
287+
case BOOLEAN:
288+
return object;
289+
}
290+
throw new RuntimeException("Unsupported Avro type:" + schema);
291+
}
292+
293+
private byte[] convertFromDecimal(Schema schema, BigDecimal decimal) {
294+
final LogicalType logicalType = schema.getLogicalType();
295+
if (logicalType instanceof LogicalTypes.Decimal) {
296+
final LogicalTypes.Decimal decimalType = (LogicalTypes.Decimal) logicalType;
297+
// rescale to target type
298+
final BigDecimal rescaled = decimal.setScale(decimalType.getScale(), BigDecimal.ROUND_UNNECESSARY);
299+
// byte array must contain the two's-complement representation of the
300+
// unscaled integer value in big-endian byte order
301+
return decimal.unscaledValue().toByteArray();
302+
} else {
303+
throw new RuntimeException("Unsupported decimal type.");
304+
}
305+
}
306+
307+
private int convertFromDate(Schema schema, Date date) {
308+
final LogicalType logicalType = schema.getLogicalType();
309+
if (logicalType == LogicalTypes.date()) {
310+
// adopted from Apache Calcite
311+
final long time = date.getTime();
312+
final long converted = time + (long) LOCAL_TZ.getOffset(time);
313+
return (int) (converted / 86400000L);
314+
} else {
315+
throw new RuntimeException("Unsupported date type.");
316+
}
317+
}
318+
319+
private int convertFromTime(Schema schema, Time date) {
320+
final LogicalType logicalType = schema.getLogicalType();
321+
if (logicalType == LogicalTypes.timeMillis()) {
322+
// adopted from Apache Calcite
323+
final long time = date.getTime();
324+
final long converted = time + (long) LOCAL_TZ.getOffset(time);
325+
return (int) (converted % 86400000L);
326+
} else {
327+
throw new RuntimeException("Unsupported time type.");
328+
}
329+
}
330+
331+
private long convertFromTimestamp(Schema schema, Timestamp date) {
332+
final LogicalType logicalType = schema.getLogicalType();
333+
if (logicalType == LogicalTypes.timestampMillis()) {
334+
// adopted from Apache Calcite
335+
final long time = date.getTime();
336+
return time + (long) LOCAL_TZ.getOffset(time);
337+
} else {
338+
throw new RuntimeException("Unsupported timestamp type.");
339+
}
340+
}
341+
342+
private void writeObject(ObjectOutputStream outputStream) throws IOException {
343+
outputStream.writeObject(recordClazz);
344+
outputStream.writeObject(schemaString); // support for null
345+
outputStream.writeObject(retractKey);
346+
outputStream.writeObject(updateMode);
347+
}
348+
349+
@SuppressWarnings("unchecked")
350+
private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, IOException {
351+
recordClazz = (Class<? extends SpecificRecord>) inputStream.readObject();
352+
schemaString = (String) inputStream.readObject();
353+
if (recordClazz != null) {
354+
schema = SpecificData.get().getSchema(recordClazz);
355+
} else {
356+
schema = new Schema.Parser().parse(schemaString);
357+
}
358+
retractKey = (String) inputStream.readObject();
359+
updateMode = (String) inputStream.readObject();
360+
361+
datumWriter = new SpecificDatumWriter<>(schema);
362+
arrayOutputStream = new ByteArrayOutputStream();
363+
encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, null);
364+
}
365+
}

0 commit comments

Comments
 (0)