Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
from collections.abc import Iterator
from typing import Any

import pytest
from injection.testing import load_test_profile, set_test_constant

from cq import (
Bus,
CommandBus,
EventBus,
QueryBus,
new_command_bus,
new_event_bus,
new_query_bus,
)
from injection import Module

from cq import CQ, Bus, CommandBus, EventBus, QueryBus
from cq._core.dispatcher.bus import SimpleBus
from cq.ext.injection import InjectionAdapter
from tests.helpers.history import HistoryMiddleware


@pytest.fixture(scope="function")
def cq(injection_module: Module) -> CQ:
return CQ(InjectionAdapter(injection_module)).register_defaults()


@pytest.fixture(scope="function")
def bus() -> Bus[Any, Any]:
return SimpleBus()


@pytest.fixture(scope="function", autouse=True)
def ensure_test_dependencies(
cq: CQ,
history: HistoryMiddleware,
injection_module: Module,
) -> None:
injection_module.injectable(
lambda: cq.new_command_bus().add_middlewares(history),
on=CommandBus,
)
injection_module.injectable(
lambda: cq.new_event_bus().add_middlewares(history),
on=EventBus,
)
injection_module.injectable(
lambda: cq.new_query_bus().add_middlewares(history),
on=QueryBus,
)


@pytest.fixture(scope="function")
def history() -> HistoryMiddleware:
return HistoryMiddleware()


@pytest.fixture(scope="function", autouse=True)
def ensure_test_dependencies(history: HistoryMiddleware) -> Iterator[None]:
command_bus: CommandBus[Any] = new_command_bus().add_middlewares(history)
event_bus: EventBus = new_event_bus().add_middlewares(history)
query_bus: QueryBus[Any] = new_query_bus().add_middlewares(history)

set_test_constant(command_bus, on=CommandBus, alias=True, mode="override")
set_test_constant(event_bus, on=EventBus, alias=True, mode="override")
set_test_constant(query_bus, on=QueryBus, alias=True, mode="override")

with load_test_profile():
yield
@pytest.fixture(scope="function")
def injection_module() -> Module:
return Module()
14 changes: 8 additions & 6 deletions cq/_core/dispatcher/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack, suppress
from typing import Protocol, Self, runtime_checkable

from cq._core.middleware import Middleware, MiddlewareGroup
Expand All @@ -10,6 +9,9 @@
class Dispatcher[I, O](Protocol):
__slots__ = ()

async def __call__(self, input_value: I, /) -> O:
return await self.dispatch(input_value)

@abstractmethod
async def dispatch(self, input_value: I, /) -> O:
raise NotImplementedError
Expand All @@ -34,10 +36,10 @@ async def _invoke_with_middlewares(
/,
fail_silently: bool = False,
) -> O:
async with AsyncExitStack() as stack:
if fail_silently:
stack.enter_context(suppress(Exception))

try:
return await self.__middleware_group.invoke(handler, input_value)
except Exception:
if fail_silently:
return NotImplemented

