From 4165260a7aab8f0fc0424013621636eab658c64c Mon Sep 17 00:00:00 2001 From: isaacbmiller Date: Sat, 28 Feb 2026 14:36:41 -0500 Subject: [PATCH] refactor: detect async modules at discovery time instead of per-request Check for user-implemented aforward() once during module discovery. Walks the MRO up to (but not including) dspy.Module, so aforward inherited from an intermediate base class is still detected while the default base-class method is ignored. Store result as has_native_async on DiscoveredModule and use it in execute_pipeline. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- src/dspy_cli/discovery/module_finder.py | 23 ++++++++++ src/dspy_cli/server/execution.py | 2 +- tests/discovery/test_module_finder.py | 60 +++++++++++++++++++++++++ tests/test_execution.py | 5 +++ 4 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/dspy_cli/discovery/module_finder.py b/src/dspy_cli/discovery/module_finder.py index 98a5a3e..c4810fe 100644 --- a/src/dspy_cli/discovery/module_finder.py +++ b/src/dspy_cli/discovery/module_finder.py @@ -28,6 +28,7 @@ class DiscoveredModule: forward_input_fields: Optional[Dict[str, Any]] = None # Input field types from forward() method forward_output_fields: Optional[Dict[str, Any]] = None # Output field types from forward() method is_forward_typed: bool = False # True if forward() has proper type annotations + has_native_async: bool = False # True if the module defines its own aforward() gateway_classes: List[Type["Gateway"]] = None # Gateway classes if specified on module (supports list) # Backward compatibility: single gateway_class property @@ -43,6 +44,23 @@ def instantiate(self, lm: dspy.LM | None = None) -> dspy.Module: return self.class_obj() +def _has_user_implemented_aforward(cls: Type[dspy.Module]) -> bool: + """Check if a class (or an intermediate base) defines aforward(). + + Walks the MRO but stops before dspy.Module itself, so the default + base-class aforward is ignored while user-defined aforward on any + intermediate superclass is still detected. + """ + for klass in cls.__mro__: + if klass is dspy.Module or klass is object: + break + if 'aforward' in klass.__dict__: + method = klass.__dict__['aforward'] + if callable(method) or isinstance(method, (staticmethod, classmethod)): + return True + return False + + def discover_modules( package_path: Path, package_name: str, @@ -135,6 +153,10 @@ def discover_modules( # Extract gateway classes if specified (supports single or list) gateway_classes = _extract_gateway_classes(obj) + native_async = _has_user_implemented_aforward(obj) + if native_async: + logger.info(f"Module {name} has native async support (aforward)") + discovered.append( DiscoveredModule( name=name, @@ -144,6 +166,7 @@ def discover_modules( forward_input_fields=forward_info.get("inputs"), forward_output_fields=forward_info.get("outputs"), is_forward_typed=forward_info.get("is_typed", False), + has_native_async=native_async, gateway_classes=gateway_classes, ) ) diff --git a/src/dspy_cli/server/execution.py b/src/dspy_cli/server/execution.py index 8279392..90c3741 100644 --- a/src/dspy_cli/server/execution.py +++ b/src/dspy_cli/server/execution.py @@ -277,7 +277,7 @@ async def execute_pipeline( logger.info(f"Executing {program_name} with inputs: {inputs}") with dspy.context(lm=request_lm): - if hasattr(instance, 'aforward'): + if module.has_native_async: result = await instance.acall(**inputs) else: result = instance(**inputs) diff --git a/tests/discovery/test_module_finder.py b/tests/discovery/test_module_finder.py index 0da2e57..29a19d7 100644 --- a/tests/discovery/test_module_finder.py +++ b/tests/discovery/test_module_finder.py @@ -7,6 +7,7 @@ from dspy_cli.discovery.module_finder import ( DiscoveredModule, _extract_gateway_classes, + _has_user_implemented_aforward, discover_modules, ) from dspy_cli.gateway import APIGateway, CronGateway, IdentityGateway @@ -122,6 +123,65 @@ def forward(self, text: str) -> str: assert result == [] +class TestHasUserImplementedAforward: + """Tests for _has_user_implemented_aforward.""" + + def test_returns_false_for_plain_module(self): + class PlainModule(dspy.Module): + def forward(self, x: str) -> str: + return x + + assert _has_user_implemented_aforward(PlainModule) is False + + def test_returns_true_for_direct_aforward(self): + class AsyncModule(dspy.Module): + def forward(self, x: str) -> str: + return x + + async def aforward(self, x: str) -> str: + return x + + assert _has_user_implemented_aforward(AsyncModule) is True + + def test_returns_true_for_inherited_aforward_from_intermediate_base(self): + """aforward defined on an intermediate class (not dspy.Module) should be detected.""" + class AsyncBase(dspy.Module): + def forward(self, x: str) -> str: + return x + + async def aforward(self, x: str) -> str: + return x + + class ConcreteChild(AsyncBase): + pass + + assert _has_user_implemented_aforward(ConcreteChild) is True + + def test_returns_false_for_dspy_module_default_aforward(self): + """The base dspy.Module aforward (if any) should not count.""" + class BasicModule(dspy.Module): + def forward(self, x: str) -> str: + return x + + assert _has_user_implemented_aforward(BasicModule) is False + + def test_deep_inheritance_chain(self): + class Base(dspy.Module): + def forward(self, x: str) -> str: + return x + + async def aforward(self, x: str) -> str: + return x + + class Mid(Base): + pass + + class Leaf(Mid): + pass + + assert _has_user_implemented_aforward(Leaf) is True + + class TestDiscoverModulesWithGateway: """Integration tests for discover_modules with gateway extraction.""" diff --git a/tests/test_execution.py b/tests/test_execution.py index 98c977a..964a576 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -212,6 +212,7 @@ def test_successful_execution(self, tmp_path): module = MagicMock() module.is_forward_typed = False module.forward_output_fields = None + module.has_native_async = False # Use spec to ensure instance doesn't have aforward (sync execution path) instance = MagicMock(spec=['__call__']) @@ -242,6 +243,7 @@ def test_async_module_execution(self, tmp_path): module = MagicMock() module.is_forward_typed = False module.forward_output_fields = None + module.has_native_async = True instance = MagicMock(spec=['aforward', 'acall']) instance.acall = AsyncMock(return_value={"async_result": "done"}) @@ -271,6 +273,7 @@ def test_logs_on_success(self, tmp_path): module = MagicMock() module.is_forward_typed = False module.forward_output_fields = None + module.has_native_async = False # Use spec to ensure instance doesn't have aforward (sync execution path) instance = MagicMock(spec=['__call__']) @@ -306,6 +309,7 @@ def test_logs_on_error(self, tmp_path): module = MagicMock() module.is_forward_typed = False module.forward_output_fields = None + module.has_native_async = False # Use spec to ensure instance doesn't have aforward (sync execution path) instance = MagicMock(spec=['__call__']) @@ -344,6 +348,7 @@ def test_copies_lm_for_isolation(self, tmp_path): module = MagicMock() module.is_forward_typed = False module.forward_output_fields = None + module.has_native_async = False # Use spec to ensure instance doesn't have aforward (sync execution path) instance = MagicMock(spec=['__call__'])