Skip to content

Commit 38362fb

Browse files
authored
Overload __new__ (#710)
* Overload __new__ Mypy can't handle a large number of generic overloads (python/mypy#19622) so instead generate 2 different stub classes and override __new__. Seems to have the same functionality, without the generics overhead. Minor feature change in the revealed type of a stub class. Signed-off-by: Aidan Jensen <aidandj.github@gmail.com> * Write union methods to main typestub. Add Signed-off-by: Aidan Jensen <aidandj.github@gmail.com> * Clean up generated code and readme * Document async usage * Add tests for async/sync attribute usage * Generate sync methods by default, and type ignore async override for new override Signed-off-by: Aidan Jensen <aidandj.github@gmail.com> * Fix missing method comment regression Signed-off-by: Aidan Jensen <aidandj.github@gmail.com> --------- Signed-off-by: Aidan Jensen <aidandj.github@gmail.com>
1 parent c89ee56 commit 38362fb

29 files changed

+6606
-988
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
- Switch to types-grpcio instead of no longer maintained grpc-stubs
1111
- Add `_HasFieldArgType` and `_ClearFieldArgType` aliases to allow for typing field manipulation functions
1212
- Add `_WhichOneofArgType_<oneof_name>` and `_WhichOneofReturnType_<oneof_name>` type aliases
13+
- Use `__new__` overloads for async stubs instead of `TypeVar` based `__init__` overloads.
14+
- https://github.com/nipunn1313/mypy-protobuf/issues/707
1315

1416
## 3.7.0
1517

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,32 @@ protoc \
371371
Note that generated code for grpc will work only together with code for python and locations should be the same.
372372
If you need stubs for grpc internal code we suggest using this package https://pypi.org/project/types-grpcio/
373373

374+
#### Async GRPC usage
375+
376+
`mypy-protobuf` generates stubs that are compatible with both sync and async usage, with a few caveats.
377+
378+
In a simple use case, the stubs work as expected.
379+
380+
```python
381+
stub = dummy_Pb2_grpc.DummyServiceStub(grpc.aio.insecure_channel("localhost:1234"))
382+
result = await stub.UnaryUnary(dummy_pb2.DummyRequest(value="cprg"))
383+
typing.assert_type(result, dummy_pb2.DummyReply)
384+
```
385+
386+
If you need to explicitly type something as an async stub (class attr, etc) then you must use deferred annotations, and the async stub, as it does not exist at runtime.
387+
388+
```python
389+
class TestAttribute:
390+
stub: "dummy_pb2_grpc.DummyServiceAsyncStub"
391+
392+
def __init__(self) -> None:
393+
self.stub = dummy_pb2_grpc.DummyServiceStub(grpc.aio.insecure_channel("localhost:1234"))
394+
395+
async def test(self) -> None:
396+
result = await self.stub.UnaryUnary(dummy_pb2.DummyRequest(value="cprg"))
397+
typing.assert_type(result, dummy_pb2.DummyReply)
398+
```
399+
374400
### `_ClearFieldArgType`, `_WhichOneofArgType_<oneof_name>`, `_WhichOneofReturnType_<oneof_name>` and `_HasFieldArgType` aliases
375401

376402
Where applicable, type aliases are generated for the arguments to `ClearField`, `WhichOneof` and `HasField`. These can be used to create typed functions for field manipulation:

mypy_protobuf/main.py

Lines changed: 40 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@
8787
}
8888

8989

90-
def _build_typevar_name(service_name: str, method_name: str) -> str:
91-
# Prefix with underscore to avoid public api error: https://stackoverflow.com/a/78871465
92-
return f"_{service_name}{method_name}Type"
93-
94-
9590
def _mangle_global_identifier(name: str) -> str:
9691
"""
9792
Module level identifiers are mangled and aliased so that they can be disambiguated
@@ -180,7 +175,7 @@ def _import(self, path: str, name: str) -> str:
180175
eg. self._import("typing", "Literal") -> "Literal"
181176
"""
182177
if path == "typing_extensions":
183-
stabilization = {"TypeAlias": (3, 10), "TypeVar": (3, 13)}
178+
stabilization = {"TypeAlias": (3, 10), "TypeVar": (3, 13), "type_check_only": (3, 12)}
184179
assert name in stabilization
185180
if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
186181
self.typing_extensions_min = stabilization[name]
@@ -816,45 +811,28 @@ def write_grpc_async_hacks(self) -> None:
816811
wl("...")
817812
wl("")
818813

819-
def write_grpc_type_vars(self, service: d.ServiceDescriptorProto) -> None:
814+
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, *, is_async: bool, both: bool = False, ignore_assignment_errors: bool = False) -> None:
820815
wl = self._write_line
821816
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
822817
if not methods:
823818
return
824-
for _, method in methods:
825-
wl("{} = {}(", _build_typevar_name(service.name, method.name), self._import("typing_extensions", "TypeVar"))
826-
with self._indent():
827-
wl("'{}',", _build_typevar_name(service.name, method.name))
828-
wl("{}[", self._callable_type(method, is_async=False))
829-
with self._indent():
830-
wl("{},", self._input_type(method))
831-
wl("{},", self._output_type(method))
832-
wl("],")
833-
wl("{}[", self._callable_type(method, is_async=True))
834-
with self._indent():
835-
wl("{},", self._input_type(method))
836-
wl("{},", self._output_type(method))
837-
wl("],")
838-
wl("default={}[", self._callable_type(method, is_async=False))
839-
with self._indent():
840-
wl("{},", self._input_type(method))
841-
wl("{},", self._output_type(method))
842-
wl("],")
843-
wl(")")
844-
wl("")
845819

