Skip to content

Commit f4a016e

Browse files
authored
Enhance StreamListener to support generic type annotations for output (#9112)
* Enhance StreamListener to support generic type annotations for output fields - Added import of inspect to facilitate type checking. - Updated the condition for handling custom streamable types to include a check for class type using inspect.isclass. - Introduced a new test for StreamListener to validate behavior with generic type annotations in output fields. Signed-off-by: TomuHirata <tomu.hirata@gmail.com> * test * comment Signed-off-by: TomuHirata <tomu.hirata@gmail.com> --------- Signed-off-by: TomuHirata <tomu.hirata@gmail.com>
1 parent d27c5af commit f4a016e

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

dspy/streaming/streaming_listener.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import re
23
from collections import defaultdict
34
from queue import Queue
@@ -135,7 +136,12 @@ def receive(self, chunk: ModelResponseStream):
135136
return
136137

137138
# Handle custom streamable types
138-
if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable():
139+
if (
140+
self._output_type
141+
and inspect.isclass(self._output_type)
142+
and issubclass(self._output_type, Type)
143+
and self._output_type.is_streamable()
144+
):
139145
if parsed_chunk := self._output_type.parse_stream_chunk(chunk):
140146
return StreamResponse(
141147
self.predict_name,

tests/streaming/test_streaming.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,52 @@ async def chat_stream(*args, **kwargs):
12051205
assert "success" in full_content
12061206

12071207

1208+
@pytest.mark.anyio
1209+
async def test_chat_adapter_with_generic_type_annotation():
1210+
class TestSignature(dspy.Signature):
1211+
question: str = dspy.InputField()
1212+
response: list[str] | int = dspy.OutputField()
1213+
1214+
class MyProgram(dspy.Module):
1215+
def __init__(self):
1216+
self.predict = dspy.Predict(TestSignature)
1217+
1218+
def forward(self, question, **kwargs):
1219+
return self.predict(question=question, **kwargs)
1220+
1221+
async def chat_stream(*args, **kwargs):
1222+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
1223+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" response"))])
1224+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))])
1225+
yield ModelResponseStream(
1226+
model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="1"))]
1227+
)
1228+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ##"))])
1229+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
1230+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]"))])
1231+
1232+
program = dspy.streamify(
1233+
MyProgram(),
1234+
stream_listeners=[
1235+
dspy.streaming.StreamListener(signature_field_name="response"),
1236+
],
1237+
)
1238+
1239+
with mock.patch("litellm.acompletion", side_effect=chat_stream):
1240+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()):
1241+
output = program(question="Say hello")
1242+
chunks = []
1243+
async for value in output:
1244+
if isinstance(value, StreamResponse):
1245+
chunks.append(value)
1246+
1247+
assert len(chunks) > 0
1248+
assert chunks[0].signature_field_name == "response"
1249+
1250+
full_content = "".join(chunk.chunk for chunk in chunks)
1251+
assert "1" in full_content
1252+
1253+
12081254
@pytest.mark.anyio
12091255
async def test_chat_adapter_nested_pydantic_streaming():
12101256
"""Test ChatAdapter streaming with nested pydantic model."""

0 commit comments

Comments
 (0)