Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions conformance/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from connectrpc.client import ResponseMetadata
from connectrpc.code import Code
from connectrpc.codec import proto_json_codec
from connectrpc.compression.brotli import BrotliCompression
from connectrpc.compression.gzip import GzipCompression
from connectrpc.compression.zstd import ZstdCompression
Expand Down Expand Up @@ -173,7 +174,9 @@ async def client_sync(
ZstdCompression(),
],
send_compression=_convert_compression(test_request.compression),
proto_json=test_request.codec == Codec.CODEC_JSON,
codec=proto_json_codec()
if test_request.codec == Codec.CODEC_JSON
else None,
protocol=protocol,
read_max_bytes=read_max_bytes,
) as client,
Expand Down Expand Up @@ -220,7 +223,9 @@ async def client_async(
ZstdCompression(),
],
send_compression=_convert_compression(test_request.compression),
proto_json=test_request.codec == Codec.CODEC_JSON,
codec=proto_json_codec()
if test_request.codec == Codec.CODEC_JSON
else None,
protocol=protocol,
read_max_bytes=read_max_bytes,
) as client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Mapping,
)

from connectrpc.codec import Codec
from connectrpc.compression import Compression
from connectrpc.interceptor import Interceptor, InterceptorSync
from connectrpc.request import Headers, RequestContext
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[Compression] | None = None,
codecs: Iterable[Codec] | None = None,
) -> None:
super().__init__(
service=service,
Expand Down Expand Up @@ -159,6 +161,7 @@ def __init__(
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down Expand Up @@ -358,6 +361,7 @@ def __init__(
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[Compression] | None = None,
codecs: Iterable[Codec] | None = None,
) -> None:
super().__init__(
endpoints={
Expand Down Expand Up @@ -425,6 +429,7 @@ def __init__(
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down
3 changes: 3 additions & 0 deletions docs/api.md
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, unrelated: I think we need connectrpc.protocol here as well to render the ProtocolType:

Image

Can be a follow-up.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

::: connectrpc.interceptor

::: connectrpc.codec
::: connectrpc.protocol

::: connectrpc.compression
::: connectrpc.compression.brotli
::: connectrpc.compression.gzip
Expand Down
5 changes: 5 additions & 0 deletions example/example/eliza_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Mapping,
)

from connectrpc.codec import Codec
from connectrpc.compression import Compression
from connectrpc.interceptor import Interceptor, InterceptorSync
from connectrpc.request import Headers, RequestContext
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[Compression] | None = None,
codecs: Iterable[Codec] | None = None,
) -> None:
super().__init__(
service=service,
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down Expand Up @@ -194,6 +197,7 @@ def __init__(
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[Compression] | None = None,
codecs: Iterable[Codec] | None = None,
) -> None:
super().__init__(
endpoints={
Expand Down Expand Up @@ -231,6 +235,7 @@ def __init__(
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down
7 changes: 5 additions & 2 deletions protoc-gen-connect-python/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ from typing import Protocol

from connectrpc.client import ConnectClient, ConnectClientSync
from connectrpc.code import Code
from connectrpc.codec import Codec
from connectrpc.compression import Compression
from connectrpc.errors import ConnectError
from connectrpc.interceptor import Interceptor, InterceptorSync
Expand All @@ -69,7 +70,7 @@ class {{.Name}}(Protocol):{{- range .Methods }}
{{ end }}

class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None:
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None:
super().__init__(
service=service,
endpoints=lambda svc: { {{- range .Methods }}
Expand All @@ -87,6 +88,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down Expand Up @@ -130,7 +132,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }}


class {{.Name}}WSGIApplication(ConnectWSGIApplication):
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None:
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None:
super().__init__(
endpoints={ {{- range .Methods }}
"/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}(
Expand All @@ -147,6 +149,7 @@ class {{.Name}}WSGIApplication(ConnectWSGIApplication):
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
codecs=codecs,
)

@property
Expand Down
10 changes: 6 additions & 4 deletions src/connectrpc/_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import _client_shared
from ._asyncio_timeout import timeout as asyncio_timeout
from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec
from ._codec import proto_binary_codec
from ._compression import IdentityCompression, _gzip, resolve_compressions
from ._interceptor_async import (
BidiStreamInterceptor,
Expand Down Expand Up @@ -43,6 +43,7 @@
from types import TracebackType

from ._envelope import EnvelopeReader
from .codec import Codec
from .compression import Compression
from .method import MethodInfo
from .request import Headers, RequestContext
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
self,
address: str,
*,
proto_json: bool = False,
codec: Codec | None = None,
protocol: ProtocolType = ProtocolType.CONNECT,
accept_compression: Iterable[Compression] | None = None,
send_compression: Compression | None = _gzip,
Expand All @@ -105,7 +106,8 @@ def __init__(

Args:
address: The address of the server to connect to, including scheme.
proto_json: Whether to use JSON for the protocol.
codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf.
For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec].
protocol: The [ProtocolType][] to use for requests.
accept_compression: Compression algorithms to accept from the server. If unset,
defaults to gzip. If set to empty, disables response compression.
Expand All @@ -117,7 +119,7 @@ def __init__(
http_client: A pyqwest Client to use for requests.
"""
self._address = address
self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec()
self._codec = codec or proto_binary_codec()
self._response_compressions = resolve_compressions(accept_compression)
self._accept_compression_header = ",".join(self._response_compressions.keys())
self._send_compression = send_compression or IdentityCompression()
Expand Down
10 changes: 6 additions & 4 deletions src/connectrpc/_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from connectrpc._protocol_grpc import GRPCClientProtocol, GRPCWebClientProtocol

from . import _client_shared
from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec
from ._codec import proto_binary_codec
from ._compression import IdentityCompression, _gzip, resolve_compressions
from ._interceptor_sync import (
BidiStreamInterceptorSync,
Expand All @@ -33,6 +33,7 @@
from types import TracebackType

from ._envelope import EnvelopeReader
from .codec import Codec
from .compression import Compression
from .method import MethodInfo
from .request import Headers, RequestContext
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
self,
address: str,
*,
proto_json: bool = False,
codec: Codec | None = None,
protocol: ProtocolType = ProtocolType.CONNECT,
accept_compression: Iterable[Compression] | None = None,
send_compression: Compression | None = _gzip,
Expand All @@ -95,7 +96,8 @@ def __init__(

Args:
address: The address of the server to connect to, including scheme.
proto_json: Whether to use JSON for the protocol.
codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf.
For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec].
protocol: The [ProtocolType][] to use for requests.
accept_compression: Compression algorithms to accept from the server. If unset,
defaults to gzip. If set to empty, disables response compression.
Expand All @@ -107,7 +109,7 @@ def __init__(
http_client: A pyqwest SyncClient to use for requests.
"""
self._address = address
self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec()
self._codec = codec or proto_binary_codec()
self._timeout_ms = timeout_ms
self._read_max_bytes = read_max_bytes
self._response_compressions = resolve_compressions(accept_compression)
Expand Down
39 changes: 21 additions & 18 deletions src/connectrpc/_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

class Codec(Protocol[T_contra, U]):
def name(self) -> str:
"""Returns the name of the codec."""
"""Returns the name of the codec.

This corresponds to the content-type used in requests.
"""
...

def encode(self, message: T_contra) -> bytes:
Expand Down Expand Up @@ -49,8 +52,11 @@ def decode(self, data: bytes | bytearray, message: V) -> V:
class ProtoJSONCodec(Codec[Message, V]):
"""Codec for Protocol bytes | bytearrays JSON format."""

def __init__(self, name: str = "json") -> None:
self._name = name

def name(self) -> str:
return "json"
return self._name

def encode(self, message: Message) -> bytes:
return MessageToJson(message).encode()
Expand All @@ -60,27 +66,24 @@ def decode(self, data: bytes | bytearray, message: V) -> V:
return message


# TODO: Codecs can generally be customized per handler instead of as a global
# registry, though the usage isn't common.
_proto_binary_codec = ProtoBinaryCodec()
_proto_json_codec = ProtoJSONCodec()
_codecs = {
CODEC_NAME_PROTO: _proto_binary_codec,
CODEC_NAME_JSON: _proto_json_codec,
CODEC_NAME_JSON_CHARSET_UTF8: _proto_json_codec,
}
_default_codecs = [
_proto_binary_codec,
_proto_json_codec,
ProtoJSONCodec(name=CODEC_NAME_JSON_CHARSET_UTF8),
]


def get_proto_binary_codec() -> Codec:
"""Returns the Protocol bytes | bytearrays binary codec."""
return _proto_binary_codec
def get_default_codecs() -> list[Codec]:
return _default_codecs


def get_proto_json_codec() -> Codec:
"""Returns the Protocol bytes | bytearrays JSON codec."""
return _proto_json_codec
def proto_binary_codec() -> Codec:
"""Returns the Protocol Buffers binary codec."""
return _proto_binary_codec


def get_codec(name: str) -> Codec | None:
"""Returns the codec with the given name."""
return _codecs.get(name)
def proto_json_codec() -> Codec:
"""Returns the Protocol Buffers JSON codec."""
return _proto_json_codec
9 changes: 7 additions & 2 deletions src/connectrpc/_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TYPE_CHECKING, Generic, TypeVar, cast
from urllib.parse import parse_qs

from ._codec import Codec, get_codec
from ._codec import Codec, get_default_codecs
from ._compression import negotiate_compression, resolve_compressions
from ._envelope import EnvelopeReader
from ._interceptor_async import (
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[Compression] | None = None,
codecs: Iterable[Codec] | None = None,
) -> None:
"""Initialize the ASGI application.

Expand All @@ -103,6 +104,8 @@ def __init__(
read_max_bytes: Maximum size of request messages.
compressions: Supported compression algorithms. If unset, defaults to gzip.
If set to empty, disables compression.
codecs: The codecs supported by the server. If unset, defaults to Protocol Buffers
binary and JSON codecs.
"""
super().__init__()
self._service = service
Expand All @@ -111,6 +114,8 @@ def __init__(
self._resolved_endpoints = None
self._read_max_bytes = read_max_bytes
self._compressions = resolve_compressions(compressions)
codecs = codecs if codecs is not None else get_default_codecs()
self._codecs = {codec.name(): codec for codec in codecs}

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
Expand Down Expand Up @@ -208,7 +213,7 @@ async def __call__(
codec_name = protocol.codec_name_from_content_type(
headers.get("content-type", ""), stream=not is_unary
)
codec = get_codec(codec_name.lower())
codec = self._codecs.get(codec_name.lower())
if not codec:
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
Expand Down
Loading
Loading