Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,11 @@
*/
public class OllamaChatModelSetup extends BaseChatModelSetup {

private final String model;
private final Object think;
private final boolean extractReasoning;

public OllamaChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) {
super(descriptor, resourceContext);
this.model = descriptor.getArgument("model");
this.think = descriptor.getArgument("think", true);
this.extractReasoning = descriptor.getArgument("extract_reasoning", true);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.agents.integrations.chatmodels.ollama;

import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link OllamaChatModelSetup}. */
class OllamaChatModelSetupTest {

private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null);

@Test
void getModel_returnsValueFromDescriptor() {
ResourceDescriptor desc =
ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName())
.addInitialArgument("connection", "dummy-connection")
.addInitialArgument("model", "qwen3:4b")
.build();
OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP);

assertThat(setup.getModel()).isEqualTo("qwen3:4b");
}

@Test
void getParameters_includesModelFromDescriptor() {
ResourceDescriptor desc =
ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName())
.addInitialArgument("connection", "dummy-connection")
.addInitialArgument("model", "qwen3:4b")
.build();
OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP);

assertThat(setup.getParameters()).containsEntry("model", "qwen3:4b");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ public static void main(String[] args) throws IOException {
kwargs.put("name", "chat_model");
kwargs.put("prompt", "prompt");
kwargs.put("tools", List.of("add"));
kwargs.put("connection", "mock_connection");
kwargs.put("model", "mock-model");
ResourceDescriptor chatModelDescriptor =
new ResourceDescriptor(
"flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent",
Expand Down
1 change: 1 addition & 0 deletions python/flink_agents/api/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class OutputData(BaseModel):
chat_model=ResourceDescriptor(
clazz=OllamaChatModelSetup,
connection="ollama_server",
model="qwen3:8b",
tools=["notify_shipping_manager"],
),
prompt=prompt,
Expand Down
3 changes: 3 additions & 0 deletions python/flink_agents/api/chat_models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class BaseChatModelSetup(Resource):
"""Base abstract class for chat model setup.

Responsible for managing chat configurations, such as:
- Connection to chat model service (connection)
- Model name (model)
- Prompt templates (prompt)
- Available tools (tools)
- Generation parameters (temperature, max_tokens, etc.)
Expand All @@ -143,6 +145,7 @@ class BaseChatModelSetup(Resource):
"""

connection: str = Field(description="The referenced connection name.")
model: str = Field(description="Name of the chat model to use.")
_resolved_connection: BaseChatModelConnection | None = PrivateAttr(default=None)
prompt: Prompt | str | None = None
tools: List[str] | List[Tool] = Field(default_factory=list)
Expand Down
47 changes: 47 additions & 0 deletions python/flink_agents/api/chat_models/tests/test_chat_model_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
from typing import Any, Dict

import pytest
from pydantic import ValidationError

from flink_agents.api.chat_models.chat_model import BaseChatModelSetup


class _MinimalChatModelSetup(BaseChatModelSetup):
"""Minimal subclass that omits the `model` field declaration.

Used to assert the `model` field is inherited from `BaseChatModelSetup`.
"""

@property
def model_kwargs(self) -> Dict[str, Any]:
"""Return chat model settings derived from the inherited `model` field."""
return {"model": self.model}


def test_inherits_model_field_from_base() -> None:
"""A subclass that omits `model` still exposes it via inheritance."""
setup = _MinimalChatModelSetup(connection="c", model="m1")
assert setup.model == "m1"


def test_missing_model_raises_validation_error() -> None:
"""Constructing without `model` must raise a Pydantic ValidationError."""
with pytest.raises(ValidationError):
_MinimalChatModelSetup(connection="c")
12 changes: 6 additions & 6 deletions python/flink_agents/api/chat_models/tests/test_token_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TestBaseChatModelTokenMetrics:

def test_record_token_metrics_with_metric_group(self) -> None:
"""Test token metrics are recorded when metric group is set."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()

# Set the metric group
Expand All @@ -116,7 +116,7 @@ def test_record_token_metrics_with_metric_group(self) -> None:

def test_record_token_metrics_without_metric_group(self) -> None:
"""Test token metrics are not recorded when metric group is null."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")

# Do not set metric group (should be None by default)
# Record token metrics - should not throw
Expand All @@ -125,7 +125,7 @@ def test_record_token_metrics_without_metric_group(self) -> None:

def test_token_metrics_hierarchy(self) -> None:
"""Test token metrics hierarchy: actionMetricGroup -> modelName -> counters."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()

# Set the metric group
Expand All @@ -148,7 +148,7 @@ def test_token_metrics_hierarchy(self) -> None:

def test_token_metrics_accumulation(self) -> None:
"""Test that token metrics accumulate across multiple calls."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()

# Set the metric group
Expand All @@ -165,12 +165,12 @@ def test_token_metrics_accumulation(self) -> None:

def test_resource_type(self) -> None:
"""Test resource type is CHAT_MODEL_CONNECTION."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")
assert chat_model.resource_type() == ResourceType.CHAT_MODEL

def test_bound_metric_group_property(self) -> None:
"""Test bound_metric_group property."""
chat_model = TestChatModelSetup(connection="mock")
chat_model = TestChatModelSetup(connection="mock", model="mock-model")

# Initially should be None
assert chat_model.metric_group is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def slow_chat_model() -> ResourceDescriptor:
return ResourceDescriptor(
clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}",
connection="placement",
model="slow-mock-model",
tools=["add"],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,19 @@ class AnthropicChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
model : str
Specifies the Anthropic model to use. Defaults to claude-sonnet-4-20250514
when omitted via ``__init__``. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from BaseChatModelSetup)
model : str
Specifies the Anthropic model to use. Defaults to claude-sonnet-4-20250514.
max_tokens: int
The maximum number of tokens to generate before stopping. Defaults to 1024.
temperature : float
Amount of randomness injected into the response.
"""

model: str = Field(
default=DEFAULT_ANTHROPIC_MODEL,
description="Specifies the Anthropic model to use. Defaults to "
"claude-sonnet-4-20250514.",
)
max_tokens: int = Field(
default=DEFAULT_MAX_TOKENS,
description="The maximum number of tokens to generate before stopping. Defaults to 1024.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.resource_context import ResourceContext
from flink_agents.integrations.chat_models.anthropic.anthropic_chat_model import (
DEFAULT_ANTHROPIC_MODEL,
AnthropicChatModelConnection,
AnthropicChatModelSetup,
)
Expand Down Expand Up @@ -103,3 +104,16 @@ def get_resource(name: str, type: ResourceType) -> Resource:
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 2
assert tool_call.get("original_id") is not None


def test_model_field_roundtrip() -> None:
"""Verify `model` is preserved through pydantic dump/validate round-trip."""
setup = AnthropicChatModelSetup(connection="conn", model="test-model")
restored = AnthropicChatModelSetup.model_validate(setup.model_dump())
assert restored.model == "test-model"


def test_default_model_when_omitted() -> None:
"""Verify per-integration default applies when `model` is omitted from __init__."""
setup = AnthropicChatModelSetup(connection="conn")
assert setup.model == DEFAULT_ANTHROPIC_MODEL
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
model : str
Name of OpenAI model deployment on Azure. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from BaseChatModelSetup)
model : str
Name of OpenAI model deployment on Azure.
model_of_azure_deployment : Optional[str]
The underlying model name of the Azure deployment (e.g., 'gpt-4').
Used for token counting and cost calculation.
Expand All @@ -193,9 +193,6 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
Additional kwargs for the Azure OpenAI API.
"""

model: str = Field(
description="Name of OpenAI model deployment on Azure.",
)
model_of_azure_deployment: str | None = Field(
default=None,
description="The underlying model name of the Azure deployment (e.g., 'gpt-4', "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,10 @@ def get_resource(name: str, type: ResourceType) -> Resource:
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 1065


def test_model_field_roundtrip() -> None:
"""Verify `model` is preserved through pydantic dump/validate round-trip."""
setup = AzureOpenAIChatModelSetup(connection="conn", model="test-deployment")
restored = AzureOpenAIChatModelSetup.model_validate(setup.model_dump())
assert restored.model == "test-deployment"
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ class OllamaChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
model : str
Model name to use. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from BaseChatModelSetup)
model : str
Model name to use.
temperature : float
The temperature to use for sampling.
num_ctx : int
Expand All @@ -196,8 +196,6 @@ class OllamaChatModelSetup(BaseChatModelSetup):
stores it in additional_kwargs.
"""

model: str = Field(description="Model name to use.")

temperature: float = Field(
default=0.75,
description="The temperature to use for sampling.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,13 @@ class OpenAIChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
model : str
The OpenAI model to use. Defaults to ``DEFAULT_OPENAI_MODEL`` when omitted via
``__init__``. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from BaseChatModelSetup)
model : str
The OpenAI model to use.
temperature : float
The temperature to use during generation.
max_tokens : Optional[int]
Expand All @@ -225,9 +226,6 @@ class OpenAIChatModelSetup(BaseChatModelSetup):
The effort to use for reasoning models.
"""

model: str = Field(
default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use."
)
temperature: float = Field(
default=DEFAULT_TEMPERATURE,
description="The temperature to use during generation.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.resource_context import ResourceContext
from flink_agents.integrations.chat_models.openai.openai_chat_model import (
DEFAULT_OPENAI_MODEL,
OpenAIChatModelConnection,
OpenAIChatModelSetup,
)
Expand Down Expand Up @@ -104,3 +105,16 @@ def get_resource(name: str, type: ResourceType) -> Resource:
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 1065


def test_model_field_roundtrip() -> None:
"""Verify `model` is preserved through pydantic dump/validate round-trip."""
setup = OpenAIChatModelSetup(connection="conn", model="test-model")
restored = OpenAIChatModelSetup.model_validate(setup.model_dump())
assert restored.model == "test-model"


def test_default_model_when_omitted() -> None:
"""Verify per-integration default applies when `model` is omitted from __init__."""
setup = OpenAIChatModelSetup(connection="conn")
assert setup.model == DEFAULT_OPENAI_MODEL
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def get_resource(name: str, type: ResourceType) -> Resource:
assert add(**tool_call["function"]["arguments"]) == 3


def test_model_field_roundtrip() -> None:
"""Verify `model` is preserved through pydantic dump/validate round-trip."""
setup = OllamaChatModelSetup(connection="conn", model="test-model")
restored = OllamaChatModelSetup.model_validate(setup.model_dump())
assert restored.model == "test-model"


def test_extract_think_tags() -> None:
"""Test the static method that extracts content from <think></think> tags."""
# Test with a think tag at the beginning (most common case)
Expand Down
Loading
Loading