Skip to content

Commit d63e6c6

Browse files
committed
Proof of concept: Add implementation for computing the required bytes to encode a message
This commit introduces a new API which computes the required bytes to encode a message without actually serializing a message. By extension, this API is meant to be used as a cornerstone for implementing Kotlin#2075. Signed-off-by: George Papadopoulos <[email protected]>
1 parent 694e2f7 commit d63e6c6

File tree

10 files changed

+1256
-26
lines changed

10 files changed

+1256
-26
lines changed

formats/protobuf/api/kotlinx-serialization-protobuf.api

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public final class kotlinx/serialization/protobuf/ProtoBufKt {
2121
public static synthetic fun ProtoBuf$default (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lkotlinx/serialization/protobuf/ProtoBuf;
2222
}
2323

24+
public final class kotlinx/serialization/protobuf/ProtoBufSerializedSizeCalculatorKt {
25+
public static final fun getOrComputeSerializedSize (Lkotlinx/serialization/protobuf/ProtoBuf;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)I
26+
}
27+
2428
public final class kotlinx/serialization/protobuf/ProtoIntegerType : java/lang/Enum {
2529
public static final field DEFAULT Lkotlinx/serialization/protobuf/ProtoIntegerType;
2630
public static final field FIXED Lkotlinx/serialization/protobuf/ProtoIntegerType;

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoBufSerializedSize.kt

+589
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package kotlinx.serialization.protobuf.internal
2+
3+
internal fun computeEnumSizeNoTag(value: Int): Int = computeInt32SizeNoTag(value)
4+
5+
internal fun computeByteArraySizeNoTag(value: ByteArray): Int = computeLengthDelimitedFieldSize(value.size)
6+
7+
internal fun computeStringSizeNoTag(value: String): Int {
8+
// java's implementation uses a custom method for some optimizations.
9+
return computeLengthDelimitedFieldSize(value.length)
10+
}
11+
12+
internal fun computeLengthDelimitedFieldSize(length: Int): Int = computeUInt32SizeNoTag(length) + length
13+
14+
//TODO: should this also be named "compute" for consistency?
15+
internal fun getFixed64SizeNoTag(): Int = FIXED64_SIZE
16+
internal fun computeSInt64SizeNoTag(value: Long): Int = computeUInt64SizeNoTag(encodeZigZag64(value))
17+
internal fun computeInt64SizeNoTag(value: Long): Int = computeUInt64SizeNoTag(value)
18+
19+
//TODO: should this also be named "compute" for consistency?
20+
internal fun getFixed32SizeNoTag() = FIXED32_SIZE
21+
internal fun computeSInt32SizeNoTag(value: Int) = computeUInt32SizeNoTag((encodeZigZag32(value)))
22+
internal fun computeInt32SizeNoTag(value: Int) =
23+
if (value >= 0) computeUInt32SizeNoTag(value) else MAX_VARINT_SIZE
24+
25+
/** Compute the number of bytes that would be needed to encode an uint32 field. */
26+
internal fun computeUInt32SizeNoTag(value: Int): Int = when {
27+
value and (0.inv() shl 7) == 0 -> 1
28+
value and (0.inv() shl 14) == 0 -> 2
29+
value and (0.inv() shl 21) == 0 -> 3
30+
value and (0.inv() shl 28) == 0 -> 4
31+
else -> 5 // max varint32 size
32+
}
33+
34+
private fun computeUInt64SizeNoTag(value: Long): Int {
35+
var _value = value
36+
// handle first two most common cases
37+
if ((_value and (0L.inv() shl 7)) == 0L) {
38+
return 1
39+
}
40+
if (_value < 0L) {
41+
return 10
42+
}
43+
// rest cases
44+
var size = 2
45+
if ((_value and (0L.inv() shl 35)) != 0L) {
46+
size += 4
47+
_value = value ushr 28
48+
}
49+
if ((_value and (0L.inv() shl 21)) != 0L) {
50+
size += 2
51+
_value = value ushr 14
52+
}
53+
if ((_value and (0L.inv() shl 14)) != 0L) {
54+
size += 1
55+
}
56+
return size
57+
}
58+
59+
// helpers
60+
61+
// per protobuf spec 1-10 bytes
62+
internal const val MAX_VARINT_SIZE = 10
63+
64+
// after decoding the varint representing a field, the low 3 bits tell us the wire type,
65+
// and the rest of the integer tells us the field number.
66+
private const val TAG_TYPE_BITS = 3
67+
68+
/**
69+
* See [Scalar type values](https://developers.google.com/protocol-buffers/docs/proto#scalar).
70+
*/
71+
72+
private const val FIXED32_SIZE = 4
73+
private const val FIXED64_SIZE = 8
74+
75+
internal fun computeTagSize(protoId: Int): Int = computeUInt32SizeNoTag(makeTag(protoId, 0))
76+
private fun makeTag(protoId: Int, wireType: Int): Int = protoId shl TAG_TYPE_BITS or wireType
77+
78+
// stream utils
79+
80+
internal fun encodeZigZag64(value: Long): Long = (value shl 1) xor (value shr 63)
81+
82+
internal fun encodeZigZag32(value: Int): Int = ((value shl 1) xor (value shr 31))

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufTaggedEncoder.kt

+28-26
Original file line numberDiff line numberDiff line change
@@ -33,61 +33,61 @@ internal abstract class ProtobufTaggedEncoder : ProtobufTaggedBase(), Encoder, C
3333

3434
protected open fun encodeTaggedInline(tag: ProtoDesc, inlineDescriptor: SerialDescriptor): Encoder = this.apply { pushTag(tag) }
3535

36-
public final override fun encodeNull() {
36+
final override fun encodeNull() {
3737
if (nullableMode != NullableMode.ACCEPTABLE) {
3838
val message = when (nullableMode) {
3939
NullableMode.OPTIONAL -> "'null' is not supported for optional properties in ProtoBuf"
4040
NullableMode.COLLECTION -> "'null' is not supported for collection types in ProtoBuf"
4141
NullableMode.NOT_NULL -> "'null' is not allowed for not-null properties"
42-
else -> "'null' is not supported in ProtoBuf";
42+
else -> "'null' is not supported in ProtoBuf"
4343
}
4444
throw SerializationException(message)
4545
}
4646
}
4747

48-
public final override fun encodeBoolean(value: Boolean) {
48+
final override fun encodeBoolean(value: Boolean) {
4949
encodeTaggedBoolean(popTagOrDefault(), value)
5050
}
5151

52-
public final override fun encodeByte(value: Byte) {
52+
final override fun encodeByte(value: Byte) {
5353
encodeTaggedByte(popTagOrDefault(), value)
5454
}
5555

56-
public final override fun encodeShort(value: Short) {
56+
final override fun encodeShort(value: Short) {
5757
encodeTaggedShort(popTagOrDefault(), value)
5858
}
5959

60-
public final override fun encodeInt(value: Int) {
60+
final override fun encodeInt(value: Int) {
6161
encodeTaggedInt(popTagOrDefault(), value)
6262
}
6363

64-
public final override fun encodeLong(value: Long) {
64+
final override fun encodeLong(value: Long) {
6565
encodeTaggedLong(popTagOrDefault(), value)
6666
}
6767

68-
public final override fun encodeFloat(value: Float) {
68+
final override fun encodeFloat(value: Float) {
6969
encodeTaggedFloat(popTagOrDefault(), value)
7070
}
7171

72-
public final override fun encodeDouble(value: Double) {
72+
final override fun encodeDouble(value: Double) {
7373
encodeTaggedDouble(popTagOrDefault(), value)
7474
}
7575

76-
public final override fun encodeChar(value: Char) {
76+
final override fun encodeChar(value: Char) {
7777
encodeTaggedChar(popTagOrDefault(), value)
7878
}
7979

80-
public final override fun encodeString(value: String) {
80+
final override fun encodeString(value: String) {
8181
encodeTaggedString(popTagOrDefault(), value)
8282
}
8383

84-
public final override fun encodeEnum(
84+
final override fun encodeEnum(
8585
enumDescriptor: SerialDescriptor,
8686
index: Int
8787
): Unit = encodeTaggedEnum(popTagOrDefault(), enumDescriptor, index)
8888

8989

90-
public final override fun endStructure(descriptor: SerialDescriptor) {
90+
final override fun endStructure(descriptor: SerialDescriptor) {
9191
if (stackSize >= 0) {
9292
popTag()
9393
}
@@ -96,46 +96,48 @@ internal abstract class ProtobufTaggedEncoder : ProtobufTaggedBase(), Encoder, C
9696

9797
protected open fun endEncode(descriptor: SerialDescriptor) {}
9898

99-
public final override fun encodeBooleanElement(descriptor: SerialDescriptor, index: Int, value: Boolean): Unit =
99+
final override fun encodeBooleanElement(descriptor: SerialDescriptor, index: Int, value: Boolean): Unit =
100100
encodeTaggedBoolean(descriptor.getTag(index), value)
101101

102-
public final override fun encodeByteElement(descriptor: SerialDescriptor, index: Int, value: Byte): Unit =
102+
final override fun encodeByteElement(descriptor: SerialDescriptor, index: Int, value: Byte): Unit =
103103
encodeTaggedByte(descriptor.getTag(index), value)
104104

105-
public final override fun encodeShortElement(descriptor: SerialDescriptor, index: Int, value: Short): Unit =
105+
final override fun encodeShortElement(descriptor: SerialDescriptor, index: Int, value: Short): Unit =
106106
encodeTaggedShort(descriptor.getTag(index), value)
107107

108-
public final override fun encodeIntElement(descriptor: SerialDescriptor, index: Int, value: Int): Unit =
108+
final override fun encodeIntElement(descriptor: SerialDescriptor, index: Int, value: Int): Unit =
109109
encodeTaggedInt(descriptor.getTag(index), value)
110110

111-
public final override fun encodeLongElement(descriptor: SerialDescriptor, index: Int, value: Long): Unit =
111+
final override fun encodeLongElement(descriptor: SerialDescriptor, index: Int, value: Long): Unit =
112112
encodeTaggedLong(descriptor.getTag(index), value)
113113

114-
public final override fun encodeFloatElement(descriptor: SerialDescriptor, index: Int, value: Float): Unit =
114+
final override fun encodeFloatElement(descriptor: SerialDescriptor, index: Int, value: Float): Unit =
115115
encodeTaggedFloat(descriptor.getTag(index), value)
116116

117-
public final override fun encodeDoubleElement(descriptor: SerialDescriptor, index: Int, value: Double): Unit =
117+
final override fun encodeDoubleElement(descriptor: SerialDescriptor, index: Int, value: Double): Unit =
118118
encodeTaggedDouble(descriptor.getTag(index), value)
119119

120-
public final override fun encodeCharElement(descriptor: SerialDescriptor, index: Int, value: Char): Unit =
120+
final override fun encodeCharElement(descriptor: SerialDescriptor, index: Int, value: Char): Unit =
121121
encodeTaggedChar(descriptor.getTag(index), value)
122122

123-
public final override fun encodeStringElement(descriptor: SerialDescriptor, index: Int, value: String): Unit =
123+
final override fun encodeStringElement(descriptor: SerialDescriptor, index: Int, value: String): Unit =
124124
encodeTaggedString(descriptor.getTag(index), value)
125125

126-
public final override fun <T : Any?> encodeSerializableElement(
126+
final override fun <T : Any?> encodeSerializableElement(
127127
descriptor: SerialDescriptor,
128128
index: Int,
129129
serializer: SerializationStrategy<T>,
130130
value: T
131131
) {
132132
nullableMode = NullableMode.NOT_NULL
133-
134-
pushTag(descriptor.getTag(index))
133+
val tag = descriptor.getTag(index)
134+
println("will push tag:$tag in stack")
135+
pushTag(tag)
136+
println("total stackSize:${stackSize + 1}")
135137
encodeSerializableValue(serializer, value)
136138
}
137139

138-
public final override fun <T : Any> encodeNullableSerializableElement(
140+
final override fun <T : Any> encodeNullableSerializableElement(
139141
descriptor: SerialDescriptor,
140142
index: Int,
141143
serializer: SerializationStrategy<T>,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package kotlinx.serialization.protobuf
2+
3+
import kotlinx.serialization.descriptors.SerialDescriptor
4+
5+
internal actual fun createSerializedSizeCache(): SerializedSizeCache = JsHashMap()
6+
7+
private class JsHashMap : SerializedSizeCache {
8+
private val cache = mutableMapOf<SerialDescriptor, SerializedData>()
9+
10+
override fun get(descriptor: SerialDescriptor, key: SerializedSizeCacheKey): Int? = cache[descriptor]?.get(key)
11+
12+
override fun set(descriptor: SerialDescriptor, key: SerializedSizeCacheKey, serializedSize: Int) {
13+
cache[descriptor] = mapOf(key to serializedSize)
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package kotlinx.serialization.protobuf
2+
3+
import kotlinx.serialization.descriptors.SerialDescriptor
4+
import java.util.concurrent.ConcurrentHashMap
5+
6+
internal actual fun createSerializedSizeCache(): SerializedSizeCache {
7+
return ConcurrentHashMapSerializedCache()
8+
}
9+
10+
private class ConcurrentHashMapSerializedCache : SerializedSizeCache {
11+
private val cache = ConcurrentHashMap<SerialDescriptor, SerializedData>()
12+
13+
override fun get(descriptor: SerialDescriptor, key: SerializedSizeCacheKey): Int? = cache[descriptor]?.get(key)
14+
15+
override fun set(descriptor: SerialDescriptor, key: SerializedSizeCacheKey, serializedSize: Int) {
16+
cache[descriptor] = mapOf(key to serializedSize)
17+
}
18+
}

0 commit comments

Comments
 (0)