Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/wayflowcore/source/core/api/conversation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Messages
.. _textcontent:
.. autoclass:: wayflowcore.messagelist.TextContent

.. _texttokenlogprob:
.. autoclass:: wayflowcore.messagelist.TextTokenLogProb

.. _texttokentoplogprob:
.. autoclass:: wayflowcore.messagelist.TextTokenTopLogProb

.. _imagecontent:
.. autoclass:: wayflowcore.messagelist.ImageContent

Expand Down
2 changes: 1 addition & 1 deletion docs/wayflowcore/source/core/api/llmmodels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Class that is used to gather all token usage information.
LLM Generation Config
---------------------

Parameters for LLM generation (``max_tokens``, ``temperature``, ``top_p``).
Parameters for LLM generation (``max_tokens``, ``temperature``, ``top_p``, ``top_logprobs``).

.. _llmgenerationconfig:
.. autoclass:: wayflowcore.models.llmgenerationconfig.LlmGenerationConfig
Expand Down
8 changes: 8 additions & 0 deletions docs/wayflowcore/source/core/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ New features
:doc:`how to use LLMs from different providers <howtoguides/llm_from_different_providers>`.


* **Logprob support in `LlmGenerationConfig` and `PromptExecutionStep`**

Add per-token log-probabilities support with the ``top_logprobs`` generation config parameter and support returning
per-token log-probabilities in the ``PromptExecutionStep``.
For more information please read the guide on :ref:`How to request per-token log-probabilities <request_logprobs>`



Comment thread
jschweiz marked this conversation as resolved.
Improvements
^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@
print(conversation.get_last_message())
# .. end-##_Build_the_agent_and_run_it

# .. start-##_Request_logprobs_from_a_direct_llm_call
from wayflowcore.messagelist import Message, TextContent
from wayflowcore.models import Prompt

prompt = Prompt(
messages=[Message(content="Say 'Bern' and nothing else.")],
generation_config=LlmGenerationConfig(top_logprobs=2, max_tokens=16),
)
completion = llm.generate(prompt)
text_chunk = next(chunk for chunk in completion.message.contents if isinstance(chunk, TextContent))

print(text_chunk.content)
print(text_chunk.logprobs)
# .. end-##_Request_logprobs_from_a_direct_llm_call


from wayflowcore.controlconnection import ControlFlowEdge
from wayflowcore.dataconnection import DataFlowEdge
Expand Down Expand Up @@ -80,6 +95,38 @@
conversation.execute()
# .. end-##_Build_the_flow_using_custom_generation_parameters

# .. start-##_Request_logprobs_from_a_flow_step
from wayflowcore.executors.executionstatus import FinishedStatus

logprob_start_step = StartStep(
name="logprob_start_step",
input_descriptors=[StringProperty("user_question")],
)
logprob_step = PromptExecutionStep(
name="PromptExecutionWithLogprobs",
prompt_template="{{user_question}}",
llm=llm,
top_logprobs=2,
)
logprob_flow = Flow(
begin_step=logprob_start_step,
control_flow_edges=[
ControlFlowEdge(source_step=logprob_start_step, destination_step=logprob_step),
ControlFlowEdge(source_step=logprob_step, destination_step=None),
],
data_flow_edges=[
DataFlowEdge(logprob_start_step, "user_question", logprob_step, "user_question")
],
)
conversation = logprob_flow.start_conversation(
inputs={"user_question": "What is the capital of Switzerland?"}
)
status = conversation.execute()
if isinstance(status, FinishedStatus):
print(status.output_values[PromptExecutionStep.OUTPUT])
print(status.output_values[PromptExecutionStep.LOGPROBS])
# .. end-##_Request_logprobs_from_a_flow_step

# .. start-##_Export_config_to_Agent_Spec
from wayflowcore.agentspec import AgentSpecExporter

Expand Down
35 changes: 34 additions & 1 deletion docs/wayflowcore/source/core/howtoguides/generation_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ How to Specify the Generation Configuration when Using LLMs

Python script/notebook for this guide.

Generation parameters, such as temperature, top-p, and the maximum number of output tokens, are important for achieving the desired performance with Large Language Models (LLMs).
Generation parameters, such as temperature, top-p, the maximum number of output tokens, and per-token log-probabilities, are important for achieving the desired performance with Large Language Models (LLMs).
In WayFlow, these parameters can be configured with the :ref:`LlmGenerationConfig <llmgenerationconfig>` class.