return NotImplemented
raise
23 changes: 20 additions & 3 deletions cq/_core/dispatcher/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def add[T](
self.__steps.append(PipelineStep(converter, dispatcher))
return self

def add_static[T](
self,
input_value: T,
dispatcher: Dispatcher[T, Any] | None,
) -> Self:
converter = _StaticPipelineConverter(input_value)
self.add(converter, dispatcher) # type: ignore[arg-type]
return self

async def execute(self, input_value: I, /, *args: P.args, **kwargs: P.kwargs) -> O:
dispatcher = self.default_dispatcher

Expand Down Expand Up @@ -128,11 +137,10 @@ def decorator(wp: Convert[[], T, Any]) -> Convert[[], T, Any]:
def add_static_step[T](
self,
input_value: T,
*,
/,
dispatcher: Dispatcher[T, Any] | None = None,
) -> Self:
converter = _StaticPipelineConverter(input_value)
self.__steps.add(converter, dispatcher)
self.__steps.add_static(input_value, dispatcher)
return self

async def dispatch(self, input_value: I, /) -> O:
Expand Down Expand Up @@ -189,6 +197,15 @@ def add_middlewares(self, *middlewares: Middleware[[I], Any]) -> Self:
self.__middleware_group.add(*middlewares)
return self

def add_static_step[T](
self,
input_value: T,
/,
dispatcher: Dispatcher[T, Any] | None = None,
) -> Self:
self.__steps.add_static(input_value, dispatcher)
return self

if TYPE_CHECKING: # pragma: no cover

@overload
Expand Down
5 changes: 4 additions & 1 deletion cq/_core/pipetools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any, Self, overload

from cq import Dispatcher
from cq._core.common.typing import Decorator
Expand All @@ -25,6 +25,9 @@ def __init__(self, di: DIAdapter) -> None:
command_middleware = CommandDispatchScopeMiddleware(di)
self.add_middlewares(command_middleware)

def add_static_query_step[Q: Query](self, query: Q, /) -> Self:
return self.add_static_step(query, dispatcher=self.__query_dispatcher)

if TYPE_CHECKING: # pragma: no cover

@overload
Expand Down
22 changes: 14 additions & 8 deletions tests/test_command_bus.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from injection import find_instance
from injection import Module

from cq import AnyCommandBus, RelatedEvents, command_handler, event_handler
from cq import CQ, AnyCommandBus, RelatedEvents
from tests.helpers.history import HistoryMiddleware


class TestCommandBus:
async def test_dispatch_with_related_events(
self,
cq: CQ,
history: HistoryMiddleware,
injection_module: Module,
) -> None:
class _Event: ...

@event_handler
@cq.event_handler
class _EventHandler:
async def handle(self, event: _Event) -> None: ...

class _Command: ...

@command_handler
@cq.command_handler
class _CommandHandler:
def __init__(self, related_events: RelatedEvents) -> None:
self.related_events = related_events
Expand All @@ -26,21 +28,25 @@ async def handle(self, command: _Command) -> None:
event = _Event()
self.related_events.add(event)

command_bus = find_instance(AnyCommandBus)
command_bus = injection_module.find_instance(AnyCommandBus)
command = _Command()
await command_bus.dispatch(command)

assert len(history.records) == 2
assert isinstance(history.records[0].args[0], _Event)
assert isinstance(history.records[1].args[0], _Command)

async def test_dispatch_with_fail_silently(self) -> None:
async def test_dispatch_with_fail_silently(
self,
cq: CQ,
injection_module: Module,
) -> None:
class _Command: ...

@command_handler(fail_silently=True)
@cq.command_handler(fail_silently=True)
class _CommandHandler:
async def handle(self, command: _Command) -> None:
raise ValueError

command_bus = find_instance(AnyCommandBus)
command_bus = injection_module.find_instance(AnyCommandBus)
assert await command_bus.dispatch(_Command()) is NotImplemented
26 changes: 18 additions & 8 deletions tests/test_context_command_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from cq import ContextCommandPipeline, command_handler, query_handler
from cq import CQ, ContextCommandPipeline
from tests.helpers.history import HistoryMiddleware


class TestContextCommandPipeline:
async def test_dispatch_with_success_return_any(
self,
cq: CQ,
history: HistoryMiddleware,
) -> None:
class Command0: ...

class Command1: ...

class Command2: ...
Expand All @@ -19,17 +22,22 @@ class Bar: ...

class Baz: ...

@command_handler
@cq.command_handler
class CommandHandler0:
async def handle(self, command: Command0) -> None:
return

@cq.command_handler
class CommandHandler1:
async def handle(self, command: Command1) -> Foo:
return Foo()

@command_handler
@cq.command_handler
class CommandHandler2:
async def handle(self, command: Command2) -> Bar:
return Bar()

@query_handler
@cq.query_handler
class QueryHandler:
async def handle(self, query: Query) -> Baz:
return Baz()
Expand All @@ -39,7 +47,9 @@ class Context:
bar: Bar
baz: Baz

pipeline: ContextCommandPipeline[Command1] = ContextCommandPipeline()
pipeline: ContextCommandPipeline[Command0] = ContextCommandPipeline(cq.di)

pipeline.add_static_step(Command1())

@pipeline.step
def _(self, foo: Foo) -> Command2:
Expand All @@ -55,11 +65,11 @@ def _(self, bar: Bar) -> Query:
async def _(self, baz: Baz) -> None:
self.baz = baz

cmd = Command1()
ctx = await Context.pipeline.dispatch(cmd)
cmd = Command0()
ctx = await Context.pipeline(cmd)

assert isinstance(ctx, Context)
assert isinstance(ctx.foo, Foo)
assert isinstance(ctx.bar, Bar)
assert isinstance(ctx.baz, Baz)
assert len(history.records) == 3
assert len(history.records) == 4
Loading
Loading