Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250522003454958473.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add full llm response to LLM PRovider output"
}
10 changes: 8 additions & 2 deletions graphrag/language_model/providers/fnllm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ async def achat(
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
output=BaseModelOutput(
content=response.output.content,
full_response=response.output.raw_model.to_dict(),
),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
Expand Down Expand Up @@ -282,7 +285,10 @@ async def achat(
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
output=BaseModelOutput(
content=response.output.content,
full_response=response.output.raw_model.to_dict(),
),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
Expand Down
9 changes: 9 additions & 0 deletions graphrag/language_model/response/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def content(self) -> str:
"""Return the textual content of the output."""
...

@property
def full_response(self) -> dict[str, Any] | None:
"""Return the complete JSON response returned by the model."""
...


class ModelResponse(Protocol, Generic[T]):
"""Protocol for LLM response."""
Expand All @@ -43,6 +48,10 @@ class BaseModelOutput(BaseModel):

content: str = Field(..., description="The textual content of the output.")
"""The textual content of the output."""
full_response: dict[str, Any] | None = Field(
None, description="The complete JSON response returned by the LLM provider."
)
"""The complete JSON response returned by the LLM provider."""


class BaseModelResponse(BaseModel, Generic[T]):
Expand Down
50 changes: 50 additions & 0 deletions graphrag/language_model/response/base.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

from typing import Any, Generic, Protocol, TypeVar

from pydantic import BaseModel

_T = TypeVar("_T", bound=BaseModel, covariant=True)

class ModelOutput(Protocol):
@property
def content(self) -> str: ...
@property
def full_response(self) -> dict[str, Any] | None: ...

class ModelResponse(Protocol, Generic[_T]):
@property
def output(self) -> ModelOutput: ...
@property
def parsed_response(self) -> _T | None: ...
@property
def history(self) -> list[Any]: ...

class BaseModelOutput(BaseModel):
content: str
full_response: dict[str, Any] | None

def __init__(
self,
content: str,
full_response: dict[str, Any] | None = None,
) -> None: ...

class BaseModelResponse(BaseModel, Generic[_T]):
output: BaseModelOutput
parsed_response: _T | None
history: list[Any]
tool_calls: list[Any]
metrics: Any | None
cache_hit: bool | None

def __init__(
self,
output: BaseModelOutput,
parsed_response: _T | None = None,
history: list[Any] = ..., # default provided by Pydantic
tool_calls: list[Any] = ..., # default provided by Pydantic
metrics: Any | None = None,
cache_hit: bool | None = None,
) -> None: ...
8 changes: 7 additions & 1 deletion tests/integration/language_model/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ async def achat(
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
return BaseModelResponse(
output=BaseModelOutput(
content="content", full_response={"key": "value"}
)
)

async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
Expand All @@ -49,9 +53,11 @@ def chat_stream(
assert isinstance(model, CustomChatModel)
response = await model.achat("prompt")
assert response.output.content == "content"
assert response.output.full_response is None

response = model.chat("prompt")
assert response.output.content == "content"
assert response.output.full_response == {"key": "value"}


async def test_create_custom_embedding_llm():
Expand Down
Loading