This guide will show you how to:

- Configure the generation parameters for an agent.
- Configure the generation parameters for a flow.
- Request token log probabilities.
- Apply the generation configuration from a dictionary.
- Save a custom generation configuration.

Expand Down Expand Up @@ -54,6 +55,7 @@ The generation configuration dictionary can have the following arguments:
- ``top_p``: controls the randomness of the output;
- ``stop``: defines a list of stop words to indicate the LLM to stop generating;
- ``frequency_penalty``: controls the frequency of tokens generated.
- ``top_logprobs``: requests token-level log probabilities, including alternate candidates when the provider supports them.

Additionally, the :ref:`LlmGenerationConfig <llmgenerationconfig>` offers the possibility to set a dictionary
of arbitrary parameters, called ``extra_args``, that will be sent as part of the llm generation call.
Expand Down Expand Up @@ -111,6 +113,37 @@ Advanced usage

The :ref:`LlmGenerationConfig <llmgenerationconfig>` class is a serializable object. It can be instantiated from a dictionary or saved to one, as you will see below.


.. _request_logprobs:

Request token log probabilities
-------------------------------

Use ``top_logprobs`` when you want the model to return token-level probabilities for generated text.
WayFlow stores those values on ``TextContent.logprobs`` for direct LLM calls, and the
:ref:`PromptExecutionStep <promptexecutionstep>` also exposes them as an additional ``logprobs`` output.

.. note::

``top_logprobs`` is only available for raw text generation.
It is not supported with structured generation in :ref:`PromptExecutionStep <promptexecutionstep>`,
and support depends on the selected provider and model.

For direct ``LlmModel`` calls, configure ``top_logprobs`` on the prompt and inspect the ``TextContent`` chunk:

.. literalinclude:: ../code_examples/example_generationconfig.py
:language: python
:start-after: .. start-##_Request_logprobs_from_a_direct_llm_call
:end-before: .. end-##_Request_logprobs_from_a_direct_llm_call

For flows, you can request logprobs directly on :ref:`PromptExecutionStep <promptexecutionstep>`.
When enabled, the step appends a ``logprobs`` output alongside the normal text output:

.. literalinclude:: ../code_examples/example_generationconfig.py
:language: python
:start-after: .. start-##_Request_logprobs_from_a_flow_step
:end-before: .. end-##_Request_logprobs_from_a_flow_step

Apply the generation configuration from a dictionary
----------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion docs/wayflowcore/source/core/misc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Generation Config
=================

The :ref:`LLM generation config <llmgenerationconfig>` is the set of parameters that control the output of a :ref:`Large Language Model (LLM) <llmmodel>` in WayFlow.
These parameters include the maximum number of tokens to generate (``max_tokens``), the sampling ``temperature``, and the probability threshold for nucleus sampling (``top_p``).
These parameters include the maximum number of tokens to generate (``max_tokens``), the sampling ``temperature``, the probability threshold for nucleus sampling (``top_p``), and optional per-token log-probabilities (``top_logprobs``).

