diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java index d6dfa6976..17cbe1777 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java @@ -83,6 +83,12 @@ public static final class ChatModel { public static final String OPENAI_RESPONSES_SETUP = "org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelSetup"; + // Azure OpenAI + public static final String AZURE_OPENAI_CONNECTION = + "org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelConnection"; + public static final String AZURE_OPENAI_SETUP = + "org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelSetup"; + // Python Wrapper public static final String PYTHON_WRAPPER_CONNECTION = "org.apache.flink.agents.api.chat.model.python.PythonChatModelConnection"; diff --git a/docs/content/docs/development/chat_models.md b/docs/content/docs/development/chat_models.md index 99ac9d7e5..f4a66cd0c 100644 --- a/docs/content/docs/development/chat_models.md +++ b/docs/content/docs/development/chat_models.md @@ -390,10 +390,6 @@ Model availability and specifications may change. Always check the official Azur Azure OpenAI provides access to OpenAI models (GPT-4, GPT-4o, etc.) through Azure's cloud infrastructure, using the same OpenAI SDK with Azure-specific authentication and endpoints. This offers enterprise security, compliance, and regional availability while using familiar OpenAI APIs. -{{< hint info >}} -Azure OpenAI is only supported in Python currently. To use Azure OpenAI from Java agents, see [Using Cross-Language Providers](#using-cross-language-providers). -{{< /hint >}} - {{< hint warning >}} **Azure OpenAI vs Azure AI:** Azure OpenAI uses the OpenAI SDK to access OpenAI models (GPT-4, etc.) hosted on Azure. If you want to use other models like Llama, Mistral, or Phi deployed via Azure AI Studio, see [Azure AI](#azure-ai) instead. {{< /hint >}} @@ -420,6 +416,19 @@ Azure OpenAI is only supported in Python currently. To use Azure OpenAI from Jav {{< /tab >}} +{{< tab "Java" >}} + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `api_key` | String | Required | Azure OpenAI API key for authentication | +| `api_version` | String | Required | Azure OpenAI REST API version (e.g., "2024-02-01"). See [API versions](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning) | +| `azure_endpoint` | String | Required | Azure OpenAI endpoint URL (e.g., `https://{resource-name}.openai.azure.com`) — either a direct Azure resource or a proxy/gateway URL that fronts an Azure OpenAI service | +| `timeout` | int | None | Timeout in seconds for API requests; must be greater than 0, otherwise ignored (SDK default applies) | +| `max_retries` | int | None | Maximum number of API retry attempts; must be non-negative, otherwise ignored (SDK default applies) | +| `azure_url_path_mode` | String | `"AUTO"` | Controls how the SDK constructs Azure OpenAI request URLs. One of `"AUTO"`, `"LEGACY"`, or `"UNIFIED"`. Custom gateways that proxy Azure OpenAI typically need `"LEGACY"` to force the `/openai/deployments/{model}` path | + +{{< /tab >}} + {{< /tabs >}} #### AzureOpenAIChatModelSetup Parameters @@ -442,6 +451,22 @@ Azure OpenAI is only supported in Python currently. To use Azure OpenAI from Jav {{< /tab >}} +{{< tab "Java" >}} + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `connection` | String | Required | Reference to connection method name | +| `model` | String | Required | Azure deployment name (not the underlying OpenAI model name) | +| `model_of_azure_deployment` | String | None | The underlying model name (e.g., 'gpt-4', 'gpt-4o'). Used solely for token metrics tracking | +| `prompt` | Prompt \| String | None | Prompt template or reference to prompt resource | +| `tools` | List | None | List of tool names available to the model | +| `temperature` | double | None | Sampling temperature (0.0 to 2.0). Not supported by reasoning models | +| `max_tokens` | int | None | Maximum number of tokens to generate (must be greater than 0) | +| `logprobs` | boolean | `false` | Whether to return log probabilities of output tokens | +| `additional_kwargs` | Map | `{}` | Additional Azure OpenAI API parameters (forwarded to the OpenAI request body) | + +{{< /tab >}} + {{< /tabs >}} #### Usage Example @@ -477,6 +502,34 @@ class MyAgent(Agent): ``` {{< /tab >}} +{{< tab "Java" >}} +```java +public class MyAgent extends Agent { + @ChatModelConnection + public static ResourceDescriptor azureOpenAIConnection() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_CONNECTION) + .addInitialArgument("api_key", "") + .addInitialArgument("api_version", "2024-02-01") + .addInitialArgument("azure_endpoint", "https://your-resource.openai.azure.com") + .build(); + } + + @ChatModelSetup + public static ResourceDescriptor azureOpenAIChatModel() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_SETUP) + .addInitialArgument("connection", "azureOpenAIConnection") + .addInitialArgument("model", "my-gpt4-deployment") // Your Azure deployment name + .addInitialArgument("model_of_azure_deployment", "gpt-4") // Underlying model for metrics + .addInitialArgument("temperature", 0.3d) + .addInitialArgument("max_tokens", 1000) + .build(); + } + + ... +} +``` +{{< /tab >}} + {{< /tabs >}} #### Available Models diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java index 2c56b10a4..4492a8f48 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java @@ -91,6 +91,16 @@ public static ResourceDescriptor chatModelConnection() { ResourceName.ChatModel.OPENAI_RESPONSES_CONNECTION) .addInitialArgument("api_key", System.getenv().get("OPENAI_API_KEY")) .build(); + } else if (provider.equals("AZURE_OPENAI")) { + return ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.AZURE_OPENAI_CONNECTION) + .addInitialArgument("api_key", System.getenv().get("AZURE_OPENAI_API_KEY")) + .addInitialArgument( + "api_version", System.getenv().get("AZURE_OPENAI_API_VERSION")) + .addInitialArgument( + "azure_endpoint", System.getenv().get("AZURE_OPENAI_ENDPOINT")) + .addInitialArgument("azure_url_path_mode", "LEGACY") + .build(); } else if (provider.equals("ANTHROPIC")) { String apiKey = System.getenv().get("ANTHROPIC_API_KEY"); return ResourceDescriptor.Builder.newBuilder( @@ -150,6 +160,14 @@ public static ResourceDescriptor chatModel() { "tools", List.of("calculateBMI", "convertTemperature", "createRandomNumber")) .build(); + } else if (provider.equals("AZURE_OPENAI")) { + return ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_SETUP) + .addInitialArgument("connection", "chatModelConnection") + .addInitialArgument("model", System.getenv().get("AZURE_OPENAI_DEPLOYMENT")) + .addInitialArgument( + "tools", + List.of("calculateBMI", "convertTemperature", "createRandomNumber")) + .build(); } else { throw new RuntimeException(String.format("Unknown model provider %s", provider)); } diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java index 75a3d5c1e..de967858f 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java @@ -27,8 +27,6 @@ import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; @@ -41,7 +39,6 @@ * prompts. */ public class ChatModelIntegrationTest extends OllamaPreparationUtils { - private static final Logger LOG = LoggerFactory.getLogger(ChatModelIntegrationTest.class); private static final String API_KEY = "_API_KEY"; private static final String OLLAMA = "OLLAMA"; @@ -53,7 +50,15 @@ public ChatModelIntegrationTest() throws IOException { } @ParameterizedTest() - @ValueSource(strings = {"ANTHROPIC", "AZURE", "OLLAMA", "OPENAI", "OPENAI_RESPONSES"}) + @ValueSource( + strings = { + "ANTHROPIC", + "AZURE", + "AZURE_OPENAI", + "OLLAMA", + "OPENAI", + "OPENAI_RESPONSES" + }) public void testChatModeIntegration(String provider) throws Exception { Assumptions.assumeTrue( (OLLAMA.equals(provider) && ollamaReady) diff --git a/integrations/chat-models/openai/pom.xml b/integrations/chat-models/openai/pom.xml index e1c31cb5d..ba0c6ce1b 100644 --- a/integrations/chat-models/openai/pom.xml +++ b/integrations/chat-models/openai/pom.xml @@ -43,6 +43,12 @@ under the License. openai-java ${openai.version} + + + org.slf4j + slf4j-api + ${slf4j.version} + diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java new file mode 100644 index 000000000..7d6b5c2c1 --- /dev/null +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.openai; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.azure.AzureOpenAIServiceVersion; +import com.openai.azure.AzureUrlPathMode; +import com.openai.azure.credential.AzureApiKeyCredential; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonValue; +import com.openai.models.ChatModel; +import com.openai.models.FunctionDefinition; +import com.openai.models.FunctionParameters; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionFunctionTool; +import com.openai.models.chat.completions.ChatCompletionTool; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Chat model integration for Azure OpenAI Service. Built on the openai-java SDK using its built-in + * Azure support ({@link AzureOpenAIServiceVersion}, {@link AzureApiKeyCredential}). + * + *

Required connection arguments: + * + *

    + *
  • api_key: Azure OpenAI API key + *
  • api_version: Azure OpenAI REST API version (e.g., {@code "2024-02-01"}) + *
  • azure_endpoint: base URL for the Azure OpenAI deployment — either a direct Azure + * resource (e.g., {@code "https://your-resource.openai.azure.com"}) or a proxy/gateway URL + * that fronts an Azure OpenAI service. Custom gateway hostnames also require setting {@code + * azure_url_path_mode} below. + *
+ * + *

Optional connection arguments: + * + *

    + *
  • timeout (Number): seconds before an API call times out; must be greater than 0, + * otherwise ignored (SDK default applies) + *
  • max_retries (Number): retry attempts on failure; must be non-negative, otherwise + * ignored (SDK default applies) + *
  • azure_url_path_mode (String): one of {@code "AUTO"}, {@code "LEGACY"}, or {@code + * "UNIFIED"} (default {@code "AUTO"}). Controls how the SDK constructs Azure OpenAI request + * URLs. In {@code AUTO} mode the SDK only treats the endpoint as Azure when its hostname + * matches a known suffix (e.g. {@code .openai.azure.com}); custom gateways that proxy Azure + * OpenAI need {@code LEGACY} to force the {@code /openai/deployments/{model}} path. + *
+ * + *

Example usage: + * + *

{@code
+ * @ChatModelConnection
+ * public static ResourceDescriptor azureOpenAIConnection() {
+ *   return ResourceDescriptor.Builder.newBuilder(
+ *               AzureOpenAIChatModelConnection.class.getName())
+ *           .addInitialArgument("api_key", System.getenv("AZURE_OPENAI_API_KEY"))
+ *           .addInitialArgument("api_version", "2024-02-01")
+ *           .addInitialArgument("azure_endpoint", "https://my-resource.openai.azure.com")
+ *           .build();
+ * }
+ * }
+ */ +public class AzureOpenAIChatModelConnection extends BaseChatModelConnection { + + private static final ObjectMapper mapper = new ObjectMapper(); + + private static final Set RESERVED_KWARG_KEYS = + Set.of("model", "model_of_azure_deployment", "temperature", "max_tokens", "logprobs"); + + private final OpenAIClient client; + + public AzureOpenAIChatModelConnection( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + String apiKey = descriptor.getArgument("api_key"); + if (apiKey == null || apiKey.isBlank()) { + throw new IllegalArgumentException("api_key should not be null or empty."); + } + + String apiVersion = descriptor.getArgument("api_version"); + if (apiVersion == null || apiVersion.isBlank()) { + throw new IllegalArgumentException("api_version should not be null or empty."); + } + + String azureEndpoint = descriptor.getArgument("azure_endpoint"); + if (azureEndpoint == null || azureEndpoint.isBlank()) { + throw new IllegalArgumentException("azure_endpoint should not be null or empty."); + } + + OpenAIOkHttpClient.Builder clientBuilder = + OpenAIOkHttpClient.builder() + .baseUrl(azureEndpoint) + .credential(AzureApiKeyCredential.create(apiKey)) + .azureServiceVersion(AzureOpenAIServiceVersion.fromString(apiVersion)); + + Integer timeoutSeconds = descriptor.getArgument("timeout"); + if (timeoutSeconds != null && timeoutSeconds > 0) { + clientBuilder.timeout(Duration.ofSeconds(timeoutSeconds)); + } + + Integer maxRetries = descriptor.getArgument("max_retries"); + if (maxRetries != null && maxRetries >= 0) { + clientBuilder.maxRetries(maxRetries); + } + + String azureUrlPathMode = descriptor.getArgument("azure_url_path_mode"); + if (azureUrlPathMode != null && !azureUrlPathMode.isBlank()) { + try { + clientBuilder.azureUrlPathMode( + AzureUrlPathMode.valueOf(azureUrlPathMode.trim().toUpperCase())); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "azure_url_path_mode must be one of AUTO, LEGACY, or UNIFIED; got: " + + azureUrlPathMode, + e); + } + } + + this.client = clientBuilder.build(); + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + try { + Map mutableArgs = + arguments != null ? new HashMap<>(arguments) : new HashMap<>(); + + String azureDeployment = (String) mutableArgs.remove("model"); + if (azureDeployment == null || azureDeployment.isBlank()) { + throw new IllegalArgumentException("model is required for Azure OpenAI API calls"); + } + String modelOfAzureDeployment = + (String) mutableArgs.remove("model_of_azure_deployment"); + + ChatCompletionCreateParams.Builder builder = + ChatCompletionCreateParams.builder() + .model(ChatModel.of(azureDeployment)) + .messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages)); + + if (tools != null && !tools.isEmpty()) { + builder.tools(convertTools(tools)); + } + + Object temperature = mutableArgs.remove("temperature"); + if (temperature instanceof Number) { + builder.temperature(((Number) temperature).doubleValue()); + } + + Object maxTokens = mutableArgs.remove("max_tokens"); + if (maxTokens instanceof Number) { + builder.maxCompletionTokens(((Number) maxTokens).longValue()); + } + + Object logprobs = mutableArgs.remove("logprobs"); + if (Boolean.TRUE.equals(logprobs)) { + builder.logprobs(true); + } + + @SuppressWarnings("unchecked") + Map additionalKwargs = + (Map) mutableArgs.remove("additional_kwargs"); + if (additionalKwargs != null) { + Set collisions = new HashSet<>(additionalKwargs.keySet()); + collisions.retainAll(RESERVED_KWARG_KEYS); + if (!collisions.isEmpty()) { + throw new IllegalArgumentException( + "additional_kwargs must not contain reserved typed fields: " + + collisions + + ". Set these via the corresponding Setup field instead."); + } + for (Map.Entry entry : additionalKwargs.entrySet()) { + builder.putAdditionalBodyProperty( + entry.getKey(), toJsonValue(entry.getValue())); + } + } + + ChatCompletion completion = client.chat().completions().create(builder.build()); + + ChatMessage response = + OpenAIChatCompletionsUtils.convertFromOpenAIMessage( + completion.choices().get(0).message()); + + if (modelOfAzureDeployment != null + && !modelOfAzureDeployment.isBlank() + && completion.usage().isPresent()) { + recordTokenMetrics( + modelOfAzureDeployment, + completion.usage().get().promptTokens(), + completion.usage().get().completionTokens()); + } + + return response; + } catch (IllegalArgumentException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException("Failed to call Azure OpenAI chat completions API.", e); + } + } + + @Override + public void close() throws Exception { + this.client.close(); + } + + private List convertTools(List tools) { + List openaiTools = new ArrayList<>(tools.size()); + for (Tool tool : tools) { + ToolMetadata metadata = tool.getMetadata(); + FunctionDefinition.Builder functionBuilder = + FunctionDefinition.builder() + .name(metadata.getName()) + .description(metadata.getDescription()); + + String schema = metadata.getInputSchema(); + if (schema != null && !schema.isBlank()) { + functionBuilder.parameters(parseFunctionParameters(schema)); + } + + ChatCompletionFunctionTool functionTool = + ChatCompletionFunctionTool.builder() + .function(functionBuilder.build()) + .type(JsonValue.from("function")) + .build(); + + openaiTools.add(ChatCompletionTool.ofFunction(functionTool)); + } + return openaiTools; + } + + private FunctionParameters parseFunctionParameters(String schemaJson) { + try { + JsonNode root = mapper.readTree(schemaJson); + if (root == null || !root.isObject()) { + return FunctionParameters.builder().build(); + } + FunctionParameters.Builder builder = FunctionParameters.builder(); + root.fields() + .forEachRemaining( + entry -> + builder.putAdditionalProperty( + entry.getKey(), + JsonValue.fromJsonNode(entry.getValue()))); + return builder.build(); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse tool schema JSON.", e); + } + } + + private JsonValue toJsonValue(Object value) { + if (value instanceof JsonValue) { + return (JsonValue) value; + } + if (value instanceof String + || value instanceof Number + || value instanceof Boolean + || value == null) { + return JsonValue.from(value); + } + return JsonValue.fromJsonNode(mapper.valueToTree(value)); + } +} diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java new file mode 100644 index 000000000..44a7c8431 --- /dev/null +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.openai; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * Setup for Azure OpenAI Chat Completions. + * + *

{@code model} (inherited from {@link BaseChatModelSetup}) is the Azure deployment name, not + * the underlying OpenAI model name. The underlying model name can be supplied via {@code + * model_of_azure_deployment} and is used solely for token-metrics tracking. + * + *

Example usage: + * + *

{@code
+ * @ChatModelSetup
+ * public static ResourceDescriptor azureOpenAIModel() {
+ *   return ResourceDescriptor.Builder.newBuilder(AzureOpenAIChatModelSetup.class.getName())
+ *           .addInitialArgument("connection", "myAzureOpenAIConnection")
+ *           .addInitialArgument("model", "my-gpt4o-deployment")
+ *           .addInitialArgument("model_of_azure_deployment", "gpt-4o")
+ *           .addInitialArgument("temperature", 0.3d)
+ *           .addInitialArgument("max_tokens", 500)
+ *           .build();
+ * }
+ * }
+ */ +public class AzureOpenAIChatModelSetup extends BaseChatModelSetup { + + private static final Logger LOG = LoggerFactory.getLogger(AzureOpenAIChatModelSetup.class); + + private final String modelOfAzureDeployment; + private final Double temperature; + private final Integer maxTokens; + private final Boolean logprobs; + private final Map additionalKwargs; + + public AzureOpenAIChatModelSetup( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + this.modelOfAzureDeployment = descriptor.getArgument("model_of_azure_deployment"); + if (this.modelOfAzureDeployment == null || this.modelOfAzureDeployment.isBlank()) { + LOG.warn( + "model_of_azure_deployment is not set; token usage metrics will not be recorded for this Azure OpenAI deployment '{}'.", + this.model); + } + + this.temperature = + Optional.ofNullable(descriptor.getArgument("temperature")) + .map(Number::doubleValue) + .orElse(null); + if (this.temperature != null && (this.temperature < 0.0 || this.temperature > 2.0)) { + throw new IllegalArgumentException("temperature must be between 0.0 and 2.0"); + } + + this.maxTokens = + Optional.ofNullable(descriptor.getArgument("max_tokens")) + .map(Number::intValue) + .orElse(null); + if (this.maxTokens != null && this.maxTokens <= 0) { + throw new IllegalArgumentException("max_tokens must be greater than 0"); + } + + this.logprobs = + Optional.ofNullable(descriptor.getArgument("logprobs")).orElse(false); + + Map additional = + Optional.ofNullable( + descriptor.>getArgument("additional_kwargs")) + .map(HashMap::new) + .orElseGet(HashMap::new); + this.additionalKwargs = additional; + } + + @Override + public Map getParameters() { + Map params = new HashMap<>(); + if (model != null) { + params.put("model", model); + } + if (modelOfAzureDeployment != null) { + params.put("model_of_azure_deployment", modelOfAzureDeployment); + } + params.put("logprobs", logprobs); + if (temperature != null) { + params.put("temperature", temperature); + } + if (maxTokens != null) { + params.put("max_tokens", maxTokens); + } + if (additionalKwargs != null && !additionalKwargs.isEmpty()) { + params.put("additional_kwargs", additionalKwargs); + } + return params; + } +} diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java new file mode 100644 index 000000000..c9d8c8d9b --- /dev/null +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.openai; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.core.JsonValue; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessage; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Static helpers for converting between Flink Agents {@link ChatMessage} and OpenAI Chat + * Completions API message types. Restricted to message conversion (no tool-definition conversion — + * that stays per-connection). + * + *

Used by both {@code OpenAICompletionsConnection} (OpenAI / OpenAI-compatible providers) and + * {@code AzureOpenAIChatModelConnection} (Azure OpenAI). Both rely on the same openai-java SDK + * message types. + */ +final class OpenAIChatCompletionsUtils { + + private OpenAIChatCompletionsUtils() {} + + private static final ObjectMapper mapper = new ObjectMapper(); + private static final TypeReference> MAP_TYPE = new TypeReference<>() {}; + + /** Convert a list of Flink Agents ChatMessages to OpenAI ChatCompletionMessageParams. */ + public static List convertToOpenAIMessages( + List messages) { + return messages.stream() + .map(OpenAIChatCompletionsUtils::convertToOpenAIMessage) + .collect(Collectors.toList()); + } + + /** Convert a single Flink Agents ChatMessage to an OpenAI ChatCompletionMessageParam. */ + public static ChatCompletionMessageParam convertToOpenAIMessage(ChatMessage message) { + MessageRole role = message.getRole(); + String content = Optional.ofNullable(message.getContent()).orElse(""); + + switch (role) { + case SYSTEM: + return ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(content).build()); + case USER: + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(content).build()); + case ASSISTANT: + ChatCompletionAssistantMessageParam.Builder assistantBuilder = + ChatCompletionAssistantMessageParam.builder(); + if (!content.isEmpty()) { + assistantBuilder.content(content); + } + List> toolCalls = message.getToolCalls(); + if (toolCalls != null && !toolCalls.isEmpty()) { + assistantBuilder.toolCalls(convertAssistantToolCalls(toolCalls)); + } + Object refusal = message.getExtraArgs().get("refusal"); + if (refusal instanceof String) { + assistantBuilder.refusal((String) refusal); + } + return ChatCompletionMessageParam.ofAssistant(assistantBuilder.build()); + case TOOL: + ChatCompletionToolMessageParam.Builder toolBuilder = + ChatCompletionToolMessageParam.builder().content(content); + Object toolCallId = message.getExtraArgs().get("externalId"); + if (toolCallId == null) { + throw new IllegalArgumentException( + "Tool message must have an externalId in extraArgs."); + } + toolBuilder.toolCallId(toolCallId.toString()); + return ChatCompletionMessageParam.ofTool(toolBuilder.build()); + default: + throw new IllegalArgumentException("Unsupported role: " + role); + } + } + + /** + * Convert an OpenAI {@link ChatCompletionMessage} to a Flink Agents {@link ChatMessage}. {@code + * message.refusal()} is written as {@code extraArgs["refusal"]} on the returned ChatMessage + * when present, preserving prior Java behavior. + */ + public static ChatMessage convertFromOpenAIMessage(ChatCompletionMessage message) { + String content = message.content().orElse(""); + ChatMessage response = ChatMessage.assistant(content); + + message.refusal().ifPresent(refusal -> response.getExtraArgs().put("refusal", refusal)); + + List toolCalls = message.toolCalls().orElse(List.of()); + if (!toolCalls.isEmpty()) { + response.setToolCalls(convertResponseToolCalls(toolCalls)); + } + return response; + } + + private static List convertAssistantToolCalls( + List> toolCalls) { + List result = new ArrayList<>(toolCalls.size()); + for (Map call : toolCalls) { + Object type = call.getOrDefault("type", "function"); + if (!"function".equals(String.valueOf(type))) { + continue; + } + + Map functionPayload = toMap(call.get("function")); + ChatCompletionMessageFunctionToolCall.Function.Builder functionBuilder = + ChatCompletionMessageFunctionToolCall.Function.builder(); + + Object functionName = functionPayload.get("name"); + if (functionName != null) { + functionBuilder.name(functionName.toString()); + } + + Object arguments = functionPayload.get("arguments"); + functionBuilder.arguments(serializeArguments(arguments)); + + Object idObj = call.get("id"); + if (idObj == null) { + throw new IllegalArgumentException("Tool call must have an id."); + } + String toolCallId = idObj.toString(); + + ChatCompletionMessageFunctionToolCall.Builder toolCallBuilder = + ChatCompletionMessageFunctionToolCall.builder() + .id(toolCallId) + .function(functionBuilder.build()) + .type(JsonValue.from(String.valueOf(type))); + + result.add(ChatCompletionMessageToolCall.ofFunction(toolCallBuilder.build())); + } + return result; + } + + private static List> convertResponseToolCalls( + List toolCalls) { + List> result = new ArrayList<>(toolCalls.size()); + for (ChatCompletionMessageToolCall toolCall : toolCalls) { + if (!toolCall.isFunction()) { + continue; + } + + ChatCompletionMessageFunctionToolCall functionToolCall = toolCall.asFunction(); + Map callMap = new LinkedHashMap<>(); + String toolCallId = functionToolCall.id(); + if (toolCallId == null || toolCallId.isBlank()) { + throw new IllegalStateException("OpenAI tool call ID is null or empty."); + } + + callMap.put("id", toolCallId); + callMap.put("type", "function"); + + ChatCompletionMessageFunctionToolCall.Function function = functionToolCall.function(); + Map functionMap = new LinkedHashMap<>(); + functionMap.put("name", function.name()); + functionMap.put("arguments", parseArguments(function.arguments())); + callMap.put("function", functionMap); + callMap.put("original_id", toolCallId); + result.add(callMap); + } + return result; + } + + private static Map parseArguments(String arguments) { + if (arguments == null || arguments.isBlank()) { + return Map.of(); + } + try { + return mapper.readValue(arguments, MAP_TYPE); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse tool arguments: " + arguments, e); + } + } + + private static String serializeArguments(Object arguments) { + if (arguments == null) { + return "{}"; + } + if (arguments instanceof String) { + return (String) arguments; + } + try { + return mapper.writeValueAsString(arguments); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize tool call arguments.", e); + } + } + + private static Map toMap(Object value) { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map casted = (Map) value; + return new LinkedHashMap<>(casted); + } + if (value == null) { + return new LinkedHashMap<>(); + } + return mapper.convertValue(value, MAP_TYPE); + } +} diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java index 153073075..e4947e8f3 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java @@ -18,7 +18,6 @@ package org.apache.flink.agents.integrations.chatmodels.openai; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.openai.client.OpenAIClient; @@ -29,19 +28,10 @@ import com.openai.models.FunctionParameters; import com.openai.models.ReasoningEffort; import com.openai.models.chat.completions.ChatCompletion; -import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.ChatCompletionFunctionTool; -import com.openai.models.chat.completions.ChatCompletionMessage; -import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; -import com.openai.models.chat.completions.ChatCompletionMessageParam; -import com.openai.models.chat.completions.ChatCompletionMessageToolCall; -import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; import com.openai.models.chat.completions.ChatCompletionTool; -import com.openai.models.chat.completions.ChatCompletionToolMessageParam; -import com.openai.models.chat.completions.ChatCompletionUserMessageParam; import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -51,11 +41,8 @@ import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; /** * A chat model integration for the OpenAI Chat Completions service using the official Java SDK. @@ -91,9 +78,7 @@ */ public class OpenAICompletionsConnection extends BaseChatModelConnection { - private static final TypeReference> MAP_TYPE = new TypeReference<>() {}; - - private final ObjectMapper mapper = new ObjectMapper(); + private static final ObjectMapper mapper = new ObjectMapper(); private final OpenAIClient client; private final String defaultModel; @@ -140,7 +125,9 @@ public ChatMessage chat( try { ChatCompletionCreateParams params = buildRequest(messages, tools, arguments); ChatCompletion completion = client.chat().completions().create(params); - ChatMessage response = convertResponse(completion); + ChatMessage response = + OpenAIChatCompletionsUtils.convertFromOpenAIMessage( + completion.choices().get(0).message()); // Record token metrics if (completion.usage().isPresent()) { @@ -176,10 +163,7 @@ private ChatCompletionCreateParams buildRequest( ChatCompletionCreateParams.Builder builder = ChatCompletionCreateParams.builder() .model(ChatModel.of(modelName)) - .messages( - messages.stream() - .map(this::convertToOpenAIMessage) - .collect(Collectors.toList())); + .messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages)); if (tools != null && !tools.isEmpty()) { builder.tools(convertTools(tools, strictMode)); @@ -272,145 +256,6 @@ private FunctionParameters parseFunctionParameters(String schemaJson) { } } - private ChatCompletionMessageParam convertToOpenAIMessage(ChatMessage message) { - MessageRole role = message.getRole(); - String content = Optional.ofNullable(message.getContent()).orElse(""); - - switch (role) { - case SYSTEM: - return ChatCompletionMessageParam.ofSystem( - ChatCompletionSystemMessageParam.builder().content(content).build()); - case USER: - return ChatCompletionMessageParam.ofUser( - ChatCompletionUserMessageParam.builder().content(content).build()); - case ASSISTANT: - ChatCompletionAssistantMessageParam.Builder assistantBuilder = - ChatCompletionAssistantMessageParam.builder(); - if (!content.isEmpty()) { - assistantBuilder.content(content); - } - List> toolCalls = message.getToolCalls(); - if (toolCalls != null && !toolCalls.isEmpty()) { - assistantBuilder.toolCalls(convertAssistantToolCalls(toolCalls)); - } - Object refusal = message.getExtraArgs().get("refusal"); - if (refusal instanceof String) { - assistantBuilder.refusal((String) refusal); - } - return ChatCompletionMessageParam.ofAssistant(assistantBuilder.build()); - case TOOL: - ChatCompletionToolMessageParam.Builder toolBuilder = - ChatCompletionToolMessageParam.builder().content(content); - Object toolCallId = message.getExtraArgs().get("externalId"); - if (toolCallId == null) { - throw new IllegalArgumentException( - "Tool message must have an externalId in extraArgs."); - } - toolBuilder.toolCallId(toolCallId.toString()); - return ChatCompletionMessageParam.ofTool(toolBuilder.build()); - default: - throw new IllegalArgumentException("Unsupported role: " + role); - } - } - - private List convertAssistantToolCalls( - List> toolCalls) { - List result = new ArrayList<>(toolCalls.size()); - for (Map call : toolCalls) { - Object type = call.getOrDefault("type", "function"); - if (!"function".equals(String.valueOf(type))) { - continue; - } - - Map functionPayload = toMap(call.get("function")); - ChatCompletionMessageFunctionToolCall.Function.Builder functionBuilder = - ChatCompletionMessageFunctionToolCall.Function.builder(); - - Object functionName = functionPayload.get("name"); - if (functionName != null) { - functionBuilder.name(functionName.toString()); - } - - Object arguments = functionPayload.get("arguments"); - functionBuilder.arguments(serializeArguments(arguments)); - - Object idObj = call.get("id"); - if (idObj == null) { - throw new IllegalArgumentException("Tool call must have an id."); - } - String toolCallId = idObj.toString(); - - ChatCompletionMessageFunctionToolCall.Builder toolCallBuilder = - ChatCompletionMessageFunctionToolCall.builder() - .id(toolCallId) - .function(functionBuilder.build()) - .type(JsonValue.from(String.valueOf(type))); - - result.add(ChatCompletionMessageToolCall.ofFunction(toolCallBuilder.build())); - } - return result; - } - - private ChatMessage convertResponse(ChatCompletion completion) { - List choices = completion.choices(); - if (choices.isEmpty()) { - throw new IllegalStateException("OpenAI response did not contain any choices."); - } - - ChatCompletionMessage message = choices.get(0).message(); - String content = message.content().orElse(""); - ChatMessage response = ChatMessage.assistant(content); - - message.refusal().ifPresent(refusal -> response.getExtraArgs().put("refusal", refusal)); - - List toolCalls = message.toolCalls().orElse(List.of()); - if (!toolCalls.isEmpty()) { - response.setToolCalls(convertResponseToolCalls(toolCalls)); - } - - return response; - } - - private List> convertResponseToolCalls( - List toolCalls) { - List> result = new ArrayList<>(toolCalls.size()); - for (ChatCompletionMessageToolCall toolCall : toolCalls) { - if (!toolCall.isFunction()) { - continue; - } - - ChatCompletionMessageFunctionToolCall functionToolCall = toolCall.asFunction(); - Map callMap = new LinkedHashMap<>(); - String toolCallId = functionToolCall.id(); - if (toolCallId == null || toolCallId.isBlank()) { - throw new IllegalStateException("OpenAI tool call ID is null or empty."); - } - - callMap.put("id", toolCallId); - callMap.put("type", "function"); - - ChatCompletionMessageFunctionToolCall.Function function = functionToolCall.function(); - Map functionMap = new LinkedHashMap<>(); - functionMap.put("name", function.name()); - functionMap.put("arguments", parseArguments(function.arguments())); - callMap.put("function", functionMap); - callMap.put("original_id", toolCallId); - result.add(callMap); - } - return result; - } - - private Map parseArguments(String arguments) { - if (arguments == null || arguments.isBlank()) { - return Map.of(); - } - try { - return mapper.readValue(arguments, MAP_TYPE); - } catch (JsonProcessingException e) { - throw new RuntimeException("Failed to parse tool arguments: " + arguments, e); - } - } - private JsonValue toJsonValue(Object value) { if (value instanceof JsonValue) { return (JsonValue) value; @@ -424,32 +269,6 @@ private JsonValue toJsonValue(Object value) { return JsonValue.fromJsonNode(mapper.valueToTree(value)); } - private String serializeArguments(Object arguments) { - if (arguments == null) { - return "{}"; - } - if (arguments instanceof String) { - return (String) arguments; - } - try { - return mapper.writeValueAsString(arguments); - } catch (JsonProcessingException e) { - throw new RuntimeException("Failed to serialize tool call arguments.", e); - } - } - - private Map toMap(Object value) { - if (value instanceof Map) { - @SuppressWarnings("unchecked") - Map casted = (Map) value; - return new LinkedHashMap<>(casted); - } - if (value == null) { - return new LinkedHashMap<>(); - } - return mapper.convertValue(value, MAP_TYPE); - } - @Override public void close() throws Exception { this.client.close(); diff --git a/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java new file mode 100644 index 000000000..60a29729b --- /dev/null +++ b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.openai; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link AzureOpenAIChatModelConnection} — constructor validation only, no network + * access. End-to-end tests against a real Azure OpenAI deployment live in {@link + * AzureOpenAIChatModelIT}. + */ +class AzureOpenAIChatModelConnectionTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + private static ResourceDescriptor.Builder connectionDescriptor() { + return ResourceDescriptor.Builder.newBuilder( + AzureOpenAIChatModelConnection.class.getName()); + } + + @Test + @DisplayName("Constructor throws when api_key is missing") + void testConstructorMissingApiKey() { + ResourceDescriptor desc = + connectionDescriptor() + .addInitialArgument("api_version", "2024-02-01") + .addInitialArgument("azure_endpoint", "https://example.openai.azure.com") + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc, NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("api_key"); + } + + @Test + @DisplayName("Constructor throws when api_version is missing") + void testConstructorMissingApiVersion() { + ResourceDescriptor desc = + connectionDescriptor() + .addInitialArgument("api_key", "test-key") + .addInitialArgument("azure_endpoint", "https://example.openai.azure.com") + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc, NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("api_version"); + } + + @Test + @DisplayName("Constructor throws when azure_endpoint is missing") + void testConstructorMissingAzureEndpoint() { + ResourceDescriptor desc = + connectionDescriptor() + .addInitialArgument("api_key", "test-key") + .addInitialArgument("api_version", "2024-02-01") + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc, NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("azure_endpoint"); + } + + @Test + @DisplayName("Constructor succeeds with all required args (no network call yet)") + void testConstructorAllRequiredArgs() { + ResourceDescriptor desc = + connectionDescriptor() + .addInitialArgument("api_key", "test-key") + .addInitialArgument("api_version", "2024-02-01") + .addInitialArgument("azure_endpoint", "https://example.openai.azure.com") + .build(); + AzureOpenAIChatModelConnection conn = new AzureOpenAIChatModelConnection(desc, NOOP); + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + @DisplayName("chat() rejects additional_kwargs that collide with reserved typed fields") + void testChatRejectsReservedKeyInAdditionalKwargs() { + ResourceDescriptor desc = + connectionDescriptor() + .addInitialArgument("api_key", "test-key") + .addInitialArgument("api_version", "2024-02-01") + .addInitialArgument("azure_endpoint", "https://example.openai.azure.com") + .build(); + AzureOpenAIChatModelConnection conn = new AzureOpenAIChatModelConnection(desc, NOOP); + + Map args = + Map.of( + "model", + "my-deployment", + "temperature", + 0.3d, + "additional_kwargs", + Map.of("temperature", 5.0d)); + + assertThatThrownBy( + () -> + conn.chat( + List.of(new ChatMessage(MessageRole.USER, "hi")), + null, + args)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("additional_kwargs") + .hasMessageContaining("temperature"); + } +} diff --git a/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java new file mode 100644 index 000000000..ffcb19396 --- /dev/null +++ b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.openai; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link AzureOpenAIChatModelSetup}. */ +class AzureOpenAIChatModelSetupTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + private static ResourceDescriptor.Builder descriptorBuilder() { + return ResourceDescriptor.Builder.newBuilder(AzureOpenAIChatModelSetup.class.getName()); + } + + @Test + @DisplayName("getParameters includes model and default logprobs=false") + void testGetParametersMinimal() { + ResourceDescriptor desc = + descriptorBuilder().addInitialArgument("model", "my-deployment").build(); + AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "my-deployment"); + assertThat(params).containsEntry("logprobs", false); + assertThat(params) + .doesNotContainKeys("temperature", "max_tokens", "model_of_azure_deployment"); + } + + @Test + @DisplayName("getParameters includes all explicitly-set fields") + void testGetParametersAllFields() { + ResourceDescriptor desc = + descriptorBuilder() + .addInitialArgument("model", "my-deployment") + .addInitialArgument("model_of_azure_deployment", "gpt-4o") + .addInitialArgument("temperature", 0.3d) + .addInitialArgument("max_tokens", 500) + .addInitialArgument("logprobs", true) + .build(); + AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params) + .containsEntry("model", "my-deployment") + .containsEntry("model_of_azure_deployment", "gpt-4o") + .containsEntry("temperature", 0.3d) + .containsEntry("max_tokens", 500) + .containsEntry("logprobs", true); + } + + @Test + @DisplayName("getParameters nests additional_kwargs under a dedicated key") + void testGetParametersNestsAdditionalKwargs() { + ResourceDescriptor desc = + descriptorBuilder() + .addInitialArgument("model", "my-deployment") + .addInitialArgument( + "additional_kwargs", Map.of("seed", 42, "user", "user-123")) + .build(); + AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params) + .containsEntry("model", "my-deployment") + .containsEntry("additional_kwargs", Map.of("seed", 42, "user", "user-123")) + .doesNotContainKeys("seed", "user"); + } + + @Test + @DisplayName("temperature must be in [0.0, 2.0]") + void testTemperatureValidation() { + ResourceDescriptor tooHigh = + descriptorBuilder() + .addInitialArgument("model", "m") + .addInitialArgument("temperature", 2.5d) + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(tooHigh, NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("temperature must be between 0.0 and 2.0"); + + ResourceDescriptor negative = + descriptorBuilder() + .addInitialArgument("model", "m") + .addInitialArgument("temperature", -0.1d) + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(negative, NOOP)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + @DisplayName("max_tokens must be greater than 0") + void testMaxTokensValidation() { + ResourceDescriptor zero = + descriptorBuilder() + .addInitialArgument("model", "m") + .addInitialArgument("max_tokens", 0) + .build(); + assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(zero, NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_tokens must be greater than 0"); + } + + @Test + @DisplayName("Extends BaseChatModelSetup") + void testInheritance() { + ResourceDescriptor desc = descriptorBuilder().addInitialArgument("model", "m").build(); + assertThat(new AzureOpenAIChatModelSetup(desc, NOOP)) + .isInstanceOf(BaseChatModelSetup.class); + } + + @Test + @DisplayName("model field is preserved through descriptor round-trip") + void testModelFieldRoundtrip() { + ResourceDescriptor desc = + descriptorBuilder().addInitialArgument("model", "test-deployment").build(); + AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc, NOOP); + assertThat(setup.getParameters()).containsEntry("model", "test-deployment"); + } +} diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index 3f20821e7..c8a1e88bc 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -291,6 +291,10 @@ class Java: OPENAI_RESPONSES_CONNECTION = "org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelConnection" OPENAI_RESPONSES_SETUP = "org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelSetup" + # Azure OpenAI + AZURE_OPENAI_CONNECTION = "org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelConnection" + AZURE_OPENAI_SETUP = "org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelSetup" + class EmbeddingModel: """EmbeddingModel resource names.""" diff --git a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py index 576d3ab76..18a092128 100644 --- a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import logging from typing import Any, Dict, List, Sequence from openai import NOT_GIVEN, AzureOpenAI @@ -32,6 +33,12 @@ convert_to_openai_messages, ) +logger = logging.getLogger(__name__) + +_RESERVED_KWARG_KEYS = frozenset( + {"model", "model_of_azure_deployment", "temperature", "max_tokens", "logprobs"} +) + class AzureOpenAIChatModelConnection(BaseChatModelConnection): """The connection to the Azure OpenAI LLM. @@ -139,6 +146,16 @@ def chat( msg = "model is required for Azure OpenAI API calls" raise ValueError(msg) model_of_azure_deployment = kwargs.pop("model_of_azure_deployment", None) + additional_kwargs = kwargs.pop("additional_kwargs", None) or {} + + collisions = _RESERVED_KWARG_KEYS & additional_kwargs.keys() + if collisions: + msg = ( + f"additional_kwargs must not contain reserved typed fields: " + f"{sorted(collisions)}. Set these via the corresponding " + f"Setup field instead." + ) + raise ValueError(msg) response = self.client.chat.completions.create( # Azure OpenAI APIs use Azure deployment name as the model parameter @@ -146,6 +163,7 @@ def chat( messages=convert_to_openai_messages(messages), tools=tool_specs or NOT_GIVEN, **kwargs, + **additional_kwargs, ) extra_args = {} @@ -235,6 +253,12 @@ def __init__( ) -> None: """Init method.""" additional_kwargs = additional_kwargs or {} + if not model_of_azure_deployment: + logger.warning( + "model_of_azure_deployment is not set; token usage metrics will " + "not be recorded for this Azure OpenAI deployment '%s'.", + model, + ) super().__init__( model=model, model_of_azure_deployment=model_of_azure_deployment, @@ -257,6 +281,6 @@ def model_kwargs(self) -> Dict[str, Any]: base_kwargs["temperature"] = self.temperature if self.max_tokens is not None: base_kwargs["max_tokens"] = self.max_tokens - - all_kwargs = {**base_kwargs, **self.additional_kwargs} - return all_kwargs + if self.additional_kwargs: + base_kwargs["additional_kwargs"] = self.additional_kwargs + return base_kwargs diff --git a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py index 79d2fb5c1..ce69d42ec 100644 --- a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py @@ -127,3 +127,42 @@ def test_model_field_roundtrip() -> None: setup = AzureOpenAIChatModelSetup(connection="conn", model="test-deployment") restored = AzureOpenAIChatModelSetup.model_validate(setup.model_dump()) assert restored.model == "test-deployment" + + +def test_model_kwargs_nests_additional_kwargs() -> None: + """`additional_kwargs` is nested under its own key, not flattened. + + Flattening would allow a colliding key (e.g. `temperature`) in + `additional_kwargs` to silently overwrite the field-validated value. + """ + setup = AzureOpenAIChatModelSetup( + connection="conn", + model="my-deployment", + additional_kwargs={"seed": 42, "user": "user-123"}, + ) + kwargs = setup.model_kwargs + assert kwargs["model"] == "my-deployment" + assert kwargs["additional_kwargs"] == {"seed": 42, "user": "user-123"} + assert "seed" not in kwargs + assert "user" not in kwargs + + +def test_chat_rejects_reserved_key_in_additional_kwargs() -> None: + """`additional_kwargs` containing a reserved typed key must raise. + + Without this check, `**kwargs, **additional_kwargs` would raise an opaque + TypeError, and (worse) leaves the door open for callers to bypass the + field-level validation on `temperature`, `max_tokens`, etc. + """ + connection = AzureOpenAIChatModelConnection( + api_key="fake-key", + azure_endpoint="https://example.openai.azure.com", + api_version="2024-02-01", + ) + with pytest.raises(ValueError, match="additional_kwargs"): + connection.chat( + messages=[ChatMessage(role=MessageRole.USER, content="hi")], + model="my-deployment", + temperature=0.3, + additional_kwargs={"temperature": 5.0}, + )