846-
def write_self_types(self, service: d.ServiceDescriptorProto, is_async: bool) -> None:
847-
wl = self._write_line
848-
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
849-
if not methods:
850-
return
851-
for _, method in methods:
852-
with self._indent():
853-
wl("{}[", self._callable_type(method, is_async=is_async))
854-
with self._indent():
855-
wl("{},", self._input_type(method))
856-
wl("{},", self._output_type(method))
857-
wl("],")
820+
def type_str(method: d.MethodDescriptorProto, is_async: bool) -> str:
821+
return f"{self._callable_type(method, is_async=is_async)}[{self._input_type(method)}, {self._output_type(method)}]"
822+
823+
for i, method in methods:
824+
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
825+
if both:
826+
wl(
827+
"{}: {}[{}, {}]",
828+
method.name,
829+
self._import("typing", "Union"),
830+
type_str(method, is_async=False),
831+
type_str(method, is_async=True),
832+
)
833+
else:
834+
wl("{}: {}{}", method.name, type_str(method, is_async=is_async), "" if not ignore_assignment_errors else " # type: ignore[assignment]")
835+
self._write_comments(scl)
858836

859837
def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
860838
wl = self._write_line
@@ -885,19 +863,6 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
885863
wl("...")
886864
wl("")
887865

888-
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
889-
wl = self._write_line
890-
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
891-
if not methods:
892-
wl("...")
893-
wl("")
894-
for i, method in methods:
895-
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
896-
897-
wl("{}: {}", method.name, f"{_build_typevar_name(service.name, method.name)}")
898-
self._write_comments(scl)
899-
wl("")
900-
901866
def write_grpc_services(
902867
self,
903868
services: Iterable[d.ServiceDescriptorProto],
@@ -906,58 +871,60 @@ def write_grpc_services(
906871
wl = self._write_line
907872
wl("GRPC_GENERATED_VERSION: str")
908873
wl("GRPC_VERSION: str")
874+
wl("")
909875
for i, service in enumerate(services):
910876
if service.name in PYTHON_RESERVED:
911877
continue
912878

913879
scl = scl_prefix + [i]
914880

915-
# Type vars
916-
self.write_grpc_type_vars(service)
881+
class_name = f"{service.name}Stub"
882+
async_class_alias = f"{service.name}AsyncStub"
917883

918884
# The stub client
919885
if service.options.deprecated:
920886
self._write_deprecation_warning(
921887
scl + [d.ServiceDescriptorProto.OPTIONS_FIELD_NUMBER] + [d.ServiceOptions.DEPRECATED_FIELD_NUMBER],
922888
"This stub has been marked as deprecated using proto service options.",
923889
)
924-
class_name = f"{service.name}Stub"
925890
wl(
926-
"class {}({}[{}]):",
891+
"class {}:",
927892
class_name,
928-
self._import("typing", "Generic"),
929-
", ".join(f"{_build_typevar_name(service.name, method.name)}" for method in service.method),
930893
)
931894
with self._indent():
932895
if self._write_comments(scl):
933896
wl("")
934-
935897
# Write sync overload
936898
wl("@{}", self._import("typing", "overload"))
937-
wl("def __init__(self: {}[", class_name)
938-
self.write_self_types(service, False)
939899
wl(
940-
"], channel: {}) -> None: ...",
900+
"def __new__(cls, channel: {}) -> {}: ...",
941901
self._import("grpc", "Channel"),
902+
class_name,
942903
)
943-
wl("")
944904

945905
# Write async overload
946906
wl("@{}", self._import("typing", "overload"))
947-
wl("def __init__(self: {}[", class_name)
948-
self.write_self_types(service, True)
949907
wl(
950-
"], channel: {}) -> None: ...",
908+
"def __new__(cls, channel: {}) -> {}: ...",
951909
self._import("grpc.aio", "Channel"),
910+
async_class_alias,
952911
)
912+
self.write_grpc_stub_methods(service, scl, is_async=False)
953913
wl("")
954914

955-
self.write_grpc_stub_methods(service, scl)
956-
957-
# Write AsyncStub alias
958-
wl("{}AsyncStub: {} = {}[", service.name, self._import("typing_extensions", "TypeAlias"), class_name)
959-
self.write_self_types(service, True)
960-
wl("]")
915+
# Write AsyncStub
916+
if service.options.deprecated:
917+
self._write_deprecation_warning(
918+
scl + [d.ServiceDescriptorProto.OPTIONS_FIELD_NUMBER] + [d.ServiceOptions.DEPRECATED_FIELD_NUMBER],
919+
"This stub has been marked as deprecated using proto service options.",
920+
)
921+
wl("@{}", self._import("typing", "type_check_only"))
922+
wl("class {}({}):", async_class_alias, class_name)
923+
with self._indent():
924+
if self._write_comments(scl):
925+
wl("")
926+
wl("def __init__(self, channel: {}) -> None: ...", self._import("grpc.aio", "Channel"))
927+
self.write_grpc_stub_methods(service, scl, is_async=True, ignore_assignment_errors=True)
961928
wl("")
962929

963930
# The service definition interface

0 commit comments

Comments
 (0)