Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/dspy_cli/discovery/module_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DiscoveredModule:
forward_input_fields: Optional[Dict[str, Any]] = None # Input field types from forward() method
forward_output_fields: Optional[Dict[str, Any]] = None # Output field types from forward() method
is_forward_typed: bool = False # True if forward() has proper type annotations
has_native_async: bool = False # True if the module defines its own aforward()
gateway_classes: List[Type["Gateway"]] = None # Gateway classes if specified on module (supports list)

# Backward compatibility: single gateway_class property
Expand All @@ -43,6 +44,23 @@ def instantiate(self, lm: dspy.LM | None = None) -> dspy.Module:
return self.class_obj()


def _has_user_implemented_aforward(cls: Type[dspy.Module]) -> bool:
"""Check if a class (or an intermediate base) defines aforward().

Walks the MRO but stops before dspy.Module itself, so the default
base-class aforward is ignored while user-defined aforward on any
intermediate superclass is still detected.
"""
for klass in cls.__mro__:
if klass is dspy.Module or klass is object:
break
if 'aforward' in klass.__dict__:
method = klass.__dict__['aforward']
if callable(method) or isinstance(method, (staticmethod, classmethod)):
return True
return False


def discover_modules(
package_path: Path,
package_name: str,
Expand Down Expand Up @@ -135,6 +153,10 @@ def discover_modules(
# Extract gateway classes if specified (supports single or list)
gateway_classes = _extract_gateway_classes(obj)

native_async = _has_user_implemented_aforward(obj)
if native_async:
logger.info(f"Module {name} has native async support (aforward)")

discovered.append(
DiscoveredModule(
name=name,
Expand All @@ -144,6 +166,7 @@ def discover_modules(
forward_input_fields=forward_info.get("inputs"),
forward_output_fields=forward_info.get("outputs"),
is_forward_typed=forward_info.get("is_typed", False),
has_native_async=native_async,
gateway_classes=gateway_classes,
)
)
Expand Down
2 changes: 1 addition & 1 deletion src/dspy_cli/server/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def execute_pipeline(
logger.info(f"Executing {program_name} with inputs: {inputs}")

with dspy.context(lm=request_lm):
if hasattr(instance, 'aforward'):
if module.has_native_async:
result = await instance.acall(**inputs)
else:
result = instance(**inputs)
Expand Down
60 changes: 60 additions & 0 deletions tests/discovery/test_module_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dspy_cli.discovery.module_finder import (
DiscoveredModule,
_extract_gateway_classes,
_has_user_implemented_aforward,
discover_modules,
)
from dspy_cli.gateway import APIGateway, CronGateway, IdentityGateway
Expand Down Expand Up @@ -122,6 +123,65 @@ def forward(self, text: str) -> str:
assert result == []


class TestHasUserImplementedAforward:
"""Tests for _has_user_implemented_aforward."""

def test_returns_false_for_plain_module(self):
class PlainModule(dspy.Module):
def forward(self, x: str) -> str:
return x

assert _has_user_implemented_aforward(PlainModule) is False

def test_returns_true_for_direct_aforward(self):
class AsyncModule(dspy.Module):
def forward(self, x: str) -> str:
return x

async def aforward(self, x: str) -> str:
return x

assert _has_user_implemented_aforward(AsyncModule) is True

def test_returns_true_for_inherited_aforward_from_intermediate_base(self):
"""aforward defined on an intermediate class (not dspy.Module) should be detected."""
class AsyncBase(dspy.Module):
def forward(self, x: str) -> str:
return x

async def aforward(self, x: str) -> str:
return x

class ConcreteChild(AsyncBase):
pass

assert _has_user_implemented_aforward(ConcreteChild) is True

def test_returns_false_for_dspy_module_default_aforward(self):
"""The base dspy.Module aforward (if any) should not count."""
class BasicModule(dspy.Module):
def forward(self, x: str) -> str:
return x

assert _has_user_implemented_aforward(BasicModule) is False

def test_deep_inheritance_chain(self):
class Base(dspy.Module):
def forward(self, x: str) -> str:
return x

async def aforward(self, x: str) -> str:
return x

class Mid(Base):
pass

class Leaf(Mid):
pass

assert _has_user_implemented_aforward(Leaf) is True


class TestDiscoverModulesWithGateway:
"""Integration tests for discover_modules with gateway extraction."""

Expand Down
5 changes: 5 additions & 0 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_successful_execution(self, tmp_path):
module = MagicMock()
module.is_forward_typed = False
module.forward_output_fields = None
module.has_native_async = False

# Use spec to ensure instance doesn't have aforward (sync execution path)
instance = MagicMock(spec=['__call__'])
Expand Down Expand Up @@ -242,6 +243,7 @@ def test_async_module_execution(self, tmp_path):
module = MagicMock()
module.is_forward_typed = False
module.forward_output_fields = None
module.has_native_async = True

instance = MagicMock(spec=['aforward', 'acall'])
instance.acall = AsyncMock(return_value={"async_result": "done"})
Expand Down Expand Up @@ -271,6 +273,7 @@ def test_logs_on_success(self, tmp_path):
module = MagicMock()
module.is_forward_typed = False
module.forward_output_fields = None
module.has_native_async = False

# Use spec to ensure instance doesn't have aforward (sync execution path)
instance = MagicMock(spec=['__call__'])
Expand Down Expand Up @@ -306,6 +309,7 @@ def test_logs_on_error(self, tmp_path):
module = MagicMock()
module.is_forward_typed = False
module.forward_output_fields = None
module.has_native_async = False

# Use spec to ensure instance doesn't have aforward (sync execution path)
instance = MagicMock(spec=['__call__'])
Expand Down Expand Up @@ -344,6 +348,7 @@ def test_copies_lm_for_isolation(self, tmp_path):
module = MagicMock()
module.is_forward_typed = False
module.forward_output_fields = None
module.has_native_async = False

# Use spec to ensure instance doesn't have aforward (sync execution path)
instance = MagicMock(spec=['__call__'])
Expand Down