From 073a99c7ad0fd9ba56fa5ffb2d41c6b37775687a Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 27 Jan 2026 17:50:07 +0900 Subject: [PATCH 1/5] Make Compression protocol public and default to gzip-only Signed-off-by: Anuraag Agrawal --- .github/workflows/ci.yaml | 4 - conformance/test/client.py | 41 ++++--- .../conformance/v1/service_connect.py | 6 +- conformance/test/server.py | 11 +- connect-python.code-workspace | 28 ++--- example/example/eliza_connect.py | 6 +- justfile | 9 +- noextras/README.md | 1 - noextras/pyproject.toml | 16 --- noextras/test/test_compression_default.py | 108 ---------------- .../generator/template.go | 7 +- pyproject.toml | 2 +- src/connectrpc/_client_async.py | 30 ++--- src/connectrpc/_client_shared.py | 29 ----- src/connectrpc/_client_sync.py | 30 ++--- src/connectrpc/_compression.py | 115 +++--------------- src/connectrpc/_protocol.py | 6 +- src/connectrpc/_protocol_connect.py | 33 +++-- src/connectrpc/_protocol_grpc.py | 27 ++-- src/connectrpc/_server_async.py | 36 ++---- src/connectrpc/_server_sync.py | 36 ++---- src/connectrpc/compression/__init__.py | 32 +++++ src/connectrpc/compression/brotli.py | 28 +++++ src/connectrpc/compression/gzip.py | 26 ++++ src/connectrpc/compression/zstd.py | 31 +++++ test/_util.py | 26 ++++ test/haberdasher_connect.py | 6 +- test/test_compression.py | 40 +++--- test/test_errors.py | 2 +- test/test_roundtrip.py | 52 +++++--- uv.lock | 24 ---- 31 files changed, 365 insertions(+), 483 deletions(-) delete mode 100644 noextras/README.md delete mode 100644 noextras/pyproject.toml delete mode 100644 noextras/test/test_compression_default.py create mode 100644 src/connectrpc/compression/__init__.py create mode 100644 src/connectrpc/compression/brotli.py create mode 100644 src/connectrpc/compression/gzip.py create mode 100644 src/connectrpc/compression/zstd.py create mode 100644 test/_util.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3611a0d..7d03f9a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -87,10 +87,6 @@ jobs: run: uv run pytest ${{ matrix.coverage == 'cov' && '--cov=connectrpc --cov-report=xml' || '' }} working-directory: conformance - - name: run tests with minimal dependencies - run: uv run --exact pytest ${{ matrix.coverage == 'cov' && '--cov=connectrpc --cov-report=xml' || '' }} - working-directory: noextras - - name: run Go tests run: go test ./... working-directory: protoc-gen-connect-python diff --git a/conformance/test/client.py b/conformance/test/client.py index ecabb30..ff7f1b1 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -18,11 +18,13 @@ from gen.connectrpc.conformance.v1.config_pb2 import Code as ConformanceCode from gen.connectrpc.conformance.v1.config_pb2 import ( Codec, - Compression, HTTPVersion, Protocol, StreamType, ) +from gen.connectrpc.conformance.v1.config_pb2 import ( + Compression as ConformanceCompression, +) from gen.connectrpc.conformance.v1.service_connect import ( ConformanceServiceClient, ConformanceServiceClientSync, @@ -42,6 +44,9 @@ from connectrpc.client import ResponseMetadata from connectrpc.code import Code +from connectrpc.compression.brotli import BrotliCompression +from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.zstd import ZstdCompression from connectrpc.errors import ConnectError from connectrpc.request import Headers @@ -50,6 +55,8 @@ from google.protobuf.any_pb2 import Any + from connectrpc.compression import Compression + def _convert_code(error: Code) -> ConformanceCode: match error: @@ -87,20 +94,16 @@ def _convert_code(error: Code) -> ConformanceCode: return ConformanceCode.CODE_UNAUTHENTICATED -def _convert_compression(compression: Compression) -> str: +def _convert_compression(compression: ConformanceCompression) -> Compression | None: match compression: - case Compression.COMPRESSION_IDENTITY: - return "identity" - case Compression.COMPRESSION_GZIP: - return "gzip" - case Compression.COMPRESSION_BR: - return "br" - case Compression.COMPRESSION_ZSTD: - return "zstd" - case Compression.COMPRESSION_DEFLATE: - return "deflate" - case Compression.COMPRESSION_SNAPPY: - return "snappy" + case ConformanceCompression.COMPRESSION_IDENTITY: + return None + case ConformanceCompression.COMPRESSION_GZIP: + return GZipCompression() + case ConformanceCompression.COMPRESSION_BR: + return BrotliCompression() + case ConformanceCompression.COMPRESSION_ZSTD: + return ZstdCompression() case _: msg = f"Unsupported compression: {compression}" raise ValueError(msg) @@ -152,6 +155,11 @@ async def client_sync( ConformanceServiceClientSync( f"{scheme}://{test_request.host}:{test_request.port}", http_client=http_client, + accept_compression=[ + GZipCompression(), + BrotliCompression(), + ZstdCompression(), + ], send_compression=_convert_compression(test_request.compression), proto_json=test_request.codec == Codec.CODEC_JSON, grpc=test_request.protocol == Protocol.PROTOCOL_GRPC, @@ -186,6 +194,11 @@ async def client_async( ConformanceServiceClient( f"{scheme}://{test_request.host}:{test_request.port}", http_client=http_client, + accept_compression=[ + GZipCompression(), + BrotliCompression(), + ZstdCompression(), + ], send_compression=_convert_compression(test_request.compression), proto_json=test_request.codec == Codec.CODEC_JSON, grpc=test_request.protocol == Protocol.PROTOCOL_GRPC, diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index f59439f..f24abdc 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -24,8 +24,10 @@ Iterable, Iterator, Mapping, + Sequence, ) + from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -89,7 +91,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -356,7 +358,7 @@ def __init__( service: ConformanceServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( endpoints={ diff --git a/conformance/test/server.py b/conformance/test/server.py index e634b2d..a65d22b 100644 --- a/conformance/test/server.py +++ b/conformance/test/server.py @@ -43,6 +43,9 @@ from google.protobuf.any_pb2 import Any from connectrpc.code import Code +from connectrpc.compression.brotli import BrotliCompression +from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.zstd import ZstdCompression from connectrpc.errors import ConnectError if TYPE_CHECKING: @@ -396,10 +399,14 @@ def bidi_stream( read_max_bytes = int(read_max_bytes) asgi_app = ConformanceServiceASGIApplication( - TestService(), read_max_bytes=read_max_bytes + TestService(), + read_max_bytes=read_max_bytes, + compressions=(GZipCompression(), ZstdCompression(), BrotliCompression()), ) wsgi_app = ConformanceServiceWSGIApplication( - TestServiceSync(), read_max_bytes=read_max_bytes + TestServiceSync(), + read_max_bytes=read_max_bytes, + compressions=(GZipCompression(), ZstdCompression(), BrotliCompression()), ) diff --git a/connect-python.code-workspace b/connect-python.code-workspace index 474f256..bab7fae 100644 --- a/connect-python.code-workspace +++ b/connect-python.code-workspace @@ -2,20 +2,16 @@ "folders": [ { "name": "/", - "path": "." + "path": ".", }, { "name": "conformance", - "path": "./conformance" + "path": "./conformance", }, { "name": "example", - "path": "./example" + "path": "./example", }, - { - "name": "noextras", - "path": "./noextras" - } ], "extensions": { "recommendations": [ @@ -26,8 +22,8 @@ "golang.go", "ms-python.python", "ms-python.vscode-pylance", - "nefrob.vscode-just-syntax" - ] + "nefrob.vscode-just-syntax", + ], }, "settings": { "editor.formatOnSave": true, @@ -39,25 +35,25 @@ "python.testing.unittestEnabled": false, "[github-actions-workflow]": { "editor.formatOnSave": true, - "editor.defaultFormatter": "esbenp.prettier-vscode" + "editor.defaultFormatter": "esbenp.prettier-vscode", }, "debugpy.debugJustMyCode": false, "[python]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.fixAll": "explicit", - "source.organizeImports": "explicit" + "source.organizeImports": "explicit", }, - "editor.defaultFormatter": "charliermarsh.ruff" + "editor.defaultFormatter": "charliermarsh.ruff", }, "[json][jsonc]": { "editor.formatOnSave": true, - "editor.defaultFormatter": "esbenp.prettier-vscode" + "editor.defaultFormatter": "esbenp.prettier-vscode", }, "yaml.format.enable": false, "[yaml]": { "editor.formatOnSave": true, - "editor.defaultFormatter": "esbenp.prettier-vscode" - } - } + "editor.defaultFormatter": "esbenp.prettier-vscode", + }, + }, } diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index a58afbc..f5dc8d4 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -24,8 +24,10 @@ Iterable, Iterator, Mapping, + Sequence, ) + from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -56,7 +58,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -192,7 +194,7 @@ def __init__( service: ElizaServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( endpoints={ diff --git a/justfile b/justfile index 7d5b73d..320c8fc 100644 --- a/justfile +++ b/justfile @@ -15,17 +15,12 @@ lint: uv run ruff format --check . uv run ruff check . -# Typecheck Python files +# Typecheck Python filesno typecheck: uv run pyright -# Run unit tests with no extras -[working-directory('noextras')] -test-noextras *args: - uv run --exact pytest -W error {{ args }} - # Run unit tests -test *args: (test-noextras args) +test *args: uv run pytest -W error {{ args }} # Run lint, typecheck and test diff --git a/noextras/README.md b/noextras/README.md deleted file mode 100644 index b772ec4..0000000 --- a/noextras/README.md +++ /dev/null @@ -1 +0,0 @@ -Tests for when no extra compression packages are installed. diff --git a/noextras/pyproject.toml b/noextras/pyproject.toml deleted file mode 100644 index 4622713..0000000 --- a/noextras/pyproject.toml +++ /dev/null @@ -1,16 +0,0 @@ -[project] -name = "connect-python-noextras" -version = "0.1.0" -dependencies = [ - "connect-python", - "connect-python-example", - - # Versions locked in constraint-dependencies - "pytest", - "pytest-asyncio", - "pytest-cov", - "pytest-timeout", -] - -[tool.ruff] -extend = "../pyproject.toml" diff --git a/noextras/test/test_compression_default.py b/noextras/test/test_compression_default.py deleted file mode 100644 index 416ef3d..0000000 --- a/noextras/test/test_compression_default.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import pytest -from connectrpc._compression import get_accept_encoding -from example.eliza_connect import ( - ElizaService, - ElizaServiceASGIApplication, - ElizaServiceClient, - ElizaServiceClientSync, - ElizaServiceSync, - ElizaServiceWSGIApplication, -) -from example.eliza_pb2 import SayRequest, SayResponse -from pyqwest import Client, SyncClient -from pyqwest.testing import ASGITransport, WSGITransport - - -@pytest.mark.parametrize("compression", ["gzip", "identity", None]) -def test_roundtrip_sync(compression: str) -> None: - class RoundtripElizaServiceSync(ElizaServiceSync): - def say(self, request, ctx): - return SayResponse(sentence=request.sentence) - - app = ElizaServiceWSGIApplication(RoundtripElizaServiceSync()) - with ElizaServiceClientSync( - "http://localhost", - http_client=SyncClient(WSGITransport(app=app)), - send_compression=compression, - accept_compression=[compression] if compression else None, - ) as client: - response = client.say(SayRequest(sentence="Hello")) - assert response.sentence == "Hello" - - -@pytest.mark.parametrize("compression", ["gzip", "identity"]) -@pytest.mark.asyncio -async def test_roundtrip_async(compression: str) -> None: - class DetailsElizaService(ElizaService): - async def say(self, request, ctx): - return SayResponse(sentence=request.sentence) - - app = ElizaServiceASGIApplication(DetailsElizaService()) - transport = ASGITransport(app) - async with ElizaServiceClient( - "http://localhost", - http_client=Client(transport), - send_compression=compression, - accept_compression=[compression] if compression else None, - ) as client: - response = await client.say(SayRequest(sentence="Hello")) - assert response.sentence == "Hello" - - -@pytest.mark.parametrize("compression", ["br", "zstd"]) -def test_invalid_compression_sync(compression: str) -> None: - class RoundtripElizaServiceSync(ElizaServiceSync): - def say(self, request, ctx): - return SayResponse(sentence=request.sentence) - - app = ElizaServiceWSGIApplication(RoundtripElizaServiceSync()) - - with pytest.raises( - ValueError, match=r"Unsupported compression method: .*" - ) as exc_info: - ElizaServiceClientSync( - "http://localhost", - http_client=SyncClient(WSGITransport(app=app)), - send_compression=compression, - accept_compression=[compression] if compression else None, - ) - assert ( - str(exc_info.value) - == f"Unsupported compression method: {compression}. Available methods: gzip, identity" - ) - - -@pytest.mark.parametrize("compression", ["br", "zstd"]) -@pytest.mark.asyncio -async def test_invalid_compression_async(compression: str) -> None: - class DetailsElizaService(ElizaService): - async def say(self, request, ctx): - return SayResponse(sentence=request.sentence) - - app = ElizaServiceASGIApplication(DetailsElizaService()) - transport = ASGITransport(app) - with pytest.raises( - ValueError, match=r"Unsupported compression method: .*" - ) as exc_info: - ElizaServiceClient( - "http://localhost", - http_client=Client(transport), - send_compression=compression, - accept_compression=[compression] if compression else None, - ) - assert ( - str(exc_info.value) - == f"Unsupported compression method: {compression}. Available methods: gzip, identity" - ) - - -def test_accept_encoding_only_includes_available_compressions(): - """Verify Accept-Encoding only advertises compressions that are actually available. - - When brotli and zstandard are not installed (as in the noextras environment), - the Accept-Encoding header should not include 'br' or 'zstd'. - """ - accept_encoding = get_accept_encoding() - assert accept_encoding == "gzip", f"Expected 'gzip' only, got '{accept_encoding}'" diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 07122f0..6d87b1b 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -44,11 +44,12 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*- # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: {{.FileName}} {{if .Services}} -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping, Sequence from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync from connectrpc.code import Code +from connectrpc.compression import Compression from connectrpc.errors import ConnectError from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.method import IdempotencyLevel, MethodInfo @@ -68,7 +69,7 @@ class {{.Name}}(Protocol):{{- range .Methods }} {{ end }} class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): - def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None: + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Sequence[Compression] | None = None) -> None: super().__init__( service=service, endpoints=lambda svc: { {{- range .Methods }} @@ -129,7 +130,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }} class {{.Name}}WSGIApplication(ConnectWSGIApplication): - def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None: + def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Sequence[Compression] | None = None) -> None: super().__init__( endpoints={ {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}( diff --git a/pyproject.toml b/pyproject.toml index 5b13a15..dfafc99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,7 +239,7 @@ exclude = [ ] [tool.uv.workspace] -members = ["example", "noextras"] +members = ["example"] [tool.uv.sources] connect-python = { workspace = true } diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index db496db..e1e9e3f 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -13,6 +13,7 @@ from . import _client_shared from ._asyncio_timeout import timeout as asyncio_timeout from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec +from ._compression import IdentityCompression, _gzip, resolve_compressions from ._interceptor_async import ( BidiStreamInterceptor, ClientStreamInterceptor, @@ -37,11 +38,11 @@ if TYPE_CHECKING: import sys - from collections.abc import AsyncIterator, Iterable, Mapping + from collections.abc import AsyncIterator, Iterable, Mapping, Sequence from types import TracebackType - from ._compression import Compression from ._envelope import EnvelopeReader + from .compression import Compression from .method import MethodInfo from .request import Headers, RequestContext @@ -92,8 +93,8 @@ def __init__( *, proto_json: bool = False, grpc: bool = False, - accept_compression: Iterable[str] | None = None, - send_compression: str | None = None, + accept_compression: Sequence[Compression] | None = None, + send_compression: Compression | None = _gzip, timeout_ms: int | None = None, read_max_bytes: int | None = None, interceptors: Iterable[Interceptor] = (), @@ -113,10 +114,9 @@ def __init__( """ self._address = address self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec() - self._accept_compression = accept_compression - self._send_compression = _client_shared.resolve_send_compression( - send_compression - ) + self._response_compressions = resolve_compressions(accept_compression) + self._accept_compression_header = ",".join(self._response_compressions.keys()) + self._send_compression = send_compression or IdentityCompression() self._timeout_ms = timeout_ms self._read_max_bytes = read_max_bytes if http_client: @@ -200,7 +200,7 @@ async def execute_unary( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=False, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return await self._execute_unary(request, ctx) @@ -220,7 +220,7 @@ async def execute_client_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return await self._execute_client_stream(request, ctx) @@ -240,7 +240,7 @@ def execute_server_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_server_stream(request, ctx) @@ -260,7 +260,7 @@ def execute_bidi_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_bidi_stream(request, ctx) @@ -308,7 +308,9 @@ async def _send_request_unary( ) # Decompression itself is handled by pyqwest, but we validate it # by resolving it. - self._protocol.handle_response_compression(resp.headers, stream=False) + self._protocol.handle_response_compression( + resp.headers, self._response_compressions, stream=False + ) handle_response_headers(resp.headers) if resp.status == 200: @@ -375,7 +377,7 @@ async def _send_request_bidi_stream( self._codec.name(), resp.headers.get("content-type", "") ) compression = self._protocol.handle_response_compression( - resp.headers, stream=True + resp.headers, self._response_compressions, stream=True ) reader = self._protocol.create_envelope_reader( ctx.method().output, diff --git a/src/connectrpc/_client_shared.py b/src/connectrpc/_client_shared.py index fe91b9f..0b91777 100644 --- a/src/connectrpc/_client_shared.py +++ b/src/connectrpc/_client_shared.py @@ -8,9 +8,7 @@ from pyqwest import Headers as HTTPHeaders from pyqwest import StreamError, StreamErrorCode -from . import _compression from ._codec import CODEC_NAME_JSON, CODEC_NAME_JSON_CHARSET_UTF8, Codec -from ._compression import Compression, get_available_compressions, get_compression from ._protocol import ConnectWireError from ._protocol_connect import ( CONNECT_PROTOCOL_VERSION, @@ -29,19 +27,6 @@ RES = TypeVar("RES") -def resolve_send_compression(compression_name: str | None) -> Compression | None: - if compression_name is None: - return None - compression = get_compression(compression_name) - if compression is None: - msg = ( - f"Unsupported compression method: {compression_name}. " - f"Available methods: {', '.join(get_available_compressions())}" - ) - raise ValueError(msg) - return compression - - def prepare_get_params( codec: Codec, request_data: bytes, headers: HTTPHeaders ) -> dict[str, str]: @@ -55,20 +40,6 @@ def prepare_get_params( return params -def validate_response_content_encoding( - encoding: str | None, -) -> _compression.Compression: - if not encoding: - return _compression.IdentityCompression() - res = _compression.get_compression(encoding.lower()) - if not res: - raise ConnectError( - Code.INTERNAL, - f"unknown encoding '{encoding}'; accepted encodings are {', '.join(_compression.get_available_compressions())}", - ) - return res - - def validate_unary_response( request_codec_name: str, status_code: int, response_content_type: str ) -> None: diff --git a/src/connectrpc/_client_sync.py b/src/connectrpc/_client_sync.py index decdccd..a35b154 100644 --- a/src/connectrpc/_client_sync.py +++ b/src/connectrpc/_client_sync.py @@ -11,6 +11,7 @@ from . import _client_shared from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec +from ._compression import IdentityCompression, _gzip, resolve_compressions from ._interceptor_sync import ( BidiStreamInterceptorSync, ClientStreamInterceptorSync, @@ -27,11 +28,11 @@ if TYPE_CHECKING: import sys - from collections.abc import Iterable, Iterator, Mapping + from collections.abc import Iterable, Iterator, Mapping, Sequence from types import TracebackType - from ._compression import Compression from ._envelope import EnvelopeReader + from .compression import Compression from .method import MethodInfo from .request import Headers, RequestContext @@ -82,8 +83,8 @@ def __init__( *, proto_json: bool = False, grpc: bool = False, - accept_compression: Iterable[str] | None = None, - send_compression: str | None = None, + accept_compression: Sequence[Compression] | None = None, + send_compression: Compression | None = _gzip, timeout_ms: int | None = None, read_max_bytes: int | None = None, interceptors: Iterable[InterceptorSync] = (), @@ -105,10 +106,9 @@ def __init__( self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec() self._timeout_ms = timeout_ms self._read_max_bytes = read_max_bytes - self._accept_compression = accept_compression - self._send_compression = _client_shared.resolve_send_compression( - send_compression - ) + self._response_compressions = resolve_compressions(accept_compression) + self._accept_compression_header = ",".join(self._response_compressions.keys()) + self._send_compression = send_compression or IdentityCompression() if http_client: self._http_client = http_client else: @@ -196,7 +196,7 @@ def execute_unary( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=False, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_unary(request, ctx) @@ -216,7 +216,7 @@ def execute_client_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_client_stream(request, ctx) @@ -236,7 +236,7 @@ def execute_server_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_server_stream(request, ctx) @@ -256,7 +256,7 @@ def execute_bidi_stream( timeout_ms=timeout_ms or self._timeout_ms, codec=self._codec, stream=True, - accept_compression=self._accept_compression, + accept_compression=self._accept_compression_header, send_compression=self._send_compression, ) return self._execute_bidi_stream(request, ctx) @@ -302,7 +302,9 @@ def _send_request_unary(self, request: REQ, ctx: RequestContext[REQ, RES]) -> RE ) # Decompression itself is handled by pyqwest, but we validate it # by resolving it. - self._protocol.handle_response_compression(resp.headers, stream=False) + self._protocol.handle_response_compression( + resp.headers, self._response_compressions, stream=False + ) handle_response_headers(resp.headers) if resp.status == 200: @@ -368,7 +370,7 @@ def _send_request_bidi_stream( self._codec.name(), resp.headers.get("content-type", "") ) compression = self._protocol.handle_response_compression( - resp.headers, stream=True + resp.headers, self._response_compressions, stream=True ) reader = self._protocol.create_envelope_reader( ctx.method().output, diff --git a/src/connectrpc/_compression.py b/src/connectrpc/_compression.py index eb4724d..c1d4820 100644 --- a/src/connectrpc/_compression.py +++ b/src/connectrpc/_compression.py @@ -1,125 +1,48 @@ from __future__ import annotations -import gzip -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING -if TYPE_CHECKING: - from collections.abc import KeysView - - -class Compression(Protocol): - def name(self) -> str: - """Returns the name of the compression method.""" - ... - - def compress(self, data: bytes | bytearray) -> bytes: - """Compress the given data.""" - ... - - def decompress(self, data: bytes | bytearray) -> bytes: - """Decompress the given data.""" - ... - - -_compressions: dict[str, Compression] = {} - - -class GZipCompression(Compression): - def name(self) -> str: - return "gzip" - - def compress(self, data: bytes | bytearray) -> bytes: - return gzip.compress(data, compresslevel=6) - - def decompress(self, data: bytes | bytearray) -> bytes: - return gzip.decompress(data) - - -_compressions["gzip"] = GZipCompression() +from connectrpc.compression.gzip import GZipCompression -try: - import brotli +from .compression import Compression - class BrotliCompression(Compression): - def name(self) -> str: - return "br" - - def compress(self, data: bytes | bytearray) -> bytes: - return brotli.compress(data, quality=3) - - def decompress(self, data: bytes | bytearray) -> bytes: - return brotli.decompress(data) - - _compressions["br"] = BrotliCompression() -except ImportError: - pass - -try: - import zstandard - - class ZstdCompression(Compression): - def name(self) -> str: - return "zstd" - - def compress(self, data: bytes | bytearray) -> bytes: - return zstandard.ZstdCompressor().compress(data) - - def decompress(self, data: bytes | bytearray) -> bytes: - # Support clients sending frames without length by using - # stream API. - with zstandard.ZstdDecompressor().stream_reader(data) as reader: - return reader.read() - - _compressions["zstd"] = ZstdCompression() -except ImportError: - pass +if TYPE_CHECKING: + from collections.abc import Sequence class IdentityCompression(Compression): def name(self) -> str: return "identity" - def compress(self, data: bytes | bytearray) -> bytes: + def compress(self, data: bytes | bytearray | memoryview) -> bytes: """Return data as-is without compression.""" return bytes(data) - def decompress(self, data: bytes | bytearray) -> bytes: + def decompress(self, data: bytes | bytearray | memoryview) -> bytes: """Return data as-is without decompression.""" return bytes(data) _identity = IdentityCompression() -_compressions["identity"] = _identity - -# Preferred compression names for Accept-Encoding header, in order of preference. -# Excludes 'identity' since it's an implicit fallback. -DEFAULT_ACCEPT_ENCODING_COMPRESSIONS = ("gzip", "br", "zstd") - - -def get_compression(name: str) -> Compression | None: - return _compressions.get(name.lower()) - - -def get_available_compressions() -> KeysView: - """Returns a list of available compression names.""" - return _compressions.keys() +_gzip = GZipCompression() +_default_compressions = {"gzip": _gzip, "identity": _identity} -def get_accept_encoding() -> str: - """Returns Accept-Encoding header value with available compressions in preference order. - This excludes 'identity' since it's an implicit fallback, and returns - only compressions that are actually available (i.e., their dependencies are installed). - """ - return ", ".join( - name for name in DEFAULT_ACCEPT_ENCODING_COMPRESSIONS if name in _compressions - ) +def resolve_compressions( + compressions: Sequence[Compression] | None, +) -> dict[str, Compression]: + if compressions is None: + return _default_compressions + res = {comp.name(): comp for comp in compressions} + # identity is always supported + res["identity"] = _identity + return res def negotiate_compression( - accept_encoding: str, compressions: dict[str, Compression] | None + accept_encoding: str, compressions: dict[str, Compression] ) -> Compression: - compressions = compressions if compressions is not None else _compressions for accept in accept_encoding.split(","): compression = compressions.get(accept.strip()) if compression: diff --git a/src/connectrpc/_protocol.py b/src/connectrpc/_protocol.py index 3995ef0..7009f29 100644 --- a/src/connectrpc/_protocol.py +++ b/src/connectrpc/_protocol.py @@ -13,7 +13,7 @@ from .errors import ConnectError if TYPE_CHECKING: - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Mapping, Sequence from pyqwest import FullResponse from pyqwest import Headers as HTTPHeaders @@ -207,7 +207,7 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st ... def negotiate_stream_compression( - self, headers: Headers, compressions: dict[str, Compression] | None + self, headers: Headers, compressions: dict[str, Compression] ) -> tuple[Compression | None, Compression]: """Negotiates request and response compression based on headers.""" ... @@ -223,7 +223,7 @@ def create_request_context( timeout_ms: int | None, codec: Codec, stream: bool, - accept_compression: Iterable[str] | None, + accept_compression: str, send_compression: Compression | None, ) -> RequestContext[REQ, RES]: """Creates a RequestContext for the given method and headers.""" diff --git a/src/connectrpc/_protocol_connect.py b/src/connectrpc/_protocol_connect.py index 0922ff2..d115c8d 100644 --- a/src/connectrpc/_protocol_connect.py +++ b/src/connectrpc/_protocol_connect.py @@ -6,13 +6,7 @@ from typing import TYPE_CHECKING, Any, TypeVar from ._codec import CODEC_NAME_JSON, CODEC_NAME_JSON_CHARSET_UTF8, Codec -from ._compression import ( - IdentityCompression, - get_accept_encoding, - get_available_compressions, - get_compression, - negotiate_compression, -) +from ._compression import IdentityCompression, negotiate_compression from ._envelope import EnvelopeReader, EnvelopeWriter from ._protocol import ConnectWireError, HTTPException from ._response_metadata import handle_response_trailers @@ -23,7 +17,7 @@ from .request import Headers, RequestContext if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Mapping import pyqwest @@ -123,12 +117,14 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st return codec_name_from_content_type(content_type, stream=stream) def negotiate_stream_compression( - self, headers: Headers, compressions: dict[str, Compression] | None + self, headers: Headers, compressions: dict[str, Compression] ) -> tuple[Compression, Compression]: req_compression_name = headers.get( CONNECT_STREAMING_HEADER_COMPRESSION, "identity" ) - req_compression = get_compression(req_compression_name) or IdentityCompression() + req_compression = ( + compressions.get(req_compression_name) or IdentityCompression() + ) accept_compression = headers.get( CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, "" ) @@ -162,7 +158,7 @@ def create_request_context( timeout_ms: int | None, codec: Codec, stream: bool, - accept_compression: Iterable[str] | None, + accept_compression: str, send_compression: Compression | None, ) -> RequestContext[REQ, RES]: match user_headers: @@ -190,10 +186,7 @@ def create_request_context( else CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION ) - if accept_compression is not None: - headers[accept_compression_header] = ", ".join(accept_compression) - else: - headers[accept_compression_header] = get_accept_encoding() + headers[accept_compression_header] = accept_compression if send_compression is not None: headers[compression_header] = send_compression.name() else: @@ -269,7 +262,11 @@ def validate_stream_response( ) def handle_response_compression( - self, headers: pyqwest.Headers, *, stream: bool + self, + headers: pyqwest.Headers, + compressions: dict[str, Compression], + *, + stream: bool, ) -> Compression: compression_header = ( CONNECT_STREAMING_HEADER_COMPRESSION @@ -279,11 +276,11 @@ def handle_response_compression( encoding = headers.get(compression_header) if not encoding: return IdentityCompression() - res = get_compression(encoding) + res = compressions.get(encoding) if not res: raise ConnectError( Code.INTERNAL, - f"unknown encoding '{encoding}'; accepted encodings are {', '.join(get_available_compressions())}", + f"unknown encoding '{encoding}'; accepted encodings are {', '.join(compressions.keys())}", ) return res diff --git a/src/connectrpc/_protocol_grpc.py b/src/connectrpc/_protocol_grpc.py index 8c0216f..3fc0297 100644 --- a/src/connectrpc/_protocol_grpc.py +++ b/src/connectrpc/_protocol_grpc.py @@ -6,13 +6,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any, TypeVar -from ._compression import ( - IdentityCompression, - get_accept_encoding, - get_available_compressions, - get_compression, - negotiate_compression, -) +from ._compression import IdentityCompression, negotiate_compression from ._envelope import EnvelopeReader, EnvelopeWriter from ._gen.status_pb2 import Status from ._protocol import ConnectWireError, HTTPException @@ -23,7 +17,7 @@ from .request import Headers, RequestContext if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Mapping from pyqwest import Headers as HTTPHeaders from pyqwest import Response, SyncResponse @@ -82,10 +76,10 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st return "proto" def negotiate_stream_compression( - self, headers: Headers, compressions: dict[str, Compression] | None + self, headers: Headers, compressions: dict[str, Compression] ) -> tuple[Compression | None, Compression]: req_compression_name = headers.get(GRPC_HEADER_COMPRESSION, "identity") - req_compression = get_compression(req_compression_name) + req_compression = compressions.get(req_compression_name) accept_compression = headers.get(GRPC_HEADER_ACCEPT_COMPRESSION, "") resp_compression = negotiate_compression(accept_compression, compressions) return req_compression, resp_compression @@ -160,7 +154,7 @@ def create_request_context( timeout_ms: int | None, codec: Codec, stream: bool, - accept_compression: Iterable[str] | None, + accept_compression: str, send_compression: Compression | None, ) -> RequestContext[REQ, RES]: match user_headers: @@ -176,10 +170,7 @@ def create_request_context( if "user-agent" not in headers: headers["user-agent"] = _DEFAULT_GRPC_USER_AGENT - if accept_compression is not None: - headers[GRPC_HEADER_ACCEPT_COMPRESSION] = ",".join(accept_compression) - else: - headers[GRPC_HEADER_ACCEPT_COMPRESSION] = get_accept_encoding() + headers[GRPC_HEADER_ACCEPT_COMPRESSION] = accept_compression if send_compression is not None: headers[GRPC_HEADER_COMPRESSION] = send_compression.name() else: @@ -222,16 +213,16 @@ def validate_stream_response( ) def handle_response_compression( - self, headers: HTTPHeaders, *, stream: bool + self, headers: HTTPHeaders, compressions: dict[str, Compression], stream: bool ) -> Compression: encoding = headers.get(GRPC_HEADER_COMPRESSION) if not encoding: return IdentityCompression() - res = get_compression(encoding) + res = compressions.get(encoding) if not res: raise ConnectError( Code.INTERNAL, - f"unknown encoding '{encoding}'; accepted encodings are {', '.join(get_available_compressions())}", + f"unknown encoding '{encoding}'; accepted encodings are {', '.join(compressions.keys())}", ) return res diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 52a171f..d6fc3e8 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Generic, TypeVar, cast from urllib.parse import parse_qs -from . import _compression, _server_shared from ._codec import Codec, get_codec +from ._compression import negotiate_compression, resolve_compressions from ._envelope import EnvelopeReader from ._interceptor_async import ( BidiStreamInterceptor, @@ -46,6 +46,9 @@ ) from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope + + from . import _server_shared + from .compression import Compression else: ASGIReceiveCallable = "asgiref.typing.ASGIReceiveCallable" ASGISendCallable = "asgiref.typing.ASGISendCallable" @@ -86,7 +89,7 @@ def __init__( endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: """Initialize the ASGI application. @@ -95,8 +98,7 @@ def __init__( endpoints: A mapping of URL paths to endpoints resolved from service. interceptors: A sequence of interceptors to apply to the endpoints. read_max_bytes: Maximum size of request messages. - compressions: Supported compression algorithms. If unset, - defaults to gzip along with zstd and br if available. + compressions: Supported compression algorithms. If unset, defaults to gzip. If set to empty, disables compression. """ super().__init__() @@ -105,17 +107,7 @@ def __init__( self._interceptors = interceptors self._resolved_endpoints = None self._read_max_bytes = read_max_bytes - if compressions is not None: - compressions_dict: dict[str, _compression.Compression] = {} - for name in compressions: - comp = _compression.get_compression(name) - if not comp: - msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}" - raise ValueError(msg) - compressions_dict[name] = comp - self._compressions = compressions_dict - else: - self._compressions = None + self._compressions = resolve_compressions(compressions) async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @@ -247,9 +239,7 @@ async def _handle_unary_connect( ctx: RequestContext, ) -> None: accept_encoding = headers.get("accept-encoding", "") - compression = _compression.negotiate_compression( - accept_encoding, self._compressions - ) + compression = negotiate_compression(accept_encoding, self._compressions) if http_method == "GET": request = await self._read_get_request(endpoint, codec, query_params) @@ -313,11 +303,11 @@ async def _read_get_request( # Handle compression compression_name = params.get("compression", ["identity"])[0] - compression = _compression.get_compression(compression_name) + compression = self._compressions.get(compression_name) if not compression: raise ConnectError( Code.UNIMPLEMENTED, - f"unknown compression: '{compression_name}': supported encodings are {', '.join(_compression.get_available_compressions())}", + f"unknown compression: '{compression_name}': supported encodings are {', '.join(self._compressions.keys())}", ) # Decompress and decode message @@ -342,11 +332,11 @@ async def _read_post_request( # Handle compression if specified compression_name = headers.get("content-encoding", "identity").lower() - compression = _compression.get_compression(compression_name) + compression = self._compressions.get(compression_name) if not compression: raise ConnectError( Code.UNIMPLEMENTED, - f"unknown compression: '{compression_name}': supported encodings are {', '.join(_compression.get_available_compressions())}", + f"unknown compression: '{compression_name}': supported encodings are {', '.join(self._compressions.keys())}", ) if req_body: # Don't decompress empty body @@ -525,7 +515,7 @@ async def _request_stream( receive: ASGIReceiveCallable, request_class: type[_REQ], codec: Codec, - compression: _compression.Compression, + compression: Compression, read_max_bytes: int | None = None, ) -> AsyncIterator[_REQ]: reader = EnvelopeReader(request_class, codec, compression, read_max_bytes) diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index 33184fb..5e8c7e1 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -8,8 +8,9 @@ from typing import TYPE_CHECKING, TypeVar from urllib.parse import parse_qs -from . import _compression, _server_shared +from . import _server_shared from ._codec import Codec, get_codec +from ._compression import negotiate_compression, resolve_compressions from ._envelope import EnvelopeReader, EnvelopeWriter from ._interceptor_sync import ( BidiStreamInterceptorSync, @@ -42,6 +43,8 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from io import BytesIO + from .compression import Compression + if sys.version_info >= (3, 11): from wsgiref.types import StartResponse, WSGIEnvironment else: @@ -165,7 +168,7 @@ def __init__( endpoints: Mapping[str, EndpointSync], interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: """Initialize the WSGI application. @@ -173,8 +176,7 @@ def __init__( endpoints: A mapping of URL paths to service endpoints. interceptors: A sequence of interceptors to apply to the endpoints. read_max_bytes: Maximum size of request messages. - compressions: Supported compression algorithms. If unset, - defaults to gzip along with zstd and br if available. + compressions: Supported compression algorithms. If unset, defaults to gzip. If set to empty, disables compression. """ super().__init__() @@ -191,17 +193,7 @@ def __init__( } self._endpoints = endpoints self._read_max_bytes = read_max_bytes - if compressions is not None: - compressions_dict: dict[str, _compression.Compression] = {} - for name in compressions: - comp = _compression.get_compression(name) - if not comp: - msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}" - raise ValueError(msg) - compressions_dict[name] = comp - self._compressions = compressions_dict - else: - self._compressions = None + self._compressions = resolve_compressions(compressions) def __call__( self, environ: WSGIEnvironment, start_response: StartResponse @@ -274,9 +266,7 @@ def _handle_unary( # Handle compression if accepted accept_encoding = headers.get("accept-encoding", "identity") - compression = _compression.negotiate_compression( - accept_encoding, self._compressions - ) + compression = negotiate_compression(accept_encoding, self._compressions) res_bytes = compression.compress(res_bytes) response_headers = prepare_response_headers(base_headers, compression.name()) @@ -319,11 +309,11 @@ def _handle_post_request( # Handle compression if specified compression_name = environ.get("HTTP_CONTENT_ENCODING", "identity").lower() if compression_name != "identity": - compression = _compression.get_compression(compression_name) + compression = self._compressions.get(compression_name) if not compression: raise ConnectError( Code.UNIMPLEMENTED, - f"unknown compression: '{compression_name}': supported encodings are {', '.join(_compression.get_available_compressions())}", + f"unknown compression: '{compression_name}': supported encodings are {', '.join(self._compressions.keys())}", ) try: req_body = compression.decompress(req_body) @@ -386,11 +376,11 @@ def _handle_get_request( # Handle compression if specified if "compression" in params: compression_name = params["compression"][0] - compression = _compression.get_compression(compression_name) + compression = self._compressions.get(compression_name) if not compression: raise ConnectError( Code.UNIMPLEMENTED, - f"unknown compression: '{compression_name}': supported encodings are {', '.join(_compression.get_available_compressions())}", + f"unknown compression: '{compression_name}': supported encodings are {', '.join(self._compressions.keys())}", ) message = compression.decompress(message) @@ -575,7 +565,7 @@ def _request_stream( environ: WSGIEnvironment, request_class: type[_REQ], codec: Codec, - compression: _compression.Compression, + compression: Compression, read_max_bytes: int | None = None, ) -> Iterator[_REQ]: reader = EnvelopeReader(request_class, codec, compression, read_max_bytes) diff --git a/src/connectrpc/compression/__init__.py b/src/connectrpc/compression/__init__.py new file mode 100644 index 0000000..4d72eb9 --- /dev/null +++ b/src/connectrpc/compression/__init__.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +__all__ = ["Compression"] + + +from typing import Protocol + + +class Compression(Protocol): + """Protocol for compression methods. + + By default, gzip compression is used. Other compression methods can be + used by specifying implementations of this protocol. We provide standard + implementations for + + - br (connectrpc.compression.brotli.BrotliCompression) - requires the brotli dependency + - zstd (connectrpc.compression.zstd.ZstdCompression) - requires the zstandard dependency + """ + + def name(self) -> str: + """Returns the name of the compression method. This value is used in HTTP + headers to indicate accepted and used compression. + """ + ... + + def compress(self, data: bytes | bytearray | memoryview) -> bytes: + """Compress the given data.""" + ... + + def decompress(self, data: bytes | bytearray | memoryview) -> bytes: + """Decompress the given data.""" + ... diff --git a/src/connectrpc/compression/brotli.py b/src/connectrpc/compression/brotli.py new file mode 100644 index 0000000..75818ca --- /dev/null +++ b/src/connectrpc/compression/brotli.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +__all__ = ["BrotliCompression"] + +import brotli + +from . import Compression + + +class BrotliCompression(Compression): + """Compression implementation using Brotli.""" + + def __init__(self, quality: int = 3) -> None: + """Creates a new BrotliCompression. + + Args: + quality: Compression quality to use. + """ + self._quality = quality + + def name(self) -> str: + return "br" + + def compress(self, data: bytes | bytearray | memoryview) -> bytes: + return brotli.compress(data, quality=self._quality) + + def decompress(self, data: bytes | bytearray | memoryview) -> bytes: + return brotli.decompress(data) diff --git a/src/connectrpc/compression/gzip.py b/src/connectrpc/compression/gzip.py new file mode 100644 index 0000000..5d25e0c --- /dev/null +++ b/src/connectrpc/compression/gzip.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import gzip + +from . import Compression + + +class GZipCompression(Compression): + """Compression implementation using GZip.""" + + def __init__(self, level: int = 6) -> None: + """Creates a new GZipCompression. + + Args: + level: Compression level to use. + """ + self._level = level + + def name(self) -> str: + return "gzip" + + def compress(self, data: bytes | bytearray | memoryview) -> bytes: + return gzip.compress(data, compresslevel=self._level) + + def decompress(self, data: bytes | bytearray | memoryview) -> bytes: + return gzip.decompress(data) diff --git a/src/connectrpc/compression/zstd.py b/src/connectrpc/compression/zstd.py new file mode 100644 index 0000000..3923ebb --- /dev/null +++ b/src/connectrpc/compression/zstd.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +__all__ = ["ZstdCompression"] + +import zstandard + +from . import Compression + + +class ZstdCompression(Compression): + """Compression implementation using Zstandard.""" + + def __init__(self, level: int = 3) -> None: + """Creates a new ZstdCompression. + + Args: + level: Compression level to use. + """ + self._level = level + + def name(self) -> str: + return "zstd" + + def compress(self, data: bytes | bytearray | memoryview) -> bytes: + return zstandard.ZstdCompressor(level=self._level).compress(data) + + def decompress(self, data: bytes | bytearray | memoryview) -> bytes: + # Support clients sending frames without length by using + # stream API. + with zstandard.ZstdDecompressor().stream_reader(data) as reader: + return reader.read() diff --git a/test/_util.py b/test/_util.py new file mode 100644 index 0000000..5da895e --- /dev/null +++ b/test/_util.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from connectrpc._compression import IdentityCompression +from connectrpc.compression.brotli import BrotliCompression +from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.zstd import ZstdCompression + +if TYPE_CHECKING: + from connectrpc.compression import Compression + + +def resolve_compression(encoding: str) -> Compression: + match encoding: + case "gzip": + return GZipCompression() + case "br": + return BrotliCompression() + case "zstd": + return ZstdCompression() + case "identity": + return IdentityCompression() + case _: + msg = f"unknown encoding '{encoding}'" + raise ValueError(msg) diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 38e8718..31481e2 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -26,8 +26,10 @@ Iterable, Iterator, Mapping, + Sequence, ) + from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -71,7 +73,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -310,7 +312,7 @@ def __init__( service: HaberdasherSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Iterable[str] | None = None, + compressions: Sequence[Compression] | None = None, ) -> None: super().__init__( endpoints={ diff --git a/test/test_compression.py b/test/test_compression.py index 3cb3a95..7e5d396 100644 --- a/test/test_compression.py +++ b/test/test_compression.py @@ -5,7 +5,11 @@ from pyqwest.testing import ASGITransport, WSGITransport from connectrpc.client import ResponseMetadata +from connectrpc.compression.brotli import BrotliCompression +from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.zstd import ZstdCompression +from ._util import resolve_compression from .haberdasher_connect import ( Haberdasher, HaberdasherASGIApplication, @@ -35,12 +39,19 @@ class SimpleHaberdasher(Haberdasher): async def make_hat(self, request, ctx): return Hat(size=10, color="blue") - app = HaberdasherASGIApplication(SimpleHaberdasher(), compressions=compressions) + app = HaberdasherASGIApplication( + SimpleHaberdasher(), compressions=[resolve_compression(c) for c in compressions] + ) with ResponseMetadata() as meta: client = HaberdasherClient( "http://localhost", http_client=Client(ASGITransport(app)), - accept_compression=["zstd", "gzip", "br"], + accept_compression=( + ZstdCompression(), + GZipCompression(), + BrotliCompression(), + ), + send_compression=None, ) res = await client.make_hat(Size(inches=10)) assert res.size == 10 @@ -63,32 +74,17 @@ class SimpleHaberdasher(HaberdasherSync): def make_hat(self, request, ctx): return Hat(size=10, color="blue") - app = HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=compressions) + app = HaberdasherWSGIApplication( + SimpleHaberdasher(), compressions=[resolve_compression(c) for c in compressions] + ) client = HaberdasherClientSync( "http://localhost", http_client=SyncClient(WSGITransport(app)), - accept_compression=["zstd", "gzip", "br"], + accept_compression=(ZstdCompression(), GZipCompression(), BrotliCompression()), + send_compression=None, ) with ResponseMetadata() as meta: res = client.make_hat(Size(inches=10)) assert res.size == 10 assert res.color == "blue" assert meta.headers().get("content-encoding") == encoding - - -def test_server_unsupported_compression_async() -> None: - class SimpleHaberdasher(HaberdasherSync): - def make_hat(self, request, ctx): - return Hat(size=10, color="blue") - - with pytest.raises(ValueError, match="unknown compression"): - HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=("unknown",)) - - -def test_server_unsupported_compression_sync() -> None: - class SimpleHaberdasher(Haberdasher): - async def make_hat(self, request, ctx): - return Hat(size=10, color="blue") - - with pytest.raises(ValueError, match="unknown compression"): - HaberdasherASGIApplication(SimpleHaberdasher(), compressions=("unknown",)) diff --git a/test/test_errors.py b/test/test_errors.py index fa3bb0c..cfcc226 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -170,7 +170,7 @@ async def execute(self, request: Request) -> Response: b"weird encoding", {"content-type": "application/proto", "content-encoding": "weird"}, Code.INTERNAL, - "unknown encoding 'weird'; accepted encodings are gzip, br, zstd, identity", + "unknown encoding 'weird'; accepted encodings are gzip, identity", id="bad encoding", ), ] diff --git a/test/test_roundtrip.py b/test/test_roundtrip.py index f3a6303..21386c5 100644 --- a/test/test_roundtrip.py +++ b/test/test_roundtrip.py @@ -9,6 +9,7 @@ from connectrpc.code import Code from connectrpc.errors import ConnectError +from ._util import resolve_compression from .haberdasher_connect import ( Haberdasher, HaberdasherASGIApplication, @@ -24,19 +25,22 @@ @pytest.mark.parametrize("proto_json", [False, True]) -@pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity", None]) -def test_roundtrip_sync(proto_json: bool, compression: str) -> None: +@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) +def test_roundtrip_sync(proto_json: bool, compression_name: str) -> None: class RoundtripHaberdasherSync(HaberdasherSync): def make_hat(self, request, ctx): return Hat(size=request.inches, color="green") - app = HaberdasherWSGIApplication(RoundtripHaberdasherSync()) + compression = resolve_compression(compression_name) + app = HaberdasherWSGIApplication( + RoundtripHaberdasherSync(), compressions=[compression] + ) with HaberdasherClientSync( "http://localhost", http_client=SyncClient(WSGITransport(app=app)), proto_json=proto_json, send_compression=compression, - accept_compression=[compression] if compression else None, + accept_compression=[compression], ) as client: response = client.make_hat(request=Size(inches=10)) assert response.size == 10 @@ -44,21 +48,22 @@ def make_hat(self, request, ctx): @pytest.mark.parametrize("proto_json", [False, True]) -@pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity"]) +@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) @pytest.mark.asyncio -async def test_roundtrip_async(proto_json: bool, compression: str) -> None: +async def test_roundtrip_async(proto_json: bool, compression_name: str) -> None: class DetailsHaberdasher(Haberdasher): async def make_hat(self, request, ctx): return Hat(size=request.inches, color="green") - app = HaberdasherASGIApplication(DetailsHaberdasher()) + compression = resolve_compression(compression_name) + app = HaberdasherASGIApplication(DetailsHaberdasher(), compressions=[compression]) transport = ASGITransport(app) async with HaberdasherClient( "http://localhost", http_client=Client(transport), proto_json=proto_json, send_compression=compression, - accept_compression=[compression] if compression else None, + accept_compression=[compression], ) as client: response = await client.make_hat(request=Size(inches=10)) assert response.size == 10 @@ -66,10 +71,10 @@ async def make_hat(self, request, ctx): @pytest.mark.parametrize("proto_json", [False, True]) -@pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity"]) +@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) @pytest.mark.asyncio async def test_roundtrip_response_stream_async( - proto_json: bool, compression: str + proto_json: bool, compression_name: str ) -> None: class StreamingHaberdasher(Haberdasher): async def make_similar_hats(self, request, ctx): @@ -78,7 +83,8 @@ async def make_similar_hats(self, request, ctx): yield Hat(size=request.inches, color="blue") raise ConnectError(Code.RESOURCE_EXHAUSTED, "No more hats available") - app = HaberdasherASGIApplication(StreamingHaberdasher()) + compression = resolve_compression(compression_name) + app = HaberdasherASGIApplication(StreamingHaberdasher(), compressions=[compression]) transport = ASGITransport(app) hats: list[Hat] = [] @@ -87,7 +93,7 @@ async def make_similar_hats(self, request, ctx): http_client=Client(transport=transport), proto_json=proto_json, send_compression=compression, - accept_compression=[compression] if compression else None, + accept_compression=[compression], ) as client: with pytest.raises(ConnectError) as exc_info: async for h in client.make_similar_hats(request=Size(inches=10)): @@ -104,8 +110,8 @@ async def make_similar_hats(self, request, ctx): @pytest.mark.parametrize("client_bad", [False, True]) -@pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity"]) -def test_message_limit_sync(client_bad: bool, compression: str) -> None: +@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) +def test_message_limit_sync(client_bad: bool, compression_name: str) -> None: requests: list[Size] = [] responses: list[Hat] = [] @@ -125,13 +131,16 @@ def make_various_hats(self, request: Iterator[Size], ctx) -> Iterator[Hat]: yield Hat(color="good") yield good_hat if client_bad else bad_hat - app = HaberdasherWSGIApplication(LargeHaberdasher(), read_max_bytes=100) + compression = resolve_compression(compression_name) + app = HaberdasherWSGIApplication( + LargeHaberdasher(), read_max_bytes=100, compressions=[compression] + ) transport = WSGITransport(app) with HaberdasherClientSync( "http://localhost", http_client=SyncClient(transport), send_compression=compression, - accept_compression=[compression] if compression else None, + accept_compression=[compression], read_max_bytes=100, ) as client: with pytest.raises(ConnectError) as exc_info: @@ -166,9 +175,9 @@ def request_stream(): @pytest.mark.parametrize("client_bad", [False, True]) -@pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity"]) +@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"]) @pytest.mark.asyncio -async def test_message_limit_async(client_bad: bool, compression: str) -> None: +async def test_message_limit_async(client_bad: bool, compression_name: str) -> None: requests: list[Size] = [] responses: list[Hat] = [] @@ -190,13 +199,16 @@ async def make_various_hats( yield Hat(color="good") yield good_hat if client_bad else bad_hat - app = HaberdasherASGIApplication(LargeHaberdasher(), read_max_bytes=100) + compression = resolve_compression(compression_name) + app = HaberdasherASGIApplication( + LargeHaberdasher(), read_max_bytes=100, compressions=[compression] + ) transport = ASGITransport(app) async with HaberdasherClient( "http://localhost", http_client=Client(transport=transport), send_compression=compression, - accept_compression=[compression] if compression else None, + accept_compression=[compression], read_max_bytes=100, ) as client: with pytest.raises(ConnectError) as exc_info: diff --git a/uv.lock b/uv.lock index d94c357..bea84f4 100644 --- a/uv.lock +++ b/uv.lock @@ -13,7 +13,6 @@ resolution-mode = "lowest-direct" members = [ "connect-python", "connect-python-example", - "connect-python-noextras", ] constraints = [ { name = "pytest", specifier = "==9.0.2" }, @@ -476,29 +475,6 @@ requires-dist = [ { name = "starlette", specifier = "==0.52.1" }, ] -[[package]] -name = "connect-python-noextras" -version = "0.1.0" -source = { virtual = "noextras" } -dependencies = [ - { name = "connect-python" }, - { name = "connect-python-example" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-cov" }, - { name = "pytest-timeout" }, -] - -[package.metadata] -requires-dist = [ - { name = "connect-python", editable = "." }, - { name = "connect-python-example", editable = "example" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-cov" }, - { name = "pytest-timeout" }, -] - [[package]] name = "constantly" version = "23.10.4" From 320740620136d5e72a7bf1d3d401a8c58ee3aa37 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 27 Jan 2026 17:54:47 +0900 Subject: [PATCH 2/5] Doc Signed-off-by: Anuraag Agrawal --- justfile | 2 +- src/connectrpc/_client_async.py | 6 ++++-- src/connectrpc/_client_sync.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/justfile b/justfile index 320c8fc..5fba33f 100644 --- a/justfile +++ b/justfile @@ -15,7 +15,7 @@ lint: uv run ruff format --check . uv run ruff check . -# Typecheck Python filesno +# Typecheck Python files typecheck: uv run pyright diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index e1e9e3f..50a5d83 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -105,8 +105,10 @@ def __init__( Args: address: The address of the server to connect to, including scheme. proto_json: Whether to use JSON for the protocol - accept_compression: A list of compression algorithms to accept from the server - send_compression: The compression algorithm to use for sending requests + accept_compression: Compression algorithms to accept from the server. If unset, + defaults to gzip. If set to empty, disables response compression. + send_compression: Compression algorithm to use for sending requests. If unset, + defaults to gzip. If set to None, disables request compression. timeout_ms: The timeout for requests in milliseconds read_max_bytes: The maximum number of bytes to read from the response interceptors: A list of interceptors to apply to requests diff --git a/src/connectrpc/_client_sync.py b/src/connectrpc/_client_sync.py index a35b154..089d039 100644 --- a/src/connectrpc/_client_sync.py +++ b/src/connectrpc/_client_sync.py @@ -95,8 +95,10 @@ def __init__( Args: address: The address of the server to connect to, including scheme. proto_json: Whether to use JSON for the protocol - accept_compression: A list of compression algorithms to accept from the server - send_compression: The compression algorithm to use for sending requests + accept_compression: Compression algorithms to accept from the server. If unset, + defaults to gzip. If set to empty, disables response compression. + send_compression: Compression algorithm to use for sending requests. If unset, + defaults to gzip. If set to None, disables request compression. timeout_ms: The timeout for requests in milliseconds read_max_bytes: The maximum number of bytes to read from the response interceptors: A list of interceptors to apply to requests From f6345cf329b6c1c48567e1605316a805d5ac517e Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 28 Jan 2026 10:34:20 +0900 Subject: [PATCH 3/5] Rename Signed-off-by: Anuraag Agrawal --- conformance/test/client.py | 8 ++++---- conformance/test/server.py | 6 +++--- src/connectrpc/_compression.py | 4 ++-- src/connectrpc/compression/gzip.py | 5 ++--- test/_util.py | 4 ++-- test/test_compression.py | 6 +++--- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/conformance/test/client.py b/conformance/test/client.py index ff7f1b1..00d7285 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -45,7 +45,7 @@ from connectrpc.client import ResponseMetadata from connectrpc.code import Code from connectrpc.compression.brotli import BrotliCompression -from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.gzip import GzipCompression from connectrpc.compression.zstd import ZstdCompression from connectrpc.errors import ConnectError from connectrpc.request import Headers @@ -99,7 +99,7 @@ def _convert_compression(compression: ConformanceCompression) -> Compression | N case ConformanceCompression.COMPRESSION_IDENTITY: return None case ConformanceCompression.COMPRESSION_GZIP: - return GZipCompression() + return GzipCompression() case ConformanceCompression.COMPRESSION_BR: return BrotliCompression() case ConformanceCompression.COMPRESSION_ZSTD: @@ -156,7 +156,7 @@ async def client_sync( f"{scheme}://{test_request.host}:{test_request.port}", http_client=http_client, accept_compression=[ - GZipCompression(), + GzipCompression(), BrotliCompression(), ZstdCompression(), ], @@ -195,7 +195,7 @@ async def client_async( f"{scheme}://{test_request.host}:{test_request.port}", http_client=http_client, accept_compression=[ - GZipCompression(), + GzipCompression(), BrotliCompression(), ZstdCompression(), ], diff --git a/conformance/test/server.py b/conformance/test/server.py index a65d22b..e7c45da 100644 --- a/conformance/test/server.py +++ b/conformance/test/server.py @@ -44,7 +44,7 @@ from connectrpc.code import Code from connectrpc.compression.brotli import BrotliCompression -from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.gzip import GzipCompression from connectrpc.compression.zstd import ZstdCompression from connectrpc.errors import ConnectError @@ -401,12 +401,12 @@ def bidi_stream( asgi_app = ConformanceServiceASGIApplication( TestService(), read_max_bytes=read_max_bytes, - compressions=(GZipCompression(), ZstdCompression(), BrotliCompression()), + compressions=(GzipCompression(), ZstdCompression(), BrotliCompression()), ) wsgi_app = ConformanceServiceWSGIApplication( TestServiceSync(), read_max_bytes=read_max_bytes, - compressions=(GZipCompression(), ZstdCompression(), BrotliCompression()), + compressions=(GzipCompression(), ZstdCompression(), BrotliCompression()), ) diff --git a/src/connectrpc/_compression.py b/src/connectrpc/_compression.py index c1d4820..291030c 100644 --- a/src/connectrpc/_compression.py +++ b/src/connectrpc/_compression.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.gzip import GzipCompression from .compression import Compression @@ -25,7 +25,7 @@ def decompress(self, data: bytes | bytearray | memoryview) -> bytes: _identity = IdentityCompression() -_gzip = GZipCompression() +_gzip = GzipCompression() _default_compressions = {"gzip": _gzip, "identity": _identity} diff --git a/src/connectrpc/compression/gzip.py b/src/connectrpc/compression/gzip.py index 5d25e0c..af060f7 100644 --- a/src/connectrpc/compression/gzip.py +++ b/src/connectrpc/compression/gzip.py @@ -5,12 +5,11 @@ from . import Compression -class GZipCompression(Compression): +class GzipCompression(Compression): """Compression implementation using GZip.""" def __init__(self, level: int = 6) -> None: - """Creates a new GZipCompression. - + """Creates a new GzipCompression. Args: level: Compression level to use. """ diff --git a/test/_util.py b/test/_util.py index 5da895e..db24edf 100644 --- a/test/_util.py +++ b/test/_util.py @@ -4,7 +4,7 @@ from connectrpc._compression import IdentityCompression from connectrpc.compression.brotli import BrotliCompression -from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.gzip import GzipCompression from connectrpc.compression.zstd import ZstdCompression if TYPE_CHECKING: @@ -14,7 +14,7 @@ def resolve_compression(encoding: str) -> Compression: match encoding: case "gzip": - return GZipCompression() + return GzipCompression() case "br": return BrotliCompression() case "zstd": diff --git a/test/test_compression.py b/test/test_compression.py index 7e5d396..838048a 100644 --- a/test/test_compression.py +++ b/test/test_compression.py @@ -6,7 +6,7 @@ from connectrpc.client import ResponseMetadata from connectrpc.compression.brotli import BrotliCompression -from connectrpc.compression.gzip import GZipCompression +from connectrpc.compression.gzip import GzipCompression from connectrpc.compression.zstd import ZstdCompression from ._util import resolve_compression @@ -48,7 +48,7 @@ async def make_hat(self, request, ctx): http_client=Client(ASGITransport(app)), accept_compression=( ZstdCompression(), - GZipCompression(), + GzipCompression(), BrotliCompression(), ), send_compression=None, @@ -80,7 +80,7 @@ def make_hat(self, request, ctx): client = HaberdasherClientSync( "http://localhost", http_client=SyncClient(WSGITransport(app)), - accept_compression=(ZstdCompression(), GZipCompression(), BrotliCompression()), + accept_compression=(ZstdCompression(), GzipCompression(), BrotliCompression()), send_compression=None, ) with ResponseMetadata() as meta: From b2aee61a2eb0e1762c46d14fc8057852bf4c1093 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 28 Jan 2026 10:36:37 +0900 Subject: [PATCH 4/5] iterable Signed-off-by: Anuraag Agrawal --- .../test/gen/connectrpc/conformance/v1/service_connect.py | 5 ++--- example/example/eliza_connect.py | 5 ++--- protoc-gen-connect-python/generator/template.go | 6 +++--- src/connectrpc/_client_async.py | 4 ++-- src/connectrpc/_client_sync.py | 4 ++-- src/connectrpc/_server_async.py | 2 +- src/connectrpc/_server_sync.py | 2 +- test/haberdasher_connect.py | 5 ++--- 8 files changed, 15 insertions(+), 18 deletions(-) diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index f24abdc..b66bfcf 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -24,7 +24,6 @@ Iterable, Iterator, Mapping, - Sequence, ) from connectrpc.compression import Compression @@ -91,7 +90,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -358,7 +357,7 @@ def __init__( service: ConformanceServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( endpoints={ diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index f5dc8d4..6ab062e 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -24,7 +24,6 @@ Iterable, Iterator, Mapping, - Sequence, ) from connectrpc.compression import Compression @@ -58,7 +57,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -194,7 +193,7 @@ def __init__( service: ElizaServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( endpoints={ diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 6d87b1b..3f74e0b 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*- # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: {{.FileName}} {{if .Services}} -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync @@ -69,7 +69,7 @@ class {{.Name}}(Protocol):{{- range .Methods }} {{ end }} class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): - def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Sequence[Compression] | None = None) -> None: + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None: super().__init__( service=service, endpoints=lambda svc: { {{- range .Methods }} @@ -130,7 +130,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }} class {{.Name}}WSGIApplication(ConnectWSGIApplication): - def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Sequence[Compression] | None = None) -> None: + def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None: super().__init__( endpoints={ {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}( diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 50a5d83..1e96448 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: import sys - from collections.abc import AsyncIterator, Iterable, Mapping, Sequence + from collections.abc import AsyncIterator, Iterable, Mapping from types import TracebackType from ._envelope import EnvelopeReader @@ -93,7 +93,7 @@ def __init__( *, proto_json: bool = False, grpc: bool = False, - accept_compression: Sequence[Compression] | None = None, + accept_compression: Iterable[Compression] | None = None, send_compression: Compression | None = _gzip, timeout_ms: int | None = None, read_max_bytes: int | None = None, diff --git a/src/connectrpc/_client_sync.py b/src/connectrpc/_client_sync.py index 089d039..47a1938 100644 --- a/src/connectrpc/_client_sync.py +++ b/src/connectrpc/_client_sync.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: import sys - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping from types import TracebackType from ._envelope import EnvelopeReader @@ -83,7 +83,7 @@ def __init__( *, proto_json: bool = False, grpc: bool = False, - accept_compression: Sequence[Compression] | None = None, + accept_compression: Iterable[Compression] | None = None, send_compression: Compression | None = _gzip, timeout_ms: int | None = None, read_max_bytes: int | None = None, diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index d6fc3e8..cadb456 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -89,7 +89,7 @@ def __init__( endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: """Initialize the ASGI application. diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index 5e8c7e1..027485f 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -168,7 +168,7 @@ def __init__( endpoints: Mapping[str, EndpointSync], interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: """Initialize the WSGI application. diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 31481e2..541420c 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -26,7 +26,6 @@ Iterable, Iterator, Mapping, - Sequence, ) from connectrpc.compression import Compression @@ -73,7 +72,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( service=service, @@ -312,7 +311,7 @@ def __init__( service: HaberdasherSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, - compressions: Sequence[Compression] | None = None, + compressions: Iterable[Compression] | None = None, ) -> None: super().__init__( endpoints={ From bc2e039872c6d9aa3dc731b824da820e1c9588da Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 28 Jan 2026 10:39:28 +0900 Subject: [PATCH 5/5] typing Signed-off-by: Anuraag Agrawal --- src/connectrpc/_compression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/connectrpc/_compression.py b/src/connectrpc/_compression.py index 291030c..b4d806a 100644 --- a/src/connectrpc/_compression.py +++ b/src/connectrpc/_compression.py @@ -7,7 +7,7 @@ from .compression import Compression if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable class IdentityCompression(Compression): @@ -30,7 +30,7 @@ def decompress(self, data: bytes | bytearray | memoryview) -> bytes: def resolve_compressions( - compressions: Sequence[Compression] | None, + compressions: Iterable[Compression] | None, ) -> dict[str, Compression]: if compressions is None: return _default_compressions