Skip to content

Commit 246bfe2

Browse files
authored
fix: HTTP body field messages with enums or recursive fields (googleapis#1201)
Minor fix for a generating unit tests for methods where the message that is the body field has a field that is an enum or a recursive field message type.
1 parent bd014ff commit 246bfe2

File tree

4 files changed

+123
-24
lines changed

4 files changed

+123
-24
lines changed

gapic/schema/wrappers.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,29 +97,42 @@ def map(self) -> bool:
9797

9898
@utils.cached_property
9999
def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, Dict[str, Any], List[Any], None]:
100-
# Return messages as dicts and let the message ctor handle the conversion.
101-
if self.message:
102-
if self.map:
103-
# Not worth the hassle, just return an empty map.
104-
return {}
100+
visited_messages = set()
105101

106-
msg_dict = {
107-
f.name: f.mock_value_original_type
108-
for f in self.message.fields.values()
109-
}
102+
def recursive_mock_original_type(field):
103+
if field.message:
104+
# Return messages as dicts and let the message ctor handle the conversion.
105+
if field.message in visited_messages:
106+
return {}
110107

111-
return [msg_dict] if self.repeated else msg_dict
108+
visited_messages.add(field.message)
109+
if field.map:
110+
# Not worth the hassle, just return an empty map.
111+
return {}
112112

113-
answer = self.primitive_mock() or None
113+
msg_dict = {
114+
f.name: recursive_mock_original_type(f)
115+
for f in field.message.fields.values()
116+
}
114117

115-
# If this is a repeated field, then the mock answer should
116-
# be a list.
117-
if self.repeated:
118-
first_item = self.primitive_mock(suffix=1) or None
119-
second_item = self.primitive_mock(suffix=2) or None
120-
answer = [first_item, second_item]
118+
return [msg_dict] if field.repeated else msg_dict
121119

122-
return answer
120+
if field.enum:
121+
# First Truthy value, fallback to the first value
122+
return next((v for v in field.type.values if v.number), field.type.values[0]).number
123+
124+
answer = field.primitive_mock() or None
125+
126+
# If this is a repeated field, then the mock answer should
127+
# be a list.
128+
if field.repeated:
129+
first_item = field.primitive_mock(suffix=1) or None
130+
second_item = field.primitive_mock(suffix=2) or None
131+
answer = [first_item, second_item]
132+
133+
return answer
134+
135+
return recursive_mock_original_type(self)
123136

124137
@utils.cached_property
125138
def mock_value(self) -> str:
@@ -887,8 +900,12 @@ class HttpRule:
887900
def path_fields(self, method: "Method") -> List[Tuple[Field, str, str]]:
888901
"""return list of (name, template) tuples extracted from uri."""
889902
input = method.input
890-
return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template"))
891-
for match in path_template._VARIABLE_RE.finditer(self.uri)]
903+
return [
904+
(input.get_field(*match.group("name").split(".")),
905+
match.group("name"), match.group("template"))
906+
for match in path_template._VARIABLE_RE.finditer(self.uri)
907+
if match.group("name")
908+
]
892909

893910
def sample_request(self, method: "Method") -> Dict[str, Any]:
894911
"""return json dict for sample request matching the uri template."""

test_utils/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def make_enum(
297297
name: str,
298298
package: str = 'foo.bar.v1',
299299
module: str = 'baz',
300-
values: typing.Tuple[str, int] = (),
300+
values: typing.Sequence[typing.Tuple[str, int]] = (),
301301
meta: metadata.Metadata = None,
302302
options: desc.EnumOptions = None,
303303
) -> wrappers.EnumType:

tests/fragments/test_non_primitive_body.proto

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ service SmallCompute {
2929
post: "/computation/v1/first_name/{first_name}/last_name/{last_name}"
3030
};
3131
};
32+
33+
rpc EnumBody(EnumBodyRequest) returns (EnumBodyResponse) {
34+
option (google.api.http) = {
35+
body: "resource"
36+
post: "/enum_body/v1/names/{name}"
37+
};
38+
}
39+
40+
rpc RecursiveBody(RecursiveBodyRequest) returns (RecursiveBodyResponse) {
41+
option (google.api.http) = {
42+
body: "resource"
43+
post: "/recursive_body/v1/names/{name}"
44+
};
45+
}
3246
}
3347

3448
message SerialNumber {
@@ -50,4 +64,38 @@ message MethodRequest {
5064

5165
message MethodResponse {
5266
string name = 1;
67+
}
68+
69+
message EnumBodyRequest {
70+
message Resource{
71+
enum Ordering {
72+
UNKNOWN = 0;
73+
CHRONOLOGICAL = 1;
74+
ALPHABETICAL = 2;
75+
DIFFICULTY = 3;
76+
}
77+
78+
Ordering ordering = 1;
79+
}
80+
81+
string name = 1;
82+
Resource resource = 2;
83+
}
84+
85+
message EnumBodyResponse {
86+
string data = 1;
87+
}
88+
89+
message RecursiveBodyRequest {
90+
message Resource {
91+
int32 depth = 1;
92+
Resource child_resource = 2;
93+
}
94+
95+
string name = 1;
96+
Resource resource = 2;
97+
}
98+
99+
message RecursiveBodyResponse {
100+
string data = 1;
53101
}

tests/unit/schema/wrappers/test_field.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from test_utils.test_utils import (
2828
make_field,
2929
make_message,
30+
make_enum,
3031
)
3132

3233

@@ -343,7 +344,41 @@ def test_mock_value_original_type_message():
343344
assert entry_field.mock_value_original_type == {}
344345

345346

346-
def test_mock_value_recursive():
347+
def test_mock_value_original_type_enum():
348+
mollusc_field = make_field(
349+
name="class",
350+
enum=make_enum(
351+
name="Class",
352+
values=[
353+
("UNKNOWN", 0),
354+
("GASTROPOD", 1),
355+
("BIVALVE", 2),
356+
("CEPHALOPOD", 3),
357+
],
358+
),
359+
)
360+
361+
assert mollusc_field.mock_value_original_type == 1
362+
363+
empty_field = make_field(
364+
name="empty",
365+
enum=make_enum(
366+
name="Empty",
367+
values=[("UNKNOWN", 0)],
368+
),
369+
)
370+
371+
assert empty_field.mock_value_original_type == 0
372+
373+
374+
@pytest.mark.parametrize(
375+
"mock_method,expected",
376+
[
377+
("mock_value", "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))"),
378+
("mock_value_original_type", {"turtle": {}}),
379+
],
380+
)
381+
def test_mock_value_recursive(mock_method, expected):
347382
# The elaborate setup is an unfortunate requirement.
348383
file_pb = descriptor_pb2.FileDescriptorProto(
349384
name="turtle.proto",
@@ -367,8 +402,7 @@ def test_mock_value_recursive():
367402
turtle_field = my_api.messages["animalia.chordata.v2.Turtle"].fields["turtle"]
368403

369404
# If not handled properly, this will run forever and eventually OOM.
370-
actual = turtle_field.mock_value
371-
expected = "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))"
405+
actual = getattr(turtle_field, mock_method)
372406
assert actual == expected
373407

374408

0 commit comments

Comments
 (0)