Skip to content

Commit 54d184c

Browse files
Async Stub Annotations (#611)
* Generate a set of TypeVars and a generic class TypeVars have defaults to match expected behavior Overload init methods to get expected types back Create AsyncStub type alias Signed-off-by: Aidan Jensen <[email protected]> --------- Signed-off-by: Aidan Jensen <[email protected]>
1 parent f20607f commit 54d184c

File tree

10 files changed

+311
-85
lines changed

10 files changed

+311
-85
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ per-file-ignores =
33
*.py: E203, E301, E302, E305, E501
44
*.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037
55
*_pb2.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
6-
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
6+
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021, Y023
77

88
extend_exclude = venv*,*_pb2.py,*_pb2_grpc.py,build/

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
- Mark top-level mangled identifiers as `TypeAlias`.
44
- Change the top-level mangling prefix from `global___` to `Global___` to respect
5-
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
5+
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
6+
- Support client stub async typing overloads
67

78
## 3.6.0
89

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ black .
303303
- [@fergyfresh](https://github.com/fergyfresh)
304304
- [@AlexWaygood](https://github.com/AlexWaygood)
305305
- [@Avasam](https://github.com/Avasam)
306+
- [@artificial-aidan](https://github.com/artificial-aidan)
306307

307308
## Licence etc.
308309

mypy_protobuf/main.py

+83-24
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
Iterator,
1414
List,
1515
Optional,
16-
Set,
1716
Sequence,
17+
Set,
1818
Tuple,
1919
)
2020

2121
import google.protobuf.descriptor_pb2 as d
2222
from google.protobuf.compiler import plugin_pb2 as plugin_pb2
2323
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
2424
from google.protobuf.internal.well_known_types import WKTBASES
25+
2526
from . import extensions_pb2
2627

2728
__version__ = "3.6.0"
@@ -85,6 +86,11 @@
8586
}
8687

8788

89+
def _build_typevar_name(service_name: str, method_name: str) -> str:
90+
# Prefix with underscore to avoid public api error: https://stackoverflow.com/a/78871465
91+
return f"_{service_name}{method_name}Type"
92+
93+
8894
def _mangle_global_identifier(name: str) -> str:
8995
"""
9096
Module level identifiers are mangled and aliased so that they can be disambiguated
@@ -168,9 +174,7 @@ def _import(self, path: str, name: str) -> str:
168174
eg. self._import("typing", "Literal") -> "Literal"
169175
"""
170176
if path == "typing_extensions":
171-
stabilization = {
172-
"TypeAlias": (3, 10),
173-
}
177+
stabilization = {"TypeAlias": (3, 10), "TypeVar": (3, 13)}
174178
assert name in stabilization
175179
if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
176180
self.typing_extensions_min = stabilization[name]
@@ -732,6 +736,46 @@ def write_grpc_async_hacks(self) -> None:
732736
wl("...")
733737
wl("")
734738

739+
def write_grpc_type_vars(self, service: d.ServiceDescriptorProto) -> None:
740+
wl = self._write_line
741+
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
742+
if not methods:
743+
return
744+
for _, method in methods:
745+
wl("{} = {}(", _build_typevar_name(service.name, method.name), self._import("typing_extensions", "TypeVar"))
746+
with self._indent():
747+
wl("'{}',", _build_typevar_name(service.name, method.name))
748+
wl("{}[", self._callable_type(method, is_async=False))
749+
with self._indent():
750+
wl("{},", self._input_type(method))
751+
wl("{},", self._output_type(method))
752+
wl("],")
753+
wl("{}[", self._callable_type(method, is_async=True))
754+
with self._indent():
755+
wl("{},", self._input_type(method))
756+
wl("{},", self._output_type(method))
757+
wl("],")
758+
wl("default={}[", self._callable_type(method, is_async=False))
759+
with self._indent():
760+
wl("{},", self._input_type(method))
761+
wl("{},", self._output_type(method))
762+
wl("],")
763+
wl(")")
764+
wl("")
765+
766+
def write_self_types(self, service: d.ServiceDescriptorProto, is_async: bool) -> None:
767+
wl = self._write_line
768+
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
769+
if not methods:
770+
return
771+
for _, method in methods:
772+
with self._indent():
773+
wl("{}[", self._callable_type(method, is_async=is_async))
774+
with self._indent():
775+
wl("{},", self._input_type(method))
776+
wl("{},", self._output_type(method))
777+
wl("],")
778+
735779
def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
736780
wl = self._write_line
737781
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
@@ -769,11 +813,7 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
769813
for i, method in methods:
770814
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
771815

772-
wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
773-
with self._indent():
774-
wl("{},", self._input_type(method))
775-
wl("{},", self._output_type(method))
776-
wl("]")
816+
wl("{}: {}", method.name, f"{_build_typevar_name(service.name, method.name)}")
777817
self._write_comments(scl)
778818
wl("")
779819

@@ -791,29 +831,48 @@ def write_grpc_services(
791831

792832
scl = scl_prefix + [i]
793833

834+
# Type vars
835+
self.write_grpc_type_vars(service)
836+
794837
# The stub client
838+
class_name = f"{service.name}Stub"
795839
wl(
796-
"class {}Stub:",
797-
service.name,
840+
"class {}({}[{}]):",
841+
class_name,
842+
self._import("typing", "Generic"),
843+
", ".join(f"{_build_typevar_name(service.name, method.name)}" for method in service.method),
798844
)
799845
with self._indent():
800846
if self._write_comments(scl):
801847
wl("")
802-
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
803-
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
804-
wl("def __init__(self, channel: {}) -> None: ...", channel)
848+
849+
# Write sync overload
850+
wl("@{}", self._import("typing", "overload"))
851+
wl("def __init__(self: {}[", class_name)
852+
self.write_self_types(service, False)
853+
wl(
854+
"], channel: {}) -> None: ...",
855+
self._import("grpc", "Channel"),
856+
)
857+
wl("")
858+
859+
# Write async overload
860+
wl("@{}", self._import("typing", "overload"))
861+
wl("def __init__(self: {}[", class_name)
862+
self.write_self_types(service, True)
863+
wl(
864+
"], channel: {}) -> None: ...",
865+
self._import("grpc.aio", "Channel"),
866+
)
867+
wl("")
868+
805869
self.write_grpc_stub_methods(service, scl)
806870

807-
# The (fake) async stub client
808-
wl(
809-
"class {}AsyncStub:",
810-
service.name,
811-
)
812-
with self._indent():
813-
if self._write_comments(scl):
814-
wl("")
815-
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
816-
self.write_grpc_stub_methods(service, scl, is_async=True)
871+
# Write AsyncStub alias
872+
wl("{}AsyncStub: {} = {}[", service.name, self._import("typing_extensions", "TypeAlias"), class_name)
873+
self.write_self_types(service, True)
874+
wl("]")
875+
wl("")
817876

818877
# The service definition interface
819878
wl(

run_test.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
169169
# Write output to file. Make variant w/ omitted line numbers for easy diffing / CR
170170
PY_VER_MYPY_TARGET=$(echo "$1" | cut -d. -f1-2)
171171
export MYPYPATH=$MYPYPATH:test/generated
172-
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
172+
# Use --no-incremental to avoid caching issues: https://github.com/python/mypy/issues/16363
173+
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --no-incremental --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
173174
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
174175
}
175176

test/generated/testproto/grpc/dummy_pb2_grpc.pyi

+121-32
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ import abc
77
import collections.abc
88
import grpc
99
import grpc.aio
10+
import sys
1011
import testproto.grpc.dummy_pb2
1112
import typing
1213

14+
if sys.version_info >= (3, 13):
15+
import typing as typing_extensions
16+
else:
17+
import typing_extensions
18+
1319
_T = typing.TypeVar("_T")
1420

1521
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
@@ -19,60 +25,143 @@ class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type:
1925

2026
GRPC_GENERATED_VERSION: str
2127
GRPC_VERSION: str
22-
class DummyServiceStub:
23-
"""DummyService"""
24-
25-
def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
26-
UnaryUnary: grpc.UnaryUnaryMultiCallable[
28+
_DummyServiceUnaryUnaryType = typing_extensions.TypeVar(
29+
'_DummyServiceUnaryUnaryType',
30+
grpc.UnaryUnaryMultiCallable[
2731
testproto.grpc.dummy_pb2.DummyRequest,
2832
testproto.grpc.dummy_pb2.DummyReply,
29-
]
30-
"""UnaryUnary"""
33+
],
34+
grpc.aio.UnaryUnaryMultiCallable[
35+
testproto.grpc.dummy_pb2.DummyRequest,
36+
testproto.grpc.dummy_pb2.DummyReply,
37+
],
38+
default=grpc.UnaryUnaryMultiCallable[
39+
testproto.grpc.dummy_pb2.DummyRequest,
40+
testproto.grpc.dummy_pb2.DummyReply,
41+
],
42+
)
3143

32-
UnaryStream: grpc.UnaryStreamMultiCallable[
44+
_DummyServiceUnaryStreamType = typing_extensions.TypeVar(
45+
'_DummyServiceUnaryStreamType',
46+
grpc.UnaryStreamMultiCallable[
3347
testproto.grpc.dummy_pb2.DummyRequest,
3448
testproto.grpc.dummy_pb2.DummyReply,
35-
]
36-
"""UnaryStream"""
49+
],
50+
grpc.aio.UnaryStreamMultiCallable[
51+
testproto.grpc.dummy_pb2.DummyRequest,
52+
testproto.grpc.dummy_pb2.DummyReply,
53+
],
54+
default=grpc.UnaryStreamMultiCallable[
55+
testproto.grpc.dummy_pb2.DummyRequest,
56+
testproto.grpc.dummy_pb2.DummyReply,
57+
],
58+
)
3759

38-
StreamUnary: grpc.StreamUnaryMultiCallable[
60+
_DummyServiceStreamUnaryType = typing_extensions.TypeVar(
61+
'_DummyServiceStreamUnaryType',
62+
grpc.StreamUnaryMultiCallable[
3963
testproto.grpc.dummy_pb2.DummyRequest,
4064
testproto.grpc.dummy_pb2.DummyReply,
41-
]
42-
"""StreamUnary"""
65+
],
66+
grpc.aio.StreamUnaryMultiCallable[
67+
testproto.grpc.dummy_pb2.DummyRequest,
68+
testproto.grpc.dummy_pb2.DummyReply,
69+
],
70+
default=grpc.StreamUnaryMultiCallable[
71+
testproto.grpc.dummy_pb2.DummyRequest,
72+
testproto.grpc.dummy_pb2.DummyReply,
73+
],
74+
)
4375

44-
StreamStream: grpc.StreamStreamMultiCallable[
76+
_DummyServiceStreamStreamType = typing_extensions.TypeVar(
77+
'_DummyServiceStreamStreamType',
78+
grpc.StreamStreamMultiCallable[
4579
testproto.grpc.dummy_pb2.DummyRequest,
4680
testproto.grpc.dummy_pb2.DummyReply,
47-
]
48-
"""StreamStream"""
81+
],
82+
grpc.aio.StreamStreamMultiCallable[
83+
testproto.grpc.dummy_pb2.DummyRequest,
84+
testproto.grpc.dummy_pb2.DummyReply,
85+
],
86+
default=grpc.StreamStreamMultiCallable[
87+
testproto.grpc.dummy_pb2.DummyRequest,
88+
testproto.grpc.dummy_pb2.DummyReply,
89+
],
90+
)
4991

50-
class DummyServiceAsyncStub:
92+
class DummyServiceStub(typing.Generic[_DummyServiceUnaryUnaryType, _DummyServiceUnaryStreamType, _DummyServiceStreamUnaryType, _DummyServiceStreamStreamType]):
5193
"""DummyService"""
5294

53-
UnaryUnary: grpc.aio.UnaryUnaryMultiCallable[
54-
testproto.grpc.dummy_pb2.DummyRequest,
55-
testproto.grpc.dummy_pb2.DummyReply,
56-
]
95+
@typing.overload
96+
def __init__(self: DummyServiceStub[
97+
grpc.UnaryUnaryMultiCallable[
98+
testproto.grpc.dummy_pb2.DummyRequest,
99+
testproto.grpc.dummy_pb2.DummyReply,
100+
],
101+
grpc.UnaryStreamMultiCallable[
102+
testproto.grpc.dummy_pb2.DummyRequest,
103+
testproto.grpc.dummy_pb2.DummyReply,
104+
],
105+
grpc.StreamUnaryMultiCallable[
106+
testproto.grpc.dummy_pb2.DummyRequest,
107+
testproto.grpc.dummy_pb2.DummyReply,
108+
],
109+
grpc.StreamStreamMultiCallable[
110+
testproto.grpc.dummy_pb2.DummyRequest,
111+
testproto.grpc.dummy_pb2.DummyReply,
112+
],
113+
], channel: grpc.Channel) -> None: ...
114+
115+
@typing.overload
116+
def __init__(self: DummyServiceStub[
117+
grpc.aio.UnaryUnaryMultiCallable[
118+
testproto.grpc.dummy_pb2.DummyRequest,
119+
testproto.grpc.dummy_pb2.DummyReply,
120+
],
121+
grpc.aio.UnaryStreamMultiCallable[
122+
testproto.grpc.dummy_pb2.DummyRequest,
123+
testproto.grpc.dummy_pb2.DummyReply,
124+
],
125+
grpc.aio.StreamUnaryMultiCallable[
126+
testproto.grpc.dummy_pb2.DummyRequest,
127+
testproto.grpc.dummy_pb2.DummyReply,
128+
],
129+
grpc.aio.StreamStreamMultiCallable[
130+
testproto.grpc.dummy_pb2.DummyRequest,
131+
testproto.grpc.dummy_pb2.DummyReply,
132+
],
133+
], channel: grpc.aio.Channel) -> None: ...
134+
135+
UnaryUnary: _DummyServiceUnaryUnaryType
57136
"""UnaryUnary"""
58137

59-
UnaryStream: grpc.aio.UnaryStreamMultiCallable[
60-
testproto.grpc.dummy_pb2.DummyRequest,
61-
testproto.grpc.dummy_pb2.DummyReply,
62-
]
138+
UnaryStream: _DummyServiceUnaryStreamType
63139
"""UnaryStream"""
64140

65-
StreamUnary: grpc.aio.StreamUnaryMultiCallable[
66-
testproto.grpc.dummy_pb2.DummyRequest,
67-
testproto.grpc.dummy_pb2.DummyReply,
68-
]
141+
StreamUnary: _DummyServiceStreamUnaryType
69142
"""StreamUnary"""
70143

71-
StreamStream: grpc.aio.StreamStreamMultiCallable[
144+
StreamStream: _DummyServiceStreamStreamType
145+
"""StreamStream"""
146+
147+
DummyServiceAsyncStub: typing_extensions.TypeAlias = DummyServiceStub[
148+
grpc.aio.UnaryUnaryMultiCallable[
72149
testproto.grpc.dummy_pb2.DummyRequest,
73150
testproto.grpc.dummy_pb2.DummyReply,
74-
]
75-
"""StreamStream"""
151+
],
152+
grpc.aio.UnaryStreamMultiCallable[
153+
testproto.grpc.dummy_pb2.DummyRequest,
154+
testproto.grpc.dummy_pb2.DummyReply,
155+
],
156+
grpc.aio.StreamUnaryMultiCallable[
157+
testproto.grpc.dummy_pb2.DummyRequest,
158+
testproto.grpc.dummy_pb2.DummyReply,
159+
],
160+
grpc.aio.StreamStreamMultiCallable[
161+
testproto.grpc.dummy_pb2.DummyRequest,
162+
testproto.grpc.dummy_pb2.DummyReply,
163+
],
164+
]
76165

77166
class DummyServiceServicer(metaclass=abc.ABCMeta):
78167
"""DummyService"""

0 commit comments

Comments
 (0)