Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,16 @@ protoc \
/usr/local/include/google/protobuf/*.proto
```

### Using grpcio library instead of grpclib

In order to use the `grpcio` library instead of `grpclib`, you can use the `--python_betterproto_opt=USE_GRPCIO`
option when running the `protoc` command.
This will generate stubs compatible with the `grpcio` library.

Example:
```sh
protoc -I . --python_betterproto_out=. --python_betterproto_opt=USE_GRPCIO demo.proto
```
### TODO

- [x] Fixed length fields
Expand Down
172 changes: 94 additions & 78 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dynamic = ["dependencies"]
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
ruff = { version = "~0.9.1", optional = true }
grpclib = "^0.4.1"
grpcio = { version = ">=1.73.0", optional = true }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
typing-extensions = "^4.7.1"
Expand All @@ -45,13 +46,15 @@ pydantic = ">=2.0,<3"
protobuf = "^5"
cachelib = "^0.13.0"
tomlkit = ">=0.7.0"
grpcio-testing = "^1.54.2"

[project.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main"

[project.optional-dependencies]
compiler = ["ruff", "jinja2"]
rust-codec = ["betterproto-rust-codec"]
grpcio = ["grpcio"]

[tool.ruff]
extend-exclude = ["tests/output_*"]
Expand Down
120 changes: 120 additions & 0 deletions src/betterproto/grpc/grpcio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from abc import ABC
from typing import (
TYPE_CHECKING,
AsyncIterable,
AsyncIterator,
Iterable,
Mapping,
Optional,
Union,
)

import grpc

if TYPE_CHECKING:
from .._types import (
T,
IProtoMessage,
)

Value = Union[str, bytes]
MetadataLike = Union[Mapping[str, Value], Iterable[tuple[str, Value]]]
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]


class ServiceStub(ABC):

def __init__(
self,
channel: grpc.aio.Channel,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> None:
self.channel = channel
self.timeout = timeout
self.metadata = metadata

def _resolve_request_kwargs(
self,
timeout: Optional[float],
metadata: Optional[MetadataLike],
):
return {
"timeout": self.timeout if timeout is None else timeout,
"metadata": self.metadata if metadata is None else metadata,
}

async def _unary_unary(
self,
stub_method: grpc.aio.UnaryUnaryMultiCallable,
request: "IProtoMessage",
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> "T":
return await stub_method(
request,
**self._resolve_request_kwargs(timeout, metadata),
)

async def _unary_stream(
self,
stub_method: grpc.aio.UnaryStreamMultiCallable,
request: "IProtoMessage",
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator["T"]:
call = stub_method(
request,
**self._resolve_request_kwargs(timeout, metadata),
)
async for response in call:
yield response

async def _stream_unary(
self,
stub_method: grpc.aio.StreamUnaryMultiCallable,
request_iterator: MessageSource,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> "T":
call = stub_method(
self._wrap_message_iterator(request_iterator),
**self._resolve_request_kwargs(timeout, metadata),
)
return await call

async def _stream_stream(
self,
stub_method: grpc.aio.StreamStreamMultiCallable,
request_iterator: MessageSource,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator["T"]:
call = stub_method(
self._wrap_message_iterator(request_iterator),
**self._resolve_request_kwargs(timeout, metadata),
)
async for response in call:
yield response

@staticmethod
def _wrap_message_iterator(
messages: MessageSource,
) -> AsyncIterator["IProtoMessage"]:
if hasattr(messages, '__aiter__'):
async def async_wrapper():
async for message in messages:
yield message

return async_wrapper()
else:
async def sync_wrapper():
for message in messages:
yield message

return sync_wrapper()
30 changes: 30 additions & 0 deletions src/betterproto/grpc/grpcio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict


if TYPE_CHECKING:
import grpc


class ServiceBase(ABC):

@property
@abstractmethod
def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: ...

@property
@abstractmethod
def __proto_path__(self) -> str: ...


def register_servicers(server: "grpc.aio.Server", *servicers: ServiceBase):
from grpc import method_handlers_generic_handler

server.add_generic_rpc_handlers(
tuple(
method_handlers_generic_handler(
servicer.__proto_path__, servicer.__rpc_methods__
)
for servicer in servicers
)
)
17 changes: 10 additions & 7 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
use_grpcio: bool = False
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)

@property
Expand Down Expand Up @@ -697,18 +698,20 @@ class ServiceMethodCompiler(ProtoContentBase):
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
comment_indent: int = 8
use_grpcio: bool = False

def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)

self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
self.output_file.imports_type_checking_only.add(
"from grpclib.metadata import Deadline"
)
if self.use_grpcio:
imports = ["import grpc.aio", "from betterproto.grpc.grpcio_client import MetadataLike"]
else:
imports = ["import grpclib.server", "from betterproto.grpc.grpclib_client import MetadataLike",
"from grpclib.metadata import Deadline"]

for import_line in imports:
self.output_file.imports_type_checking_only.add(import_line)

super().__post_init__() # check for unset fields

Expand Down
10 changes: 7 additions & 3 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)

USE_GRPCIO_FLAG = "USE_GRPCIO"


def traverse(
proto_file: FileDescriptorProto,
Expand Down Expand Up @@ -80,6 +81,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL

request_data = PluginRequestCompiler(plugin_request_obj=request)
use_grpcio = USE_GRPCIO_FLAG in plugin_options
# Gather output packages
for proto_file in request.proto_file:
output_package_name = proto_file.package
Expand All @@ -90,7 +92,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)

request_data.output_packages[output_package_name].use_grpcio = use_grpcio
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
Expand Down Expand Up @@ -143,7 +145,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(proto_input_file, service, index, output_package)
read_protobuf_service(proto_input_file, service, index, output_package, use_grpcio)

# Generate output files
output_paths: Set[pathlib.Path] = set()
Expand Down Expand Up @@ -253,6 +255,7 @@ def read_protobuf_service(
service: ServiceDescriptorProto,
index: int,
output_package: OutputTemplate,
use_grpcio: bool = False,
) -> None:
service_data = ServiceCompiler(
source_file=source_file,
Expand All @@ -266,4 +269,5 @@ def read_protobuf_service(
parent=service_data,
proto_obj=method,
path=[6, index, 2, j],
use_grpcio=use_grpcio,
)
16 changes: 16 additions & 0 deletions src/betterproto/plugin/typing_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def async_iterable(self, type: str) -> str:
def async_iterator(self, type: str) -> str:
raise NotImplementedError()

@abc.abstractmethod
def async_generator(self, type: str) -> str:
raise NotImplementedError()

@abc.abstractmethod
def imports(self) -> Dict[str, Optional[Set[str]]]:
"""
Expand Down Expand Up @@ -93,6 +97,10 @@ def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"

def async_generator(self, type: str) -> str:
self._imports["typing"].add("AsyncGenerator")
return f"AsyncGenerator[{type}, None]"

def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}

Expand Down Expand Up @@ -129,6 +137,10 @@ def async_iterator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterator[{type}]"

def async_generator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncGenerator[{type}, None]"

def imports(self) -> Dict[str, Optional[Set[str]]]:
if self._imported:
return {"typing": None}
Expand Down Expand Up @@ -169,5 +181,9 @@ def async_iterator(self, type: str) -> str:
self._imports["collections.abc"].add("AsyncIterator")
return f'"AsyncIterator[{type}]"'

def async_generator(self, type: str) -> str:
self._imports["collections.abc"].add("AsyncGenerator")
return f'"AsyncGenerator[{type}, None]"'

def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
15 changes: 13 additions & 2 deletions src/betterproto/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ __all__ = (
{%- for service in output_file.services -%}
"{{ service.py_name }}Stub",
"{{ service.py_name }}Base",
{%- if output_file.use_grpcio -%}
"add_{{ service.py_name }}Servicer_to_server",
{%- endif -%}
{%- endfor -%}
)

Expand All @@ -29,7 +32,7 @@ from dataclasses import dataclass
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}

{% endif%}
{% endif %}
{% set typing_imports = output_file.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
Expand All @@ -42,8 +45,16 @@ from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% i

{% endif %}


{% if output_file.use_grpcio %}
import grpc
from betterproto.grpc.grpcio_client import ServiceStub
from betterproto.grpc.grpcio_server import ServiceBase
{% endif %}

import betterproto
{% if output_file.services %}
{% if not output_file.use_grpcio %}
from betterproto.grpc.grpclib_client import ServiceStub
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
Expand Down
Loading