diff --git a/src/nexusrpc/_util.py b/src/nexusrpc/_util.py index 776079a..7e89137 100644 --- a/src/nexusrpc/_util.py +++ b/src/nexusrpc/_util.py @@ -2,7 +2,6 @@ import functools import inspect -import typing from collections.abc import Awaitable from typing import TYPE_CHECKING, Any, Callable, Optional @@ -142,14 +141,6 @@ def get_callable_name(fn: Callable[..., Any]) -> str: return method_name -def is_subtype(type1: type[Any], type2: type[Any]) -> bool: - # Note that issubclass() argument 2 cannot be a parameterized generic - # TODO(nexus-preview): review desired type compatibility logic - if type1 == type2: - return True - return issubclass(type1, typing.get_origin(type2) or type2) - - # See # https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older diff --git a/src/nexusrpc/handler/_operation_handler.py b/src/nexusrpc/handler/_operation_handler.py index 7a06a51..181fa67 100644 --- a/src/nexusrpc/handler/_operation_handler.py +++ b/src/nexusrpc/handler/_operation_handler.py @@ -12,7 +12,6 @@ get_operation_factory, is_async_callable, is_callable, - is_subtype, ) from ._common import ( @@ -180,9 +179,8 @@ def validate_operation_handler_methods( 1. There must be a method in ``user_methods`` whose method name matches the method name from the service definition. - 2. The input and output types of the user method must be such that the user method - is a subtype of the operation defined in the service definition, i.e. respecting - input type contravariance and output type covariance. + 2. The input and output types of the handler method must exactly match the types + declared in the service definition. """ operation_handler_factories_by_method_name = ( operation_handler_factories_by_method_name.copy() @@ -212,40 +210,21 @@ def validate_operation_handler_methods( f"is '{op_defn.name}'. Operation handlers may not override the name of an operation " f"in the service definition." ) - # Input type is contravariant: op handler input must be superclass of op defn input. # If handler's input_type is None (missing annotation), skip validation - the handler # relies on the service definition for type information. This supports handlers without # explicit type annotations when a service definition is provided. - if ( - op.input_type is not None - and Any not in (op.input_type, op_defn.input_type) - and not ( - op_defn.input_type == op.input_type - or is_subtype(op_defn.input_type, op.input_type) - ) - ): + if op.input_type is not None and op_defn.input_type != op.input_type: raise TypeError( - f"Operation '{op_defn.method_name}' in service '{service_cls}' " - f"has input type '{op.input_type}', which is not " - f"compatible with the input type '{op_defn.input_type}' in interface " - f"'{service_definition.name}'. The input type must be the same as or a " - f"superclass of the operation definition input type." + f"OperationHandler input type mismatch for '{service_cls}.{op_defn.method_name}': " + f"expected {op_defn.input_type}, got {op.input_type}" ) - # Output type is covariant: op handler output must be subclass of op defn output. # If handler's output_type is None (missing annotation), skip validation - the handler # relies on the service definition for type information. - if ( - op.output_type is not None - and Any not in (op.output_type, op_defn.output_type) - and not is_subtype(op.output_type, op_defn.output_type) - ): + if op.output_type is not None and op.output_type != op_defn.output_type: raise TypeError( - f"Operation '{op_defn.method_name}' in service '{service_cls}' " - f"has output type '{op.output_type}', which is not " - f"compatible with the output type '{op_defn.output_type}' in interface " - f" '{service_definition}'. The output type must be the same as or a " - f"subclass of the operation definition output type." + f"OperationHandler output type mismatch for '{service_cls}.{op_defn.method_name}': " + f"expected {op_defn.output_type}, got {op.output_type}" ) if operation_handler_factories_by_method_name: raise ValueError( diff --git a/tests/handler/test_handler_validates_service_handler_collection.py b/tests/handler/test_handler_validates_service_handler_collection.py index fbf2680..2b1b8ec 100644 --- a/tests/handler/test_handler_validates_service_handler_collection.py +++ b/tests/handler/test_handler_validates_service_handler_collection.py @@ -5,15 +5,17 @@ import pytest +from nexusrpc import HandlerError, LazyValue from nexusrpc.handler import ( CancelOperationContext, Handler, OperationHandler, StartOperationContext, StartOperationResultSync, + operation_handler, service_handler, ) -from nexusrpc.handler._decorators import operation_handler +from tests.helpers import DummySerializer, TestOperationTaskCancellation def test_service_must_use_decorator(): @@ -63,3 +65,31 @@ class Service2: with pytest.raises(RuntimeError): _ = Handler([Service1(), Service2()]) + + +@pytest.mark.asyncio +async def test_operations_must_have_decorator(): + @service_handler + class TestService: + async def op(self, _ctx: StartOperationContext, input: str) -> str: + return input + + handler = Handler([TestService()]) + + with pytest.raises(HandlerError, match="has no operation 'op'"): + _ = await handler.start_operation( + StartOperationContext( + service=TestService.__name__, + operation=TestService.op.__name__, + headers={}, + request_id="test-req", + task_cancellation=TestOperationTaskCancellation(), + request_deadline=None, + callback_url=None, + ), + LazyValue( + serializer=DummySerializer(value="test"), + headers={}, + stream=None, + ), + ) diff --git a/tests/handler/test_invalid_usage.py b/tests/handler/test_invalid_usage.py index 2cd1277..5ffa31b 100644 --- a/tests/handler/test_invalid_usage.py +++ b/tests/handler/test_invalid_usage.py @@ -3,10 +3,9 @@ handler implementations. """ -from typing import Any, Callable +from dataclasses import dataclass import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc.handler import ( @@ -19,15 +18,14 @@ from nexusrpc.handler._operation_handler import OperationHandler -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): - build: Callable[..., Any] +@dataclass() +class _TestCase: error_message: str + @staticmethod + def build(): + pass + class OperationHandlerOverridesNameInconsistentlyWithServiceDefinition(_TestCase): @staticmethod diff --git a/tests/handler/test_operation_handler_runtime_behavior.py b/tests/handler/test_operation_handler_runtime_behavior.py new file mode 100644 index 0000000..8689a35 --- /dev/null +++ b/tests/handler/test_operation_handler_runtime_behavior.py @@ -0,0 +1,95 @@ +""" +Test runtime behavior of operation handlers invoked through Handler.start_operation(). + +This file tests actual execution behavior, distinct from: +- Decoration-time validation (test_service_handler_decorator_validates_against_service_contract.py) +- Handler constructor validation (test_handler_validates_service_handler_collection.py) +""" + +import pytest + +from nexusrpc import LazyValue, Operation, service +from nexusrpc.handler import ( + CancelOperationContext, + Handler, + OperationHandler, + StartOperationContext, + StartOperationResultSync, + operation_handler, + service_handler, +) +from nexusrpc.handler._decorators import sync_operation +from tests.helpers import DummySerializer, TestOperationTaskCancellation + + +@pytest.mark.asyncio +async def test_handler_can_return_covariant_type(): + class Superclass: + pass + + class Subclass(Superclass): + pass + + @service + class CovariantService: + op_handler: Operation[None, Superclass] + inline: Operation[None, Superclass] + + class ValidOperationHandler(OperationHandler[None, Superclass]): + async def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultSync[Subclass]: + return StartOperationResultSync(Subclass()) + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + pass + + @service_handler(service=CovariantService) + class CovariantServiceHandler: + @operation_handler + def op_handler(self) -> OperationHandler[None, Superclass]: + return ValidOperationHandler() + + @sync_operation + async def inline(self, ctx: StartOperationContext, input: None) -> Superclass: # pyright: ignore[reportUnusedParameter] + return Subclass() + + handler = Handler([CovariantServiceHandler()]) + + result = await handler.start_operation( + StartOperationContext( + service=CovariantService.__name__, + operation=CovariantService.op_handler.name, + headers={}, + request_id="test-req", + task_cancellation=TestOperationTaskCancellation(), + request_deadline=None, + callback_url=None, + ), + LazyValue( + serializer=DummySerializer(None), + headers={}, + stream=None, + ), + ) + assert type(result) is StartOperationResultSync + assert type(result.value) is Subclass + + result = await handler.start_operation( + StartOperationContext( + service=CovariantService.__name__, + operation=CovariantService.inline.name, + headers={}, + request_id="test-req", + task_cancellation=TestOperationTaskCancellation(), + request_deadline=None, + callback_url=None, + ), + LazyValue( + serializer=DummySerializer(None), + headers={}, + stream=None, + ), + ) + assert type(result) is StartOperationResultSync + assert type(result.value) is Subclass diff --git a/tests/handler/test_request_routing.py b/tests/handler/test_request_routing.py index 64b54bc..04a4a75 100644 --- a/tests/handler/test_request_routing.py +++ b/tests/handler/test_request_routing.py @@ -1,7 +1,7 @@ +from dataclasses import dataclass from typing import Any, Callable, cast import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc import LazyValue @@ -16,12 +16,8 @@ from tests.helpers import DummySerializer, TestOperationTaskCancellation -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: UserService: type[Any] # (service_name, op_name) supported_request: tuple[str, str] diff --git a/tests/handler/test_service_handler_decorator_requirements.py b/tests/handler/test_service_handler_decorator_requirements.py index 80953ad..8548e99 100644 --- a/tests/handler/test_service_handler_decorator_requirements.py +++ b/tests/handler/test_service_handler_decorator_requirements.py @@ -1,9 +1,9 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc._util import get_service_definition @@ -15,12 +15,8 @@ from nexusrpc.handler._decorators import operation_handler -@dataclass_transform() -class _TestCase: - pass - - -class _DecoratorValidationTestCase(_TestCase): +@dataclass() +class _DecoratorValidationTestCase: UserService: type[Any] UserServiceHandler: type[Any] expected_error_message_pattern: str @@ -71,7 +67,8 @@ def test_decorator_validates_definition_compliance( service_handler(service=test_case.UserService)(test_case.UserServiceHandler) -class _ServiceHandlerInheritanceTestCase(_TestCase): +@dataclass() +class _ServiceHandlerInheritanceTestCase: UserServiceHandler: type[Any] expected_operations: set[str] @@ -134,7 +131,8 @@ def test_service_implementation_inheritance( ) -class _ServiceDefinitionInheritanceTestCase(_TestCase): +@dataclass() +class _ServiceDefinitionInheritanceTestCase: UserService: type[Any] expected_ops: set[str] diff --git a/tests/handler/test_service_handler_decorator_results_in_correctly_functioning_operation_factories.py b/tests/handler/test_service_handler_decorator_results_in_correctly_functioning_operation_factories.py index bbe1c0c..cd7570a 100644 --- a/tests/handler/test_service_handler_decorator_results_in_correctly_functioning_operation_factories.py +++ b/tests/handler/test_service_handler_decorator_results_in_correctly_functioning_operation_factories.py @@ -2,10 +2,10 @@ Test that operation decorators result in operation factories that return the correct result. """ +from dataclasses import dataclass from typing import Any, Union, cast import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc import InputT, OutputT @@ -26,12 +26,8 @@ from tests.helpers import TestOperationTaskCancellation -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: Service: type[Any] expected_operation_factories: dict[str, Any] diff --git a/tests/handler/test_service_handler_decorator_selects_correct_service_name.py b/tests/handler/test_service_handler_decorator_selects_correct_service_name.py index 786449f..b02897b 100644 --- a/tests/handler/test_service_handler_decorator_selects_correct_service_name.py +++ b/tests/handler/test_service_handler_decorator_selects_correct_service_name.py @@ -1,7 +1,7 @@ +from dataclasses import dataclass from typing import Optional import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc._util import get_service_definition @@ -18,12 +18,8 @@ class ServiceInterfaceWithNameOverride: pass -@dataclass_transform() -class _BaseTestCase: - pass - - -class _NameOverrideTestCase(_BaseTestCase): +@dataclass() +class _NameOverrideTestCase: ServiceImpl: type expected_name: str expected_error: Optional[type[Exception]] = None diff --git a/tests/handler/test_service_handler_decorator_validates_against_service_contract.py b/tests/handler/test_service_handler_decorator_validates_against_service_contract.py index 4be5c70..65404a9 100644 --- a/tests/handler/test_service_handler_decorator_validates_against_service_contract.py +++ b/tests/handler/test_service_handler_decorator_validates_against_service_contract.py @@ -1,7 +1,7 @@ -from typing import Any, Optional +from dataclasses import dataclass +from typing import Any import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc.handler import ( @@ -11,18 +11,26 @@ ) -@dataclass_transform() -class _BaseTestCase: - pass - - -class _InterfaceImplementationTestCase(_BaseTestCase): +@dataclass() +class _InterfaceImplementationTestCase: Interface: type Impl: type - error_message: Optional[str] + error_message: str | None + + +class _InvalidInputTestCase(_InterfaceImplementationTestCase): + error_message = "OperationHandler input type mismatch" + +class _InvalidOutputTestCase(_InterfaceImplementationTestCase): + error_message = "OperationHandler output type mismatch" -class ValidImpl(_InterfaceImplementationTestCase): + +class _ValidTestCase(_InterfaceImplementationTestCase): + error_message = None + + +class ValidImpl(_ValidTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[None, None] @@ -33,8 +41,6 @@ class Impl: @sync_operation async def op(self, _ctx: StartOperationContext, _input: None) -> None: ... - error_message = None - class ValidImplWithEmptyInterfaceAndExtraOperation(_InterfaceImplementationTestCase): @nexusrpc.service @@ -50,7 +56,7 @@ def unrelated_method(self) -> None: ... error_message = "does not match an operation method name in the service definition" -class ValidImplWithoutTypeAnnotations(_InterfaceImplementationTestCase): +class ValidImplWithoutTypeAnnotations(_ValidTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[int, str] @@ -59,8 +65,6 @@ class Impl: @sync_operation async def op(self, ctx, input): ... # type: ignore[reportMissingParameterType] - error_message = None - class MissingOperation(_InterfaceImplementationTestCase): @nexusrpc.service @@ -73,7 +77,7 @@ class Impl: error_message = "does not implement an operation with method name 'op'" -class MissingInputAnnotation(_InterfaceImplementationTestCase): +class MissingInputAnnotation(_ValidTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[None, None] @@ -82,10 +86,8 @@ class Impl: @sync_operation async def op(self, ctx: StartOperationContext, input) -> None: ... # type: ignore[reportMissingParameterType] - error_message = None - -class MissingContextAnnotation(_InterfaceImplementationTestCase): +class MissingContextAnnotation(_ValidTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[None, None] @@ -94,10 +96,8 @@ class Impl: @sync_operation async def op(self, ctx, input: None) -> None: ... # type: ignore[reportMissingParameterType] - error_message = None - -class WrongOutputType(_InterfaceImplementationTestCase): +class WrongOutputType(_InvalidOutputTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[None, int] @@ -106,10 +106,8 @@ class Impl: @sync_operation async def op(self, _ctx: StartOperationContext, _input: None) -> str: ... - error_message = "is not compatible with the output type" - -class WrongOutputTypeWithNone(_InterfaceImplementationTestCase): +class WrongOutputTypeWithNone(_InvalidOutputTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[str, None] @@ -118,10 +116,8 @@ class Impl: @sync_operation async def op(self, _ctx: StartOperationContext, _input: str) -> str: ... - error_message = "is not compatible with the output type" - -class ValidImplWithNone(_InterfaceImplementationTestCase): +class ValidImplWithNone(_ValidTestCase): @nexusrpc.service class Interface: op: nexusrpc.Operation[str, None] @@ -130,105 +126,133 @@ class Impl: @sync_operation async def op(self, _ctx: StartOperationContext, _input: str) -> None: ... - error_message = None +class X: + pass + + +class SuperClass: + pass + + +class Subclass(SuperClass): + pass -class MoreSpecificImplAllowed(_InterfaceImplementationTestCase): + +class OutputCovarianceImplOutputCannotBeSubclass(_InvalidOutputTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[Any, Any] + op: nexusrpc.Operation[X, SuperClass] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: str) -> str: ... - - error_message = None + async def op(self, _ctx: StartOperationContext, _input: X) -> Subclass: ... -class X: - pass +class OutputCovarianceImplOutputCannotBeStrictSuperclass(_InvalidOutputTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[X, Subclass] + class Impl: + @sync_operation + async def op(self, _ctx: StartOperationContext, _input: X) -> SuperClass: ... -class SuperClass: - pass +class InputContravarianceImplInputCannotBeSuperclass(_InvalidInputTestCase): + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[Subclass, X] -class Subclass(SuperClass): - pass + class Impl: + @sync_operation + async def op(self, _ctx: StartOperationContext, _input: SuperClass) -> X: ... -class OutputCovarianceImplOutputCanBeSameType(_InterfaceImplementationTestCase): +class InputContravarianceImplInputCannotBeSubclass(_InvalidInputTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[X, X] + op: nexusrpc.Operation[SuperClass, X] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: X) -> X: ... + async def op(self, _ctx: StartOperationContext, _input: Subclass) -> X: ... - error_message = None +class ValidImplWithGenericTypes(_ValidTestCase): + """Validates that generic types work with equality comparison.""" -class OutputCovarianceImplOutputCanBeSubclass(_InterfaceImplementationTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[X, SuperClass] + op: nexusrpc.Operation[list[int], dict[str, bool]] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: X) -> Subclass: ... + async def op( + self, _ctx: StartOperationContext, _input: list[int] + ) -> dict[str, bool]: ... - error_message = None +class InvalidImplWithWrongGenericInputType(_InvalidInputTestCase): + """Validates that mismatched generic input types are caught.""" -class OutputCovarianceImplOutputCannnotBeStrictSuperclass( - _InterfaceImplementationTestCase -): @nexusrpc.service class Interface: - op: nexusrpc.Operation[X, Subclass] + op: nexusrpc.Operation[list[int], str] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: X) -> SuperClass: ... + async def op(self, _ctx: StartOperationContext, _input: list[str]) -> str: ... - error_message = "is not compatible with the output type" +class InvalidImplWithWrongGenericOutputType(_InvalidOutputTestCase): + """Validates that mismatched generic output types are caught.""" -class InputContravarianceImplInputCanBeSameType(_InterfaceImplementationTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[X, X] + op: nexusrpc.Operation[str, dict[str, int]] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: X) -> X: ... + async def op( + self, _ctx: StartOperationContext, _input: str + ) -> dict[str, str]: ... - error_message = None +class ValidImplWithAnyTypes(_ValidTestCase): + """Validates that Any types require exact match (Any == Any).""" -class InputContravarianceImplInputCanBeSuperclass(_InterfaceImplementationTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[Subclass, X] + op: nexusrpc.Operation[Any, Any] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: SuperClass) -> X: ... + async def op(self, _ctx: StartOperationContext, _input: Any) -> Any: ... - error_message = None +class InvalidImplWithAnyInputButSpecificImpl(_InvalidInputTestCase): + """Validates that Any in interface does not act as wildcard for input types.""" -class InputContravarianceImplInputCannotBeSubclass(_InterfaceImplementationTestCase): @nexusrpc.service class Interface: - op: nexusrpc.Operation[SuperClass, X] + op: nexusrpc.Operation[Any, str] class Impl: @sync_operation - async def op(self, _ctx: StartOperationContext, _input: Subclass) -> X: ... + async def op(self, _ctx: StartOperationContext, _input: str) -> str: ... + + +class InvalidImplWithAnyOutputButSpecificImpl(_InvalidOutputTestCase): + """Validates that Any in interface does not act as wildcard for output types.""" + + @nexusrpc.service + class Interface: + op: nexusrpc.Operation[str, Any] - error_message = "is not compatible with the input type" + class Impl: + @sync_operation + async def op(self, _ctx: StartOperationContext, _input: str) -> str: ... @pytest.mark.parametrize( @@ -243,12 +267,16 @@ async def op(self, _ctx: StartOperationContext, _input: Subclass) -> X: ... WrongOutputType, WrongOutputTypeWithNone, ValidImplWithNone, - MoreSpecificImplAllowed, - OutputCovarianceImplOutputCanBeSameType, - OutputCovarianceImplOutputCanBeSubclass, - OutputCovarianceImplOutputCannnotBeStrictSuperclass, - InputContravarianceImplInputCanBeSameType, - InputContravarianceImplInputCanBeSuperclass, + OutputCovarianceImplOutputCannotBeSubclass, + OutputCovarianceImplOutputCannotBeStrictSuperclass, + InputContravarianceImplInputCannotBeSuperclass, + InputContravarianceImplInputCannotBeSubclass, + ValidImplWithGenericTypes, + InvalidImplWithWrongGenericInputType, + InvalidImplWithWrongGenericOutputType, + ValidImplWithAnyTypes, + InvalidImplWithAnyInputButSpecificImpl, + InvalidImplWithAnyOutputButSpecificImpl, ], ) def test_service_decorator_enforces_interface_implementation( diff --git a/tests/handler/test_service_handler_decorator_validates_duplicate_operation_names.py b/tests/handler/test_service_handler_decorator_validates_duplicate_operation_names.py index 52159d4..4ee52a4 100644 --- a/tests/handler/test_service_handler_decorator_validates_duplicate_operation_names.py +++ b/tests/handler/test_service_handler_decorator_validates_duplicate_operation_names.py @@ -1,7 +1,7 @@ +from dataclasses import dataclass from typing import Any import pytest -from typing_extensions import dataclass_transform from nexusrpc.handler import ( OperationHandler, @@ -10,12 +10,8 @@ from nexusrpc.handler._decorators import operation_handler -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: UserServiceHandler: type[Any] expected_error_message: str diff --git a/tests/service_definition/test_service_decorator_creates_expected_operation_declaration.py b/tests/service_definition/test_service_decorator_creates_expected_operation_declaration.py index 1fa470c..c92b3f3 100644 --- a/tests/service_definition/test_service_decorator_creates_expected_operation_declaration.py +++ b/tests/service_definition/test_service_decorator_creates_expected_operation_declaration.py @@ -1,7 +1,7 @@ +from dataclasses import dataclass from typing import Any import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc._util import get_service_definition @@ -11,12 +11,8 @@ class Output: pass -@dataclass_transform() -class _BaseTestCase: - pass - - -class OperationDeclarationTestCase(_BaseTestCase): +@dataclass() +class OperationDeclarationTestCase: Interface: type expected_ops: dict[str, tuple[type[Any], type[Any]]] diff --git a/tests/service_definition/test_service_decorator_selects_correct_service_name.py b/tests/service_definition/test_service_decorator_selects_correct_service_name.py index db25655..8227013 100644 --- a/tests/service_definition/test_service_decorator_selects_correct_service_name.py +++ b/tests/service_definition/test_service_decorator_selects_correct_service_name.py @@ -1,16 +1,13 @@ +from dataclasses import dataclass + import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc._util import get_service_definition -@dataclass_transform() -class _BaseTestCase: - pass - - -class NameOverrideTestCase(_BaseTestCase): +@dataclass +class NameOverrideTestCase: Interface: type expected_name: str diff --git a/tests/service_definition/test_service_decorator_validation.py b/tests/service_definition/test_service_decorator_validation.py index 05f76a4..fa5985d 100644 --- a/tests/service_definition/test_service_decorator_validation.py +++ b/tests/service_definition/test_service_decorator_validation.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass + import pytest -from typing_extensions import dataclass_transform import nexusrpc @@ -8,12 +9,8 @@ class Output: pass -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: Contract: type expected_error: Exception diff --git a/tests/service_definition/test_service_definition_inheritance.py b/tests/service_definition/test_service_definition_inheritance.py index d8af5b2..e6974b4 100644 --- a/tests/service_definition/test_service_definition_inheritance.py +++ b/tests/service_definition/test_service_definition_inheritance.py @@ -3,10 +3,10 @@ # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older from __future__ import annotations +from dataclasses import dataclass from typing import Any, Optional import pytest -from typing_extensions import dataclass_transform import nexusrpc from nexusrpc import Operation, ServiceDefinition @@ -15,12 +15,8 @@ # See https://docs.python.org/3/howto/annotations.html -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: UserService: type[Any] expected_operation_names: set[str] expected_error: Optional[str] = None diff --git a/tests/test_get_input_and_output_types.py b/tests/test_get_input_and_output_types.py index ff404eb..ee06b3b 100644 --- a/tests/test_get_input_and_output_types.py +++ b/tests/test_get_input_and_output_types.py @@ -1,5 +1,6 @@ import warnings from collections.abc import Awaitable +from dataclasses import dataclass from typing import ( Any, Callable, @@ -9,7 +10,6 @@ ) import pytest -from typing_extensions import dataclass_transform from nexusrpc.handler import StartOperationContext from nexusrpc.handler._util import get_start_method_input_and_output_type_annotations @@ -23,12 +23,8 @@ class Output: pass -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +@dataclass() +class _TestCase: start: Callable[..., Any] expected_types: tuple[Any, Any]