diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java index 81a3f92ff..a04ea3215 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java @@ -53,13 +53,11 @@ */ public class OllamaChatModelSetup extends BaseChatModelSetup { - private final String model; private final Object think; private final boolean extractReasoning; public OllamaChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { super(descriptor, resourceContext); - this.model = descriptor.getArgument("model"); this.think = descriptor.getArgument("think", true); this.extractReasoning = descriptor.getArgument("extract_reasoning", true); } diff --git a/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java b/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java new file mode 100644 index 000000000..7836339ee --- /dev/null +++ b/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.ollama; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link OllamaChatModelSetup}. */ +class OllamaChatModelSetupTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + @Test + void getModel_returnsValueFromDescriptor() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "dummy-connection") + .addInitialArgument("model", "qwen3:4b") + .build(); + OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP); + + assertThat(setup.getModel()).isEqualTo("qwen3:4b"); + } + + @Test + void getParameters_includesModelFromDescriptor() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "dummy-connection") + .addInitialArgument("model", "qwen3:4b") + .build(); + OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP); + + assertThat(setup.getParameters()).containsEntry("model", "qwen3:4b"); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java index 3388af225..c5589be3c 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java @@ -144,6 +144,8 @@ public static void main(String[] args) throws IOException { kwargs.put("name", "chat_model"); kwargs.put("prompt", "prompt"); kwargs.put("tools", List.of("add")); + kwargs.put("connection", "mock_connection"); + kwargs.put("model", "mock-model"); ResourceDescriptor chatModelDescriptor = new ResourceDescriptor( "flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent", diff --git a/python/flink_agents/api/agents/react_agent.py b/python/flink_agents/api/agents/react_agent.py index 740afbe73..cef651a1d 100644 --- a/python/flink_agents/api/agents/react_agent.py +++ b/python/flink_agents/api/agents/react_agent.py @@ -91,6 +91,7 @@ class OutputData(BaseModel): chat_model=ResourceDescriptor( clazz=OllamaChatModelSetup, connection="ollama_server", + model="qwen3:8b", tools=["notify_shipping_manager"], ), prompt=prompt, diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index 7d7fc1945..960559793 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -131,6 +131,8 @@ class BaseChatModelSetup(Resource): """Base abstract class for chat model setup. Responsible for managing chat configurations, such as: + - Connection to chat model service (connection) + - Model name (model) - Prompt templates (prompt) - Available tools (tools) - Generation parameters (temperature, max_tokens, etc.) @@ -143,6 +145,7 @@ class BaseChatModelSetup(Resource): """ connection: str = Field(description="The referenced connection name.") + model: str = Field(description="Name of the chat model to use.") _resolved_connection: BaseChatModelConnection | None = PrivateAttr(default=None) prompt: Prompt | str | None = None tools: List[str] | List[Tool] = Field(default_factory=list) diff --git a/python/flink_agents/api/chat_models/tests/test_chat_model_base.py b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py new file mode 100644 index 000000000..651061f49 --- /dev/null +++ b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py @@ -0,0 +1,47 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any, Dict + +import pytest +from pydantic import ValidationError + +from flink_agents.api.chat_models.chat_model import BaseChatModelSetup + + +class _MinimalChatModelSetup(BaseChatModelSetup): + """Minimal subclass that omits the `model` field declaration. + + Used to assert the `model` field is inherited from `BaseChatModelSetup`. + """ + + @property + def model_kwargs(self) -> Dict[str, Any]: + """Return chat model settings derived from the inherited `model` field.""" + return {"model": self.model} + + +def test_inherits_model_field_from_base() -> None: + """A subclass that omits `model` still exposes it via inheritance.""" + setup = _MinimalChatModelSetup(connection="c", model="m1") + assert setup.model == "m1" + + +def test_missing_model_raises_validation_error() -> None: + """Constructing without `model` must raise a Pydantic ValidationError.""" + with pytest.raises(ValidationError): + _MinimalChatModelSetup(connection="c") diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py b/python/flink_agents/api/chat_models/tests/test_token_metrics.py index 982565ab0..d987d1210 100644 --- a/python/flink_agents/api/chat_models/tests/test_token_metrics.py +++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py @@ -100,7 +100,7 @@ class TestBaseChatModelTokenMetrics: def test_record_token_metrics_with_metric_group(self) -> None: """Test token metrics are recorded when metric group is set.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") mock_metric_group = _MockMetricGroup() # Set the metric group @@ -116,7 +116,7 @@ def test_record_token_metrics_with_metric_group(self) -> None: def test_record_token_metrics_without_metric_group(self) -> None: """Test token metrics are not recorded when metric group is null.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") # Do not set metric group (should be None by default) # Record token metrics - should not throw @@ -125,7 +125,7 @@ def test_record_token_metrics_without_metric_group(self) -> None: def test_token_metrics_hierarchy(self) -> None: """Test token metrics hierarchy: actionMetricGroup -> modelName -> counters.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") mock_metric_group = _MockMetricGroup() # Set the metric group @@ -148,7 +148,7 @@ def test_token_metrics_hierarchy(self) -> None: def test_token_metrics_accumulation(self) -> None: """Test that token metrics accumulate across multiple calls.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") mock_metric_group = _MockMetricGroup() # Set the metric group @@ -165,12 +165,12 @@ def test_token_metrics_accumulation(self) -> None: def test_resource_type(self) -> None: """Test resource type is CHAT_MODEL_CONNECTION.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") assert chat_model.resource_type() == ResourceType.CHAT_MODEL def test_bound_metric_group_property(self) -> None: """Test bound_metric_group property.""" - chat_model = TestChatModelSetup(connection="mock") + chat_model = TestChatModelSetup(connection="mock", model="mock-model") # Initially should be None assert chat_model.metric_group is None diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py index cd4d54844..f5fefd725 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py @@ -72,6 +72,7 @@ def slow_chat_model() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}", connection="placement", + model="slow-mock-model", tools=["add"], ) diff --git a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py index 171fba24e..c077c6c8e 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py @@ -246,23 +246,19 @@ class AnthropicChatModelSetup(BaseChatModelSetup): ---------- connection : str Name of the referenced connection. (Inherited from BaseChatModelSetup) + model : str + Specifies the Anthropic model to use. Defaults to claude-sonnet-4-20250514 + when omitted via ``__init__``. (Inherited from BaseChatModelSetup) prompt : Optional[Union[Prompt, str] Prompt template or string for the model. (Inherited from BaseChatModelSetup) tools : Optional[List[str]] List of available tools to use in the chat. (Inherited from BaseChatModelSetup) - model : str - Specifies the Anthropic model to use. Defaults to claude-sonnet-4-20250514. max_tokens: int The maximum number of tokens to generate before stopping. Defaults to 1024. temperature : float Amount of randomness injected into the response. """ - model: str = Field( - default=DEFAULT_ANTHROPIC_MODEL, - description="Specifies the Anthropic model to use. Defaults to " - "claude-sonnet-4-20250514.", - ) max_tokens: int = Field( default=DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate before stopping. Defaults to 1024.", diff --git a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py index 767cb6a48..247759a23 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py @@ -24,6 +24,7 @@ from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext from flink_agents.integrations.chat_models.anthropic.anthropic_chat_model import ( + DEFAULT_ANTHROPIC_MODEL, AnthropicChatModelConnection, AnthropicChatModelSetup, ) @@ -103,3 +104,16 @@ def get_resource(name: str, type: ResourceType) -> Resource: tool_call = tool_calls[0] assert add(**tool_call["function"]["arguments"]) == 2 assert tool_call.get("original_id") is not None + + +def test_model_field_roundtrip() -> None: + """Verify `model` is preserved through pydantic dump/validate round-trip.""" + setup = AnthropicChatModelSetup(connection="conn", model="test-model") + restored = AnthropicChatModelSetup.model_validate(setup.model_dump()) + assert restored.model == "test-model" + + +def test_default_model_when_omitted() -> None: + """Verify per-integration default applies when `model` is omitted from __init__.""" + setup = AnthropicChatModelSetup(connection="conn") + assert setup.model == DEFAULT_ANTHROPIC_MODEL diff --git a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py index 64da68900..576d3ab76 100644 --- a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py @@ -167,12 +167,12 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup): ---------- connection : str Name of the referenced connection. (Inherited from BaseChatModelSetup) + model : str + Name of OpenAI model deployment on Azure. (Inherited from BaseChatModelSetup) prompt : Optional[Union[Prompt, str] Prompt template or string for the model. (Inherited from BaseChatModelSetup) tools : Optional[List[str]] List of available tools to use in the chat. (Inherited from BaseChatModelSetup) - model : str - Name of OpenAI model deployment on Azure. model_of_azure_deployment : Optional[str] The underlying model name of the Azure deployment (e.g., 'gpt-4'). Used for token counting and cost calculation. @@ -193,9 +193,6 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup): Additional kwargs for the Azure OpenAI API. """ - model: str = Field( - description="Name of OpenAI model deployment on Azure.", - ) model_of_azure_deployment: str | None = Field( default=None, description="The underlying model name of the Azure deployment (e.g., 'gpt-4', " diff --git a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py index 95172a675..30753d755 100644 --- a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py @@ -119,3 +119,10 @@ def get_resource(name: str, type: ResourceType) -> Resource: assert len(tool_calls) == 1 tool_call = tool_calls[0] assert add(**tool_call["function"]["arguments"]) == 1065 + + +def test_model_field_roundtrip() -> None: + """Verify `model` is preserved through pydantic dump/validate round-trip.""" + setup = AzureOpenAIChatModelSetup(connection="conn", model="test-deployment") + restored = AzureOpenAIChatModelSetup.model_validate(setup.model_dump()) + assert restored.model == "test-deployment" diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py index a879dcf9c..7c36ec38e 100644 --- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -176,12 +176,12 @@ class OllamaChatModelSetup(BaseChatModelSetup): ---------- connection : str Name of the referenced connection. (Inherited from BaseChatModelSetup) + model : str + Model name to use. (Inherited from BaseChatModelSetup) prompt : Optional[Union[Prompt, str] Prompt template or string for the model. (Inherited from BaseChatModelSetup) tools : Optional[List[str]] List of available tools to use in the chat. (Inherited from BaseChatModelSetup) - model : str - Model name to use. temperature : float The temperature to use for sampling. num_ctx : int @@ -196,8 +196,6 @@ class OllamaChatModelSetup(BaseChatModelSetup): stores it in additional_kwargs. """ - model: str = Field(description="Model name to use.") - temperature: float = Field( default=0.75, description="The temperature to use for sampling.", diff --git a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py index 95edb2986..2e5fe720a 100644 --- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py @@ -203,12 +203,13 @@ class OpenAIChatModelSetup(BaseChatModelSetup): ---------- connection : str Name of the referenced connection. (Inherited from BaseChatModelSetup) + model : str + The OpenAI model to use. Defaults to ``DEFAULT_OPENAI_MODEL`` when omitted via + ``__init__``. (Inherited from BaseChatModelSetup) prompt : Optional[Union[Prompt, str] Prompt template or string for the model. (Inherited from BaseChatModelSetup) tools : Optional[List[str]] List of available tools to use in the chat. (Inherited from BaseChatModelSetup) - model : str - The OpenAI model to use. temperature : float The temperature to use during generation. max_tokens : Optional[int] @@ -225,9 +226,6 @@ class OpenAIChatModelSetup(BaseChatModelSetup): The effort to use for reasoning models. """ - model: str = Field( - default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use." - ) temperature: float = Field( default=DEFAULT_TEMPERATURE, description="The temperature to use during generation.", diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py index 3c893ecbf..2280bf44d 100644 --- a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py @@ -24,6 +24,7 @@ from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext from flink_agents.integrations.chat_models.openai.openai_chat_model import ( + DEFAULT_OPENAI_MODEL, OpenAIChatModelConnection, OpenAIChatModelSetup, ) @@ -104,3 +105,16 @@ def get_resource(name: str, type: ResourceType) -> Resource: assert len(tool_calls) == 1 tool_call = tool_calls[0] assert add(**tool_call["function"]["arguments"]) == 1065 + + +def test_model_field_roundtrip() -> None: + """Verify `model` is preserved through pydantic dump/validate round-trip.""" + setup = OpenAIChatModelSetup(connection="conn", model="test-model") + restored = OpenAIChatModelSetup.model_validate(setup.model_dump()) + assert restored.model == "test-model" + + +def test_default_model_when_omitted() -> None: + """Verify per-integration default applies when `model` is omitted from __init__.""" + setup = OpenAIChatModelSetup(connection="conn") + assert setup.model == DEFAULT_OPENAI_MODEL diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py index e56fc0fd5..5503185bf 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -133,6 +133,13 @@ def get_resource(name: str, type: ResourceType) -> Resource: assert add(**tool_call["function"]["arguments"]) == 3 +def test_model_field_roundtrip() -> None: + """Verify `model` is preserved through pydantic dump/validate round-trip.""" + setup = OllamaChatModelSetup(connection="conn", model="test-model") + restored = OllamaChatModelSetup.model_validate(setup.model_dump()) + assert restored.model == "test-model" + + def test_extract_think_tags() -> None: """Test the static method that extracts content from tags.""" # Test with a think tag at the beginning (most common case) diff --git a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py index b6be74043..fc80ac6a0 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py @@ -25,6 +25,7 @@ from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext from flink_agents.integrations.chat_models.tongyi_chat_model import ( + DEFAULT_MODEL, TongyiChatModelConnection, TongyiChatModelSetup, ) @@ -175,3 +176,16 @@ def get_resource(name: str, type: ResourceType) -> Resource: assert "reasoning" in response.extra_args assert "philosophical perspectives" in response.extra_args["reasoning"] assert "Hitchhiker's Guide to the Galaxy" in response.extra_args["reasoning"] + + +def test_model_field_roundtrip() -> None: + """Verify `model` is preserved through pydantic dump/validate round-trip.""" + setup = TongyiChatModelSetup(connection="conn", model="test-model") + restored = TongyiChatModelSetup.model_validate(setup.model_dump()) + assert restored.model == "test-model" + + +def test_default_model_when_omitted() -> None: + """Verify per-integration default applies when `model` is omitted from __init__.""" + setup = TongyiChatModelSetup(connection="conn") + assert setup.model == DEFAULT_MODEL diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py index 7d37f3fc6..6587a8cba 100644 --- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py @@ -219,12 +219,13 @@ class TongyiChatModelSetup(BaseChatModelSetup): ---------- connection : str Name of the referenced connection. (Inherited from BaseChatModelSetup) + model : str + Model name to use. Defaults to ``DEFAULT_MODEL`` when omitted via + ``__init__``. (Inherited from BaseChatModelSetup) prompt : Optional[Union[Prompt, str] Prompt template or string for the model. (Inherited from BaseChatModelSetup) tools : Optional[List[str]] List of available tools to use in the chat. (Inherited from BaseChatModelSetup) - model : str - Model name to use. temperature : float The temperature to use for sampling. additional_kwargs : Dict[str, Any] @@ -234,7 +235,6 @@ class TongyiChatModelSetup(BaseChatModelSetup): in additional_kwargs. """ - model: str = Field(default=DEFAULT_MODEL, description="Model name to use.") temperature: float = Field( default=0.7, description="The temperature to use for sampling.", diff --git a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py index 122f8ad0d..5080cefbe 100644 --- a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py +++ b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py @@ -70,6 +70,8 @@ def chat_model() -> ResourceDescriptor: name="chat_model", prompt="prompt", tools=["add"], + connection="mock_connection", + model="mock-model", ) @tool diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json b/python/flink_agents/plan/tests/resources/agent_plan.json index ad0a58757..9f9a3f416 100644 --- a/python/flink_agents/plan/tests/resources/agent_plan.json +++ b/python/flink_agents/plan/tests/resources/agent_plan.json @@ -95,7 +95,8 @@ "arguments": { "host": "8.8.8.8", "desc": "mock resource just for testing.", - "connection": "mock" + "connection": "mock", + "model": "mock-model" } }, "__resource_provider_type__": "PythonResourceProvider" diff --git a/python/flink_agents/plan/tests/test_agent_plan.py b/python/flink_agents/plan/tests/test_agent_plan.py index a58289ea5..001b1d4de 100644 --- a/python/flink_agents/plan/tests/test_agent_plan.py +++ b/python/flink_agents/plan/tests/test_agent_plan.py @@ -205,6 +205,7 @@ def mock() -> ResourceDescriptor: host="8.8.8.8", desc="mock resource just for testing.", connection="mock", + model="mock-model", ) @embedding_model_connection @@ -298,6 +299,7 @@ def test_add_action_and_resource_to_agent() -> None: host="8.8.8.8", desc="mock resource just for testing.", connection="mock", + model="mock-model", ), ) diff --git a/python/flink_agents/runtime/java/java_chat_model.py b/python/flink_agents/runtime/java/java_chat_model.py index 28ca408d1..7160d5315 100644 --- a/python/flink_agents/runtime/java/java_chat_model.py +++ b/python/flink_agents/runtime/java/java_chat_model.py @@ -106,9 +106,10 @@ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> N j_resource_adapter: The Java resource adapter for method invocation **kwargs: Additional keyword arguments """ - # connection is a required parameter for BaseChatModelSetup + # connection and model are required parameters for BaseChatModelSetup connection = kwargs.pop("connection", "") - super().__init__(connection=connection, **kwargs) + model = kwargs.pop("model", "") + super().__init__(connection=connection, model=model, **kwargs) self._j_resource = j_resource self._j_resource_adapter = j_resource_adapter diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py index 09178dd31..05a918813 100644 --- a/python/flink_agents/runtime/tests/test_built_in_actions.py +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -144,6 +144,7 @@ def mock_chat_model() -> ResourceDescriptor: return ResourceDescriptor( clazz=f"{MockChatModel.__module__}.{MockChatModel.__name__}", connection="mock_connection", + model="mock-model", prompt="prompt", tools=["add"], ) diff --git a/python/flink_agents/runtime/tests/test_get_resource_in_action.py b/python/flink_agents/runtime/tests/test_get_resource_in_action.py index 1953645d6..f1f1e8e15 100644 --- a/python/flink_agents/runtime/tests/test_get_resource_in_action.py +++ b/python/flink_agents/runtime/tests/test_get_resource_in_action.py @@ -54,6 +54,7 @@ def mock_chat_model() -> ResourceDescriptor: host="8.8.8.8", desc="mock chat model just for testing.", connection="mock", + model="mock-model", ) @tool