Skip to content

Commit 7368299

Browse files
authored
Fix serialization of repeated fields with empty messages (#180)
Extend test config and utils to support exclusion of certain json samples from testing for symetry.
1 parent deb623e commit 7368299

File tree

8 files changed

+78
-25
lines changed

8 files changed

+78
-25
lines changed

src/betterproto/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,9 +739,18 @@ def __bytes__(self) -> bytes:
739739
output += _serialize_single(meta.number, TYPE_BYTES, buf)
740740
else:
741741
for item in value:
742-
output += _serialize_single(
743-
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
742+
output += (
743+
_serialize_single(
744+
meta.number,
745+
meta.proto_type,
746+
item,
747+
wraps=meta.wraps or "",
748+
)
749+
# if it's an empty message it still needs to be represented
750+
# as an item in the repeated list
751+
or b"\n\x00"
744752
)
753+
745754
elif isinstance(value, dict):
746755
for k, v in value.items():
747756
assert meta.map_types

tests/inputs/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@
1919
"example_service",
2020
"empty_service",
2121
}
22+
23+
24+
# Indicate json sample messages to skip when testing that json (de)serialization
25+
# is symmetrical becuase some cases legitimately are not symmetrical.
26+
# Each key references the name of the test scenario and the values in the tuple
27+
# Are the names of the json files.
28+
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"msg": [{"values":[]}]
3+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
message MessageA {
4+
repeated float values = 1;
5+
}
6+
7+
message Test {
8+
repeated MessageA msg = 1;
9+
}

tests/inputs/oneof/test_oneof.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
def test_which_count():
77
message = Test()
8-
message.from_json(get_test_case_json_data("oneof")[0])
8+
message.from_json(get_test_case_json_data("oneof")[0].json)
99
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
1010

1111

1212
def test_which_name():
1313
message = Test()
14-
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
14+
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
1515
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")

tests/inputs/oneof_enum/test_oneof_enum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value():
1515
"""
1616
message = Test()
1717
message.from_json(
18-
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
18+
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
1919
)
2020

2121
assert message.move == Move(
@@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
3131
"""
3232
message = Test()
3333
message.from_json(
34-
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
34+
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
3535
)
3636
assert message.move == Move(
3737
x=0, y=0
@@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
4242

4343
def test_which_one_of_returns_second_field_when_set():
4444
message = Test()
45-
message.from_json(get_test_case_json_data("oneof_enum")[0])
45+
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
4646
assert message.move == Move(x=2, y=3)
4747
assert message.signal == Signal.PASS
4848
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

tests/test_inputs.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
from collections import namedtuple
77
from types import ModuleType
8-
from typing import Any, Dict, List, Set
8+
from typing import Any, Dict, List, Set, Tuple
99

1010
import pytest
1111

@@ -29,7 +29,12 @@
2929

3030

3131
class TestCases:
32-
def __init__(self, path, services: Set[str], xfail: Set[str]):
32+
def __init__(
33+
self,
34+
path,
35+
services: Set[str],
36+
xfail: Set[str],
37+
):
3338
_all = set(get_directories(path)) - {"__pycache__"}
3439
_services = services
3540
_messages = (_all - services) - {"__pycache__"}
@@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None:
175180
plugin_module, _, json_data = test_data
176181

177182
for _ in range(repeat):
178-
for json_sample in json_data:
183+
for sample in json_data:
184+
if sample.belongs_to(test_input_config.non_symmetrical_json):
185+
continue
186+
179187
message: betterproto.Message = plugin_module.Test()
180188

181-
message.from_json(json_sample)
189+
message.from_json(sample.json)
182190
message_json = message.to_json(0)
183191

184-
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
185-
json.loads(json_sample)
186-
)
192+
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
193+
json.loads(sample.json)
194+
)
187195

188196

189197
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
195203
def test_binary_compatibility(repeat, test_data: TestData) -> None:
196204
plugin_module, reference_module, json_data = test_data
197205

198-
for json_sample in json_data:
199-
reference_instance = Parse(json_sample, reference_module().Test())
206+
for sample in json_data:
207+
reference_instance = Parse(sample.json, reference_module().Test())
200208
reference_binary_output = reference_instance.SerializeToString()
201209

202210
for _ in range(repeat):
203211
plugin_instance_from_json: betterproto.Message = (
204-
plugin_module.Test().from_json(json_sample)
212+
plugin_module.Test().from_json(sample.json)
205213
)
206214
plugin_instance_from_binary = plugin_module.Test.FromString(
207215
reference_binary_output

tests/util.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
2+
from dataclasses import dataclass
23
import importlib
34
import os
4-
import pathlib
5-
import sys
65
from pathlib import Path
6+
import sys
77
from types import ModuleType
8-
from typing import Callable, Generator, List, Optional, Union
8+
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
99

1010
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
1111

@@ -47,11 +47,24 @@ async def protoc(
4747
return stdout, stderr, proc.returncode
4848

4949

50-
def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
50+
@dataclass
51+
class TestCaseJsonFile:
52+
json: str
53+
test_name: str
54+
file_name: str
55+
56+
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
57+
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
58+
59+
60+
def get_test_case_json_data(
61+
test_case_name: str, *json_file_names: str
62+
) -> List[TestCaseJsonFile]:
5163
"""
5264
:return:
53-
A list of all files found in "inputs_path/test_case_name" with names matching
54-
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
65+
A list of all files found in "{inputs_path}/test_case_name" with names matching
66+
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
67+
json_file_names
5568
"""
5669
test_case_dir = inputs_path.joinpath(test_case_name)
5770
possible_file_paths = [
@@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[
6578
if not test_data_file_path.exists():
6679
continue
6780
with test_data_file_path.open("r") as fh:
68-
result.append(fh.read())
81+
result.append(
82+
TestCaseJsonFile(
83+
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
84+
)
85+
)
6986

7087
return result
7188

@@ -86,7 +103,7 @@ def find_module(
86103
if predicate(module):
87104
return module
88105

89-
module_path = pathlib.Path(*module.__path__)
106+
module_path = Path(*module.__path__)
90107

91108
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
92109
if sub == module_path:

0 commit comments

Comments
 (0)