Learn more about the LLM generation config in the :doc:`How to Specify the Generation Configuration when Using LLMs <../howtoguides/generation_config>`
or read the :ref:`API reference <llmgenerationconfig>`.
Expand Down
58 changes: 58 additions & 0 deletions wayflowcore/src/wayflowcore/messagelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from wayflowcore._utils.hash import fast_stable_hash
from wayflowcore.serialization.context import DeserializationContext, SerializationContext
from wayflowcore.serialization.serializer import (
FrozenSerializableDataclass,
SerializableDataclass,
SerializableDataclassMixin,
SerializableObject,
Expand All @@ -51,6 +52,28 @@
_ReasoningContent: TypeAlias = Dict[str, Any]


@dataclass(frozen=True, slots=True)
class TextTokenTopLogProb(FrozenSerializableDataclass):
"""Represents a single candidate token with its associated log probability."""

token: str
"""The literal text of the candidate token."""
logprob: float
"""The log probability assigned to the candidate token."""


@dataclass(frozen=True, slots=True)
class TextTokenLogProb(FrozenSerializableDataclass):
"""Captures a generated token, its log probability, and alternate candidates."""

token: str
"""The literal text of the generated token."""
logprob: float
"""The log probability assigned to the generated token. Range from -inf to 0."""
top_logprobs: Optional[List[TextTokenTopLogProb]] = None
"""Optional ranked list of alternate tokens with probabilities."""


class MessageType(str, Enum):
"""Type of messages"""

Expand Down Expand Up @@ -98,6 +121,7 @@ class TextContent(MessageContent, SerializableObject):
"""

content: str = ""
logprobs: Optional[List[TextTokenLogProb]] = None
type: ClassVar[Literal["text"]] = "text"

def __post_init__(self) -> None:
Expand All @@ -109,6 +133,40 @@ def __post_init__(self) -> None:
)
self.content = str(self.content)

if self.logprobs is None:
return

# We accept both already-built `TextTokenLogProb` objects and raw dicts
# (e.g., from provider payloads) to keep adapters simple.
validated: List[TextTokenLogProb] = []
for item in self.logprobs:
if isinstance(item, TextTokenLogProb):
validated.append(item)
continue

raw_item = cast(Dict[str, Any], item)
raw_top = raw_item.get("top_logprobs")
top_converted = None
if raw_top is not None:
top_converted = [
(
c
if isinstance(c, TextTokenTopLogProb)
else TextTokenTopLogProb(**cast(Dict[str, Any], c))
)
for c in raw_top
]

validated.append(
TextTokenLogProb(
token=raw_item["token"],
logprob=raw_item["logprob"],
top_logprobs=top_converted,
)
)

self.logprobs = validated


@dataclass
class ImageContent(MessageContent, SerializableObject):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, TypedDict

from wayflowcore._utils.formatting import stringify
from wayflowcore.messagelist import ImageContent, Message, TextContent
from wayflowcore.messagelist import (
ImageContent,
Message,
TextContent,
TextTokenLogProb,
TextTokenTopLogProb,
)
from wayflowcore.tokenusage import TokenUsage
from wayflowcore.tools import Tool, ToolRequest
from wayflowcore.tools.tools import ExtraContentT
Expand All @@ -31,6 +37,28 @@ class OpenAIToolRequestAsDictT(TypedDict, total=True):

class _ChatCompletionsAPIProcessor(_APIProcessor):

@staticmethod
def _convert_openai_logprobs_into_text_logprobs(logprobs: Any) -> List[TextTokenLogProb]:
converted: List[TextTokenLogProb] = []
if not logprobs:
return converted

for item in logprobs:
top = item.get("top_logprobs")
top_converted = None
if top is not None:
top_converted = [
TextTokenTopLogProb(token=c["token"], logprob=c["logprob"]) for c in top
]
converted.append(
TextTokenLogProb(
token=item["token"],
logprob=item["logprob"],
top_logprobs=top_converted,
)
)
return converted

@staticmethod
def _tool_to_openai_function_dict(tool: Tool) -> Dict[str, Any]:
openai_function_dict: Dict[str, Any] = {
Expand Down Expand Up @@ -159,7 +187,14 @@ def _convert_generation_params(
kwargs["stop"] = generation_config.stop
if generation_config.frequency_penalty is not None:
kwargs["frequency_penalty"] = generation_config.frequency_penalty
if generation_config.top_logprobs is not None:
kwargs["logprobs"] = True
kwargs["top_logprobs"] = generation_config.top_logprobs
if generation_config.extra_args:
if "include" in kwargs and "include" in generation_config.extra_args:
# prevent overriding any include
kwargs["include"].update(generation_config.extra_args["update"])
generation_config.extra_args.pop("include")
kwargs.update(generation_config.extra_args)
return kwargs

Expand All @@ -183,9 +218,17 @@ def _convert_openai_response_into_message(self, response: Any) -> "Message":
# content might be empty when certain models (like gemini) decide
# to finish the conversation
content = extracted_message.get("content", "")

logprobs = None
choice_logprobs = response["choices"][0].get("logprobs")
if choice_logprobs and choice_logprobs.get("content") is not None:
logprobs = self._convert_openai_logprobs_into_text_logprobs(
choice_logprobs["content"]
)

message = Message(
role="assistant",
contents=[TextContent(content=content)],
contents=[TextContent(content=content, logprobs=logprobs)],
_extra_content=extracted_message.get("extra_content"),
)
return message
Expand Down
Loading