Skip to content
9 changes: 0 additions & 9 deletions src/nexusrpc/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import functools
import inspect
import typing
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Optional

Expand Down Expand Up @@ -142,14 +141,6 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
return method_name


def is_subtype(type1: type[Any], type2: type[Any]) -> bool:
# Note that issubclass() argument 2 cannot be a parameterized generic
# TODO(nexus-preview): review desired type compatibility logic
if type1 == type2:
return True
return issubclass(type1, typing.get_origin(type2) or type2)


# See
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older

Expand Down
37 changes: 8 additions & 29 deletions src/nexusrpc/handler/_operation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_operation_factory,
is_async_callable,
is_callable,
is_subtype,
)

from ._common import (
Expand Down Expand Up @@ -180,9 +179,8 @@ def validate_operation_handler_methods(
1. There must be a method in ``user_methods`` whose method name matches the method
name from the service definition.

2. The input and output types of the user method must be such that the user method
is a subtype of the operation defined in the service definition, i.e. respecting
input type contravariance and output type covariance.
2. The input and output types of the handler method must exactly match the types
declared in the service definition.
"""
operation_handler_factories_by_method_name = (
operation_handler_factories_by_method_name.copy()
Expand Down Expand Up @@ -212,40 +210,21 @@ def validate_operation_handler_methods(
f"is '{op_defn.name}'. Operation handlers may not override the name of an operation "
f"in the service definition."
)
# Input type is contravariant: op handler input must be superclass of op defn input.
# If handler's input_type is None (missing annotation), skip validation - the handler
# relies on the service definition for type information. This supports handlers without
# explicit type annotations when a service definition is provided.
if (
op.input_type is not None
and Any not in (op.input_type, op_defn.input_type)
and not (
op_defn.input_type == op.input_type
or is_subtype(op_defn.input_type, op.input_type)
)
):
if op.input_type is not None and op_defn.input_type != op.input_type:
raise TypeError(
f"Operation '{op_defn.method_name}' in service '{service_cls}' "
f"has input type '{op.input_type}', which is not "
f"compatible with the input type '{op_defn.input_type}' in interface "
f"'{service_definition.name}'. The input type must be the same as or a "
f"superclass of the operation definition input type."
f"OperationHandler input type mismatch for '{service_cls}.{op_defn.method_name}': "
f"expected {op_defn.input_type}, got {op.input_type}"
)

# Output type is covariant: op handler output must be subclass of op defn output.
# If handler's output_type is None (missing annotation), skip validation - the handler
# relies on the service definition for type information.
if (
op.output_type is not None
and Any not in (op.output_type, op_defn.output_type)
and not is_subtype(op.output_type, op_defn.output_type)
):
if op.output_type is not None and op.output_type != op_defn.output_type:
raise TypeError(
f"Operation '{op_defn.method_name}' in service '{service_cls}' "
f"has output type '{op.output_type}', which is not "
f"compatible with the output type '{op_defn.output_type}' in interface "
f" '{service_definition}'. The output type must be the same as or a "
f"subclass of the operation definition output type."
f"OperationHandler output type mismatch for '{service_cls}.{op_defn.method_name}': "
f"expected {op_defn.output_type}, got {op.output_type}"
)
if operation_handler_factories_by_method_name:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import pytest

from nexusrpc import HandlerError, LazyValue
from nexusrpc.handler import (
CancelOperationContext,
Handler,
OperationHandler,
StartOperationContext,
StartOperationResultSync,
operation_handler,
service_handler,
)
from nexusrpc.handler._decorators import operation_handler
from tests.helpers import DummySerializer, TestOperationTaskCancellation


def test_service_must_use_decorator():
Expand Down Expand Up @@ -63,3 +65,31 @@ class Service2:

with pytest.raises(RuntimeError):
_ = Handler([Service1(), Service2()])


@pytest.mark.asyncio
async def test_operations_must_have_decorator():
@service_handler
class TestService:
async def op(self, _ctx: StartOperationContext, input: str) -> str:
return input

handler = Handler([TestService()])

with pytest.raises(HandlerError, match="has no operation 'op'"):
_ = await handler.start_operation(
StartOperationContext(
service=TestService.__name__,
operation=TestService.op.__name__,
headers={},
request_id="test-req",
task_cancellation=TestOperationTaskCancellation(),
request_deadline=None,
callback_url=None,
),
LazyValue(
serializer=DummySerializer(value="test"),
headers={},
stream=None,
),
)
16 changes: 7 additions & 9 deletions tests/handler/test_invalid_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
handler implementations.
"""

from typing import Any, Callable
from dataclasses import dataclass

import pytest
from typing_extensions import dataclass_transform

import nexusrpc
from nexusrpc.handler import (
Expand All @@ -19,15 +18,14 @@
from nexusrpc.handler._operation_handler import OperationHandler


@dataclass_transform()
class _BaseTestCase:
pass


class _TestCase(_BaseTestCase):
build: Callable[..., Any]
@dataclass()
class _TestCase:
error_message: str

@staticmethod
def build():
pass


class OperationHandlerOverridesNameInconsistentlyWithServiceDefinition(_TestCase):
@staticmethod
Expand Down
95 changes: 95 additions & 0 deletions tests/handler/test_operation_handler_runtime_behavior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Test runtime behavior of operation handlers invoked through Handler.start_operation().

This file tests actual execution behavior, distinct from:
- Decoration-time validation (test_service_handler_decorator_validates_against_service_contract.py)
- Handler constructor validation (test_handler_validates_service_handler_collection.py)
"""

import pytest

from nexusrpc import LazyValue, Operation, service
from nexusrpc.handler import (
CancelOperationContext,
Handler,
OperationHandler,
StartOperationContext,
StartOperationResultSync,
operation_handler,
service_handler,
)
from nexusrpc.handler._decorators import sync_operation
from tests.helpers import DummySerializer, TestOperationTaskCancellation


@pytest.mark.asyncio
async def test_handler_can_return_covariant_type():
class Superclass:
pass

class Subclass(Superclass):
pass

@service
class CovariantService:
op_handler: Operation[None, Superclass]
inline: Operation[None, Superclass]

class ValidOperationHandler(OperationHandler[None, Superclass]):
async def start(
self, ctx: StartOperationContext, input: None
) -> StartOperationResultSync[Subclass]:
return StartOperationResultSync(Subclass())

async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
pass

@service_handler(service=CovariantService)
class CovariantServiceHandler:
@operation_handler
def op_handler(self) -> OperationHandler[None, Superclass]:
return ValidOperationHandler()

@sync_operation
async def inline(self, ctx: StartOperationContext, input: None) -> Superclass: # pyright: ignore[reportUnusedParameter]
return Subclass()

handler = Handler([CovariantServiceHandler()])

result = await handler.start_operation(
StartOperationContext(
service=CovariantService.__name__,
operation=CovariantService.op_handler.name,
headers={},
request_id="test-req",
task_cancellation=TestOperationTaskCancellation(),
request_deadline=None,
callback_url=None,
),
LazyValue(
serializer=DummySerializer(None),
headers={},
stream=None,
),
)
assert type(result) is StartOperationResultSync
assert type(result.value) is Subclass

result = await handler.start_operation(
StartOperationContext(
service=CovariantService.__name__,
operation=CovariantService.inline.name,
headers={},
request_id="test-req",
task_cancellation=TestOperationTaskCancellation(),
request_deadline=None,
callback_url=None,
),
LazyValue(
serializer=DummySerializer(None),
headers={},
stream=None,
),
)
assert type(result) is StartOperationResultSync
assert type(result.value) is Subclass
10 changes: 3 additions & 7 deletions tests/handler/test_request_routing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Callable, cast

import pytest
from typing_extensions import dataclass_transform

import nexusrpc
from nexusrpc import LazyValue
Expand All @@ -16,12 +16,8 @@
from tests.helpers import DummySerializer, TestOperationTaskCancellation


@dataclass_transform()
class _BaseTestCase:
pass


class _TestCase(_BaseTestCase):
@dataclass()
class _TestCase:
UserService: type[Any]
# (service_name, op_name)
supported_request: tuple[str, str]
Expand Down
16 changes: 7 additions & 9 deletions tests/handler/test_service_handler_decorator_requirements.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import pytest
from typing_extensions import dataclass_transform

import nexusrpc
from nexusrpc._util import get_service_definition
Expand All @@ -15,12 +15,8 @@
from nexusrpc.handler._decorators import operation_handler


@dataclass_transform()
class _TestCase:
pass


class _DecoratorValidationTestCase(_TestCase):
@dataclass()
class _DecoratorValidationTestCase:
UserService: type[Any]
UserServiceHandler: type[Any]
expected_error_message_pattern: str
Expand Down Expand Up @@ -71,7 +67,8 @@ def test_decorator_validates_definition_compliance(
service_handler(service=test_case.UserService)(test_case.UserServiceHandler)


class _ServiceHandlerInheritanceTestCase(_TestCase):
@dataclass()
class _ServiceHandlerInheritanceTestCase:
UserServiceHandler: type[Any]
expected_operations: set[str]

Expand Down Expand Up @@ -134,7 +131,8 @@ def test_service_implementation_inheritance(
)


class _ServiceDefinitionInheritanceTestCase(_TestCase):
@dataclass()
class _ServiceDefinitionInheritanceTestCase:
UserService: type[Any]
expected_ops: set[str]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Test that operation decorators result in operation factories that return the correct result.
"""

from dataclasses import dataclass
from typing import Any, Union, cast

import pytest
from typing_extensions import dataclass_transform

import nexusrpc
from nexusrpc import InputT, OutputT
Expand All @@ -26,12 +26,8 @@
from tests.helpers import TestOperationTaskCancellation


@dataclass_transform()
class _BaseTestCase:
pass


class _TestCase(_BaseTestCase):
@dataclass()
class _TestCase:
Service: type[Any]
expected_operation_factories: dict[str, Any]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Optional

import pytest
from typing_extensions import dataclass_transform

import nexusrpc
from nexusrpc._util import get_service_definition
Expand All @@ -18,12 +18,8 @@ class ServiceInterfaceWithNameOverride:
pass


@dataclass_transform()
class _BaseTestCase:
pass


class _NameOverrideTestCase(_BaseTestCase):
@dataclass()
class _NameOverrideTestCase:
ServiceImpl: type
expected_name: str
expected_error: Optional[type[Exception]] = None
Expand Down
Loading