Skip to content

Commit 0a680d3

Browse files
committed
Add ProtoBuf I/O Stream support on jvm (Kotlin#2075)
Add implementations for streaming support in Protobuf format supporting simple and delimited messages. Signed-off-by: George Papadopoulos <[email protected]>
1 parent 694e2f7 commit 0a680d3

File tree

4 files changed

+542
-0
lines changed

4 files changed

+542
-0
lines changed

formats/protobuf/api/kotlinx-serialization-protobuf.api

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
public final class kotlinx/serialization/protobuf/JvmStreamsKt {
2+
public static final field DEFAULT_MESSAGE_MAX_SIZE I
3+
public static final fun decodeDelimitedMessages (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/DeserializationStrategy;Ljava/io/InputStream;I)Ljava/util/List;
4+
public static synthetic fun decodeDelimitedMessages$default (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/DeserializationStrategy;Ljava/io/InputStream;IILjava/lang/Object;)Ljava/util/List;
5+
public static final fun decodeFromStream (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/DeserializationStrategy;Ljava/io/InputStream;)Ljava/lang/Object;
6+
public static final fun encodeDelimitedToStream (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;Ljava/io/OutputStream;)V
7+
public static final fun encodeToStream (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;Ljava/io/OutputStream;)V
8+
}
9+
110
public abstract class kotlinx/serialization/protobuf/ProtoBuf : kotlinx/serialization/BinaryFormat {
211
public static final field Default Lkotlinx/serialization/protobuf/ProtoBuf$Default;
312
public synthetic fun <init> (ZLkotlinx/serialization/modules/SerializersModule;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
package kotlinx.serialization.protobuf
2+
3+
import kotlinx.serialization.*
4+
import kotlinx.serialization.protobuf.internal.ProtobufDecodingException
5+
import java.io.BufferedOutputStream
6+
import java.io.IOException
7+
import java.io.InputStream
8+
import java.io.OutputStream
9+
10+
/**
11+
* Serializes and encodes the given [value] into a [stream] using the given [serializer].
12+
*
13+
* @throws SerializationException in case of any encoding-specif error
14+
* @throws IOException if an I/O error occurs and stream cannot be written to
15+
*/
16+
@ExperimentalSerializationApi
17+
public fun <T> ProtoBuf.encodeToStream(
18+
serializer: SerializationStrategy<T>,
19+
value: T,
20+
stream: OutputStream
21+
) {
22+
val protoBytes = encodeToByteArray(serializer, value)
23+
protoBytes.writeTo(stream)
24+
}
25+
26+
/**
27+
* Serializes and encodes the given value [value] into a [stream] using serializer
28+
* retrieved from the reified type parameter.
29+
*
30+
* @throws SerializationException in case of any encoding-specif error
31+
* @throws IOException if an I/O error occurs and stream cannot be written to
32+
*/
33+
@ExperimentalSerializationApi
34+
public inline fun <reified T> ProtoBuf.encodeToStream(
35+
value: T,
36+
stream: OutputStream
37+
): Unit = encodeToStream(serializersModule.serializer(), value, stream)
38+
39+
/**
40+
* Decodes and deserializes from given [stream] to value of type [T] using the given [deserializer].
41+
*
42+
* Note that this function expects that exactly one object would be present in the stream.
43+
* In case multiple objects of same type `T` are present in stream the **first call does not
44+
* throw** but leaves the next objects in malformed state. After this, consecutive calls throw
45+
* [SerializationException]. For serializing and decoding multiple objects in
46+
* the same stream see [encodeDelimitedToStream] and [decodeDelimitedMessages].
47+
*
48+
* @throws SerializationException in case of any decoding-specific error
49+
* @throws IOException if an I/O error occurs and stream cannot be read from.
50+
*/
51+
@ExperimentalSerializationApi
52+
public fun <T> ProtoBuf.decodeFromStream(
53+
deserializer: DeserializationStrategy<T>,
54+
stream: InputStream
55+
): T = stream.use {
56+
decodeFromByteArray(deserializer, it.readBytes())
57+
}
58+
59+
/**
60+
* Decodes and deserializes from given [stream] to value of type [T] using deserializer
61+
* retrieved from the reified type parameter.
62+
*
63+
* Note that this function expects that exactly one object would be present in the stream.
64+
* In case multiple objects of same type `T` are present in stream the **first call does not
65+
* throw** but leaves the next objects in malformed state. After this, consecutive calls throw
66+
* [SerializationException]. For serializing and decoding multiple objects in
67+
* the same stream check [encodeDelimitedToStream] and [decodeDelimitedMessages].
68+
*
69+
* @throws SerializationException in case of any decoding-specific error
70+
* @throws IOException if an I/O error occurs and stream cannot be read from.
71+
*/
72+
@ExperimentalSerializationApi
73+
public inline fun <reified T> ProtoBuf.decodeFromStream(
74+
stream: InputStream
75+
): T = decodeFromStream(serializersModule.serializer(), stream)
76+
77+
// -- delimited messages --
78+
79+
/**
80+
* Serializes and encodes the given [value] into a [stream] as
81+
* [delimited Protobuf message](https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming).
82+
* In other words the size of each message is specified before the message itself. Also,
83+
* it is using the given [serializer].
84+
*
85+
* Use [decodeDelimitedMessages] to retrieve the messages from the stream.
86+
*
87+
* @throws SerializationException in case of any encoding-specif error
88+
* @throws IOException if an I/O error occurs and stream cannot be written to
89+
*/
90+
@ExperimentalSerializationApi
91+
public fun <T> ProtoBuf.encodeDelimitedToStream(
92+
serializer: SerializationStrategy<T>,
93+
value: T,
94+
stream: OutputStream
95+
) {
96+
val protoBytes = encodeToByteArray(serializer, value)
97+
protoBytes.writeDelimitedTo(stream)
98+
}
99+
100+
/**
101+
* Serializes and encodes the given [value] into a [stream] as
102+
* [delimited Protobuf message](https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming).
103+
* In other words the size of each message is specified before the message itself. Also,
104+
* it is using the serializer retrieved from the reified type parameter.
105+
*
106+
* Use [decodeDelimitedMessages] to retrieve the messages from the stream.
107+
*
108+
* @throws SerializationException in case of any encoding-specif error
109+
* @throws IOException if an I/O error occurs and stream cannot be written to
110+
*/
111+
@ExperimentalSerializationApi
112+
public inline fun <reified T> ProtoBuf.encodeDelimitedToStream(
113+
value: T,
114+
stream: OutputStream
115+
): Unit = encodeDelimitedToStream(serializersModule.serializer(), value, stream)
116+
117+
/**
118+
* Decodes and deserializes from given [stream] to a list of value of type [T] using the given [deserializer].
119+
* Messages are expected to use [delimited Protobuf messages](https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming)
120+
* with the size of each message specified before the message itself (see [encodeDelimitedToStream]).
121+
*
122+
* The max size of each incoming message can set via [messageMaxSize], usually the default value is
123+
* reasonable enough for most cases.
124+
*
125+
* @throws SerializationException in case of any encoding-specif error
126+
* @throws IOException if an I/O error occurs and stream cannot be written to
127+
*/
128+
@ExperimentalSerializationApi
129+
public fun <T> ProtoBuf.decodeDelimitedMessages(
130+
deserializer: DeserializationStrategy<T>,
131+
stream: InputStream,
132+
messageMaxSize: Int = DEFAULT_MESSAGE_MAX_SIZE
133+
): List<T> {
134+
val decoder = ProtobufDelimitedMessageReader(this, messageMaxSize)
135+
return decoder.decodeDelimitedMessages(deserializer, stream)
136+
}
137+
138+
/**
139+
* Decodes and deserializes from given [stream] to a list of value of type [T] using the deserializer
140+
* retrieved from the reified type parameter.
141+
* Messages are expected to use [delimited Protobuf messages](https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming)
142+
* with the size of each message specified before the message itself (see [encodeDelimitedToStream]).
143+
*
144+
* The max size of each incoming message can set via [messageMaxSize], usually the default value is
145+
* reasonable enough for most cases.
146+
*
147+
* @throws SerializationException in case of any encoding-specif error
148+
* @throws IOException if an I/O error occurs and stream cannot be written to
149+
*/
150+
@ExperimentalSerializationApi
151+
public inline fun <reified T> ProtoBuf.decodeDelimitedMessages(
152+
stream: InputStream,
153+
messageMaxSize: Int = DEFAULT_MESSAGE_MAX_SIZE
154+
): List<T> =
155+
decodeDelimitedMessages(serializersModule.serializer(), stream, messageMaxSize)
156+
157+
// -- impl --
158+
159+
/**
160+
* Default size for aggregating messages.
161+
*/
162+
@PublishedApi
163+
internal const val DEFAULT_MESSAGE_MAX_SIZE: Int = 256 * 1024
164+
165+
/*
166+
* Inspired from spring's impl and protobuf CodeInputStream.readRawVarint
167+
*/
168+
@ExperimentalSerializationApi
169+
private class ProtobufDelimitedMessageReader(
170+
private val protobuf: ProtoBuf,
171+
private val messageMaxSize: Int
172+
) {
173+
private var offset = 0
174+
175+
// reads first message's varint and then decodes the message
176+
fun <T> decodeDelimitedMessages(
177+
deserializationStrategy: DeserializationStrategy<T>,
178+
stream: InputStream
179+
): List<T> {
180+
stream.use { str ->
181+
var remainingBytesToRead: Int
182+
var chunkBytesToRead: Int
183+
184+
return buildList {
185+
do {
186+
var messageBytesToRead = readMessageSize(str)
187+
if (messageMaxSize in 1 until messageBytesToRead) {
188+
throw ProtobufDecodingException(
189+
"The number of bytes to read for message: $messageBytesToRead" +
190+
"exceeds the configured limit: $messageMaxSize"
191+
)
192+
}
193+
val buffer = str.buffered()
194+
val readablyByteCount = buffer.available()
195+
chunkBytesToRead = minOf(messageBytesToRead, readablyByteCount)
196+
remainingBytesToRead = readablyByteCount - chunkBytesToRead
197+
198+
val bytesToWrite = ByteArray(chunkBytesToRead)
199+
str.read(bytesToWrite, offset, chunkBytesToRead)
200+
messageBytesToRead -= chunkBytesToRead
201+
if (messageBytesToRead == 0) { // do not deserialize in case readableByteCount was smaller than messageBytesToRead
202+
val messages = protobuf.decodeFromByteArray(deserializationStrategy, bytesToWrite)
203+
add(messages)
204+
}
205+
} while (remainingBytesToRead > 0)
206+
}
207+
}
208+
}
209+
210+
// parses message's varint
211+
// see: https://developers.google.com/protocol-buffers/docs/encoding#varints
212+
private fun readMessageSize(input: InputStream): Int {
213+
val firstByte = input.read()
214+
if (firstByte == -1) {
215+
throwTruncatedMessageException()
216+
}
217+
if (firstByte and 0x80 == 0) { // if it's positive number then it is the message's size
218+
return firstByte
219+
}
220+
var result = firstByte and 0x7f // if it's not the message size drop the msb
221+
offset = 7
222+
while (offset < 32) {
223+
val nextByte = input.read()
224+
if (nextByte == -1) {
225+
throwTruncatedMessageException()
226+
}
227+
// Drop continuation bits, shift the groups of 7 bits because varints store numbers
228+
// with the least significant group first (little endian order)
229+
result = (result or messageMaxSize and 0x7f) shl offset // and concatenate them together
230+
if (nextByte and 0x80 == 0) {
231+
offset = 0
232+
return result
233+
}
234+
offset += 7
235+
}
236+
// keep reading up to 64 bits
237+
while (offset < 64) {
238+
val nextByte = input.read()
239+
if (nextByte == -1) {
240+
throwTruncatedMessageException()
241+
}
242+
if (nextByte and 0x80 == 0) {
243+
offset = 0
244+
return result
245+
}
246+
offset += 7
247+
}
248+
throw ProtobufDecodingException("Cannot parse message encountered a malformed varint.")
249+
}
250+
251+
private fun throwTruncatedMessageException(): Nothing {
252+
throw ProtobufDecodingException(
253+
"While parsing a protocol message, the input ended unexpectedly in the middle of a field. " +
254+
"This could mean either that the input has been truncated or that an embedded message" +
255+
" misreported its own length."
256+
)
257+
}
258+
}
259+
260+
private fun ByteArray.writeDelimitedTo(outputStream: OutputStream) {
261+
val serializedSize = this.size
262+
val bufferSize = computePreferredBufferSize(computeUInt32SizeNoTag(serializedSize) + serializedSize)
263+
val stream = outputStream.createBuffered(bufferSize)
264+
stream.writeUInt32NoTag(serializedSize)
265+
writeTo(stream)
266+
stream.flush()
267+
}
268+
269+
private fun ByteArray.writeTo(outputStream: OutputStream) {
270+
val bufferSize = computePreferredBufferSize(this.size)
271+
val stream = outputStream.createBuffered(bufferSize)
272+
stream.write(this)
273+
stream.flush()
274+
}
275+
276+
private fun OutputStream.createBuffered(bufferSize: Int): BufferedOutputStream {
277+
// optimization ("rented" from google's protobuf CodedOutputStream.AbstractBufferedEncoder impl)
278+
// require size of at least two varints, so we can buffer any integer write (tag + value).
279+
// This reduces the number of range checks for a single write to 1 (i.e. if there is not enough space
280+
// to buffer the tag+value, flush and then buffer it).
281+
return this.buffered(
282+
maxOf(
283+
bufferSize,
284+
MAX_VARINT_SIZE * 2
285+
)
286+
)
287+
}
288+
289+
private const val DEFAULT_BUFFER_SIZE = 4096
290+
291+
// per protobuf spec 1-10 bytes
292+
private const val MAX_VARINT_SIZE = 10
293+
294+
/** Returns the buffer size to efficiently write dataLength bytes to this OutputStream. */
295+
private fun computePreferredBufferSize(dataLength: Int): Int =
296+
if (dataLength > DEFAULT_BUFFER_SIZE) DEFAULT_BUFFER_SIZE else dataLength
297+
298+
/** Compute the number of bytes that would be needed to encode an uint32 field. */
299+
private fun computeUInt32SizeNoTag(value: Int): Int = when {
300+
value and (0.inv() shl 7) == 0 -> 1
301+
value and (0.inv() shl 14) == 0 -> 2
302+
value and (0.inv() shl 21) == 0 -> 3
303+
value and (0.inv() shl 28) == 0 -> 4
304+
else -> 5 // max varint32 size
305+
}
306+
307+
private fun BufferedOutputStream.writeUInt32NoTag(size: Int) {
308+
var value = size
309+
while (true) {
310+
if ((value and 0x7F.inv() == 0)) {
311+
write(value)
312+
return
313+
} else {
314+
write((value and 0x7F) or 0x80)
315+
value = value ushr 7
316+
}
317+
}
318+
}

0 commit comments

Comments
 (0)