diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 3c5217f13345..2bf66842a8d8 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -35,6 +35,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj index 4a0ae4032f3e..535cb3e6389a 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj @@ -25,6 +25,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs index 844a2341bbc9..399ae13c0d13 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Google.GenAI; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; @@ -113,4 +114,147 @@ public void GoogleAIEmbeddingGeneratorShouldBeRegisteredInServiceCollection() Assert.NotNull(embeddingsGenerationService); Assert.IsType(embeddingsGenerationService); } + +#if NET + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredInKernelServicesWithApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act + kernelBuilder.AddGoogleGenAIChatClient("modelId", "apiKey"); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredInServiceCollectionWithApiKey() + { + // Arrange + var services = new ServiceCollection(); + + // Act + services.AddGoogleGenAIChatClient("modelId", "apiKey"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredInKernelServices() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act + kernelBuilder.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1"); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + var kernel = kernelBuilder.Build(); + Assert.NotNull(kernel.Services); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredInServiceCollection() + { + // Arrange + var services = new ServiceCollection(); + + // Act + services.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(serviceProvider); + } + + [Fact] + public void GoogleAIChatClientShouldBeRegisteredInKernelServicesWithClient() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + using var googleClient = new Client(apiKey: "apiKey"); + + // Act + kernelBuilder.AddGoogleAIChatClient("modelId", googleClient); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleAIChatClientShouldBeRegisteredInServiceCollectionWithClient() + { + // Arrange + var services = new ServiceCollection(); + using var googleClient = new Client(apiKey: "apiKey"); + + // Act + services.AddGoogleAIChatClient("modelId", googleClient); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredWithServiceId() + { + // Arrange + var services = new ServiceCollection(); + const string ServiceId = "test-service-id"; + + // Act + services.AddGoogleGenAIChatClient("modelId", "apiKey", serviceId: ServiceId); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetKeyedService(ServiceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredWithServiceId() + { + // Arrange + var services = new ServiceCollection(); + const string ServiceId = "test-service-id"; + + // Act + services.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1", serviceId: ServiceId); + var serviceProvider = services.BuildServiceProvider(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(serviceProvider); + } + + [Fact] + public void GoogleAIChatClientShouldResolveFromServiceProviderWhenClientNotProvided() + { + // Arrange + var services = new ServiceCollection(); + using var googleClient = new Client(apiKey: "apiKey"); + services.AddSingleton(googleClient); + + // Act + services.AddGoogleAIChatClient("modelId"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs new file mode 100644 index 000000000000..91bf5435efdc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET + +using System; +using Google.GenAI; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.Connectors.Google.UnitTests.Services; + +public sealed class GoogleGeminiChatClientTests +{ + [Fact] + public void GenAIChatClientShouldBeCreatedWithApiKey() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string apiKey = "test-api-key"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleGenAIChatClient(modelId, apiKey); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void VertexAIChatClientShouldBeCreated() + { + // Arrange + string modelId = "gemini-1.5-pro"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleVertexAIChatClient(modelId, project: "test-project", location: "us-central1"); + var kernel = kernelBuilder.Build(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(kernel.Services); + } + + [Fact] + public void ChatClientShouldBeCreatedWithGoogleClient() + { + // Arrange + string modelId = "gemini-1.5-pro"; + using var googleClient = new Client(apiKey: "test-api-key"); + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleAIChatClient(modelId, googleClient); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GenAIChatClientShouldBeCreatedWithServiceId() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string apiKey = "test-api-key"; + string serviceId = "test-service"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleGenAIChatClient(modelId, apiKey, serviceId: serviceId); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void VertexAIChatClientShouldBeCreatedWithServiceId() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string serviceId = "test-service"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleVertexAIChatClient(modelId, project: "test-project", location: "us-central1", serviceId: serviceId); + var kernel = kernelBuilder.Build(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(kernel.Services); + } + + [Fact] + public void GenAIChatClientThrowsForNullModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient(null!, "apiKey")); + } + + [Fact] + public void GenAIChatClientThrowsForEmptyModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("", "apiKey")); + } + + [Fact] + public void GenAIChatClientThrowsForNullApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("modelId", null!)); + } + + [Fact] + public void GenAIChatClientThrowsForEmptyApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("modelId", "")); + } + + [Fact] + public void VertexAIChatClientThrowsForNullModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleVertexAIChatClient(null!, project: "test-project", location: "us-central1")); + } + + [Fact] + public void VertexAIChatClientThrowsForEmptyModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleVertexAIChatClient("", project: "test-project", location: "us-central1")); + } +} + +#endif diff --git a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj index e71d80d17a00..7e104ef8b230 100644 --- a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj +++ b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj @@ -24,6 +24,11 @@ + + + + + diff --git a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs index d6ab3768d0e0..72518e91aaf8 100644 --- a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs @@ -118,4 +118,102 @@ public static IKernelBuilder AddGoogleAIEmbeddingGenerator( dimensions: dimensions); return builder; } + +#if NET + /// + /// Add Google GenAI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The API key for authentication with the Google GenAI API. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleGenAIChatClient( + this IKernelBuilder builder, + string modelId, + string apiKey, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleGenAIChatClient( + modelId, + apiKey, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Add Google Vertex AI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The Google Cloud project ID. If null, will attempt to use the GOOGLE_CLOUD_PROJECT environment variable. + /// The Google Cloud location (e.g., "us-central1"). If null, will attempt to use the GOOGLE_CLOUD_LOCATION environment variable. + /// The optional for authentication. If null, the client will use its internal discovery implementation to get credentials from the environment. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleVertexAIChatClient( + this IKernelBuilder builder, + string modelId, + string? project = null, + string? location = null, + Google.Apis.Auth.OAuth2.ICredential? credential = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleVertexAIChatClient( + modelId, + project, + location, + credential, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Add Google AI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The to use for the service. If null, one must be available in the service provider when this service is resolved. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleAIChatClient( + this IKernelBuilder builder, + string modelId, + Google.GenAI.Client? googleClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleAIChatClient( + modelId, + googleClient, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs index a45001278e9a..40585b55971a 100644 --- a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +#pragma warning disable IDE0005 // Using directive is unnecessary +using System; +#pragma warning restore IDE0005 // Using directive is unnecessary using System.Net.Http; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; @@ -47,4 +50,146 @@ public static IServiceCollection AddGoogleAIEmbeddingGenerator( loggerFactory: serviceProvider.GetService(), dimensions: dimensions)); } + +#if NET + /// + /// Add Google GenAI to the specified service collection. + /// + /// The service collection to add the Google GenAI Chat Client to. + /// The model for chat completion. + /// The API key for authentication with the Google GenAI API. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleGenAIChatClient( + this IServiceCollection services, + string modelId, + string apiKey, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + Verify.NotNullOrWhiteSpace(apiKey); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var googleClient = new Google.GenAI.Client(apiKey: apiKey); + + var builder = new GoogleGenAIChatClient(googleClient, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Add Google Vertex AI to the specified service collection. + /// + /// The service collection to add the Google Vertex AI Chat Client to. + /// The model for chat completion. + /// The Google Cloud project ID. If null, will attempt to use the GOOGLE_CLOUD_PROJECT environment variable. + /// The Google Cloud location (e.g., "us-central1"). If null, will attempt to use the GOOGLE_CLOUD_LOCATION environment variable. + /// The optional for authentication. If null, the client will use its internal discovery implementation to get credentials from the environment. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleVertexAIChatClient( + this IServiceCollection services, + string modelId, + string? project = null, + string? location = null, + Google.Apis.Auth.OAuth2.ICredential? credential = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var googleClient = new Google.GenAI.Client(vertexAI: true, credential: credential, project: project, location: location); + + var builder = new GoogleGenAIChatClient(googleClient, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Add Google AI to the specified service collection. + /// + /// The service collection to add the Google AI Chat Client to. + /// The model for chat completion. + /// The to use for the service. If null, one must be available in the service provider when this service is resolved. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleAIChatClient( + this IServiceCollection services, + string modelId, + Google.GenAI.Client? googleClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var client = googleClient ?? serviceProvider.GetRequiredService(); + + var builder = new GoogleGenAIChatClient(client, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs b/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs new file mode 100644 index 000000000000..e265ffd70d6c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs @@ -0,0 +1,723 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Google.GenAI; +using Google.GenAI.Types; +using Microsoft.Extensions.AI; +using AIDataContent = Microsoft.Extensions.AI.DataContent; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; +using AIFunctionResultContent = Microsoft.Extensions.AI.FunctionResultContent; +// Type aliases to distinguish between Semantic Kernel and M.E.AI types +using AITextContent = Microsoft.Extensions.AI.TextContent; +using AIUriContent = Microsoft.Extensions.AI.UriContent; + +namespace Microsoft.SemanticKernel.Connectors.Google; + +/// +/// Provides an implementation based on Google.GenAI . +/// +internal sealed class GoogleGenAIChatClient : IChatClient +{ + /// A thought signature that can be used to skip thought validation when sending foreign function calls. + /// + /// See https://ai.google.dev/gemini-api/docs/thought-signatures#faqs. + /// This is more common in agentic scenarios, where a chat history is built up across multiple providers/models. + /// + private static readonly byte[] s_skipThoughtValidation = Encoding.UTF8.GetBytes("skip_thought_signature_validator"); + + /// The wrapped instance (optional). + private readonly Client? _client; + + /// The wrapped instance. + private readonly Models _models; + + /// The default model that should be used when no override is specified. + private readonly string? _defaultModelId; + + /// Lazily-initialized metadata describing the implementation. + private ChatClientMetadata? _metadata; + + /// Initializes a new instance. + /// The to wrap. + /// The default model ID to use for chat requests if not specified. + public GoogleGenAIChatClient(Client client, string? defaultModelId) + { + Verify.NotNull(client); + + this._client = client; + this._models = client.Models; + this._defaultModelId = defaultModelId; + } + + /// Initializes a new instance. + /// The client to wrap. + /// The default model ID to use for chat requests if not specified. + public GoogleGenAIChatClient(Models models, string? defaultModelId) + { + Verify.NotNull(models); + + this._models = models; + this._defaultModelId = defaultModelId; + } + + /// + public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + // Create the request. + (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); + + // Send it. + GenerateContentResponse generateResult = await this._models.GenerateContentAsync(modelId!, contents, config).ConfigureAwait(false); + + // Create the response. + ChatResponse chatResponse = new(new ChatMessage(ChatRole.Assistant, new List())) + { + CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, + ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, + RawRepresentation = generateResult, + ResponseId = generateResult.ResponseId, + }; + + // Populate the response messages. + chatResponse.FinishReason = PopulateResponseContents(generateResult, chatResponse.Messages[0].Contents); + + // Populate usage information if there is any. + if (generateResult.UsageMetadata is { } usageMetadata) + { + chatResponse.Usage = ExtractUsageDetails(usageMetadata); + } + + // Return the response. + return chatResponse; + } + + /// + public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + // Create the request. + (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); + + // Send it, and process the results. + await foreach (GenerateContentResponse generateResult in this._models.GenerateContentStreamAsync(modelId!, contents, config).WithCancellation(cancellationToken).ConfigureAwait(false).ConfigureAwait(false)) + { + // Create a response update for each result in the stream. + ChatResponseUpdate responseUpdate = new(ChatRole.Assistant, new List()) + { + CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, + ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, + RawRepresentation = generateResult, + ResponseId = generateResult.ResponseId, + }; + + // Populate the response update contents. + responseUpdate.FinishReason = PopulateResponseContents(generateResult, responseUpdate.Contents); + + // Populate usage information if there is any. + if (generateResult.UsageMetadata is { } usageMetadata) + { + responseUpdate.Contents.Add(new UsageContent(ExtractUsageDetails(usageMetadata))); + } + + // Yield the update. + yield return responseUpdate; + } + } + + /// + public object? GetService(System.Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + if (serviceKey is null) + { + // If there's a request for metadata, lazily-initialize it and return it. We don't need to worry about race conditions, + // as there's no requirement that the same instance be returned each time, and creation is idempotent. + if (serviceType == typeof(ChatClientMetadata)) + { + return this._metadata ??= new("gcp.gen_ai", new Uri("https://generativelanguage.googleapis.com/"), defaultModelId: this._defaultModelId); + } + + // Allow a consumer to "break glass" and access the underlying client if they need it. + if (serviceType.IsInstanceOfType(this._models)) + { + return this._models; + } + + if (this._client is not null && serviceType.IsInstanceOfType(this._client)) + { + return this._client; + } + + if (serviceType.IsInstanceOfType(this)) + { + return this; + } + } + + return null; + } + + /// + void IDisposable.Dispose() { /* nop */ } + + /// Creates the message parameters for from and . + private (string? ModelId, List Contents, GenerateContentConfig Config) CreateRequest(IEnumerable messages, ChatOptions? options) + { + // Create the GenerateContentConfig object. If the options contains a RawRepresentationFactory, try to use it to + // create the request instance, allowing the caller to populate it with GenAI-specific options. Otherwise, create + // a new instance directly. + string? model = this._defaultModelId; + List contents = []; + GenerateContentConfig config = options?.RawRepresentationFactory?.Invoke(this) as GenerateContentConfig ?? new(); + + if (options is not null) + { + if (options.FrequencyPenalty is { } frequencyPenalty) + { + config.FrequencyPenalty ??= frequencyPenalty; + } + + if (options.Instructions is { } instructions) + { + ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instructions }); + } + + if (options.MaxOutputTokens is { } maxOutputTokens) + { + config.MaxOutputTokens ??= maxOutputTokens; + } + + if (!string.IsNullOrWhiteSpace(options.ModelId)) + { + model = options.ModelId; + } + + if (options.PresencePenalty is { } presencePenalty) + { + config.PresencePenalty ??= presencePenalty; + } + + if (options.Seed is { } seed) + { + config.Seed ??= (int)seed; + } + + if (options.StopSequences is { } stopSequences) + { + (config.StopSequences ??= []).AddRange(stopSequences); + } + + if (options.Temperature is { } temperature) + { + config.Temperature ??= temperature; + } + + if (options.TopP is { } topP) + { + config.TopP ??= topP; + } + + if (options.TopK is { } topK) + { + config.TopK ??= topK; + } + + // Populate tools. Each kind of tool is added on its own, except for function declarations, + // which are grouped into a single FunctionDeclaration. + List? functionDeclarations = null; + if (options.Tools is { } tools) + { + foreach (var tool in tools) + { + switch (tool) + { + case AIFunction af: + functionDeclarations ??= []; + functionDeclarations.Add(new() + { + Name = af.Name, + Description = af.Description ?? "", + }); + break; + + case HostedCodeInterpreterTool: + (config.Tools ??= []).Add(new() { CodeExecution = new() }); + break; + + case HostedFileSearchTool: + (config.Tools ??= []).Add(new() { Retrieval = new() }); + break; + + case HostedWebSearchTool: + (config.Tools ??= []).Add(new() { GoogleSearch = new() }); + break; + } + } + } + + if (functionDeclarations is { Count: > 0 }) + { + Tool functionTools = new(); + (functionTools.FunctionDeclarations ??= []).AddRange(functionDeclarations); + (config.Tools ??= []).Add(functionTools); + } + + // Transfer over the tool mode if there are any tools. + if (options.ToolMode is { } toolMode && config.Tools?.Count > 0) + { + switch (toolMode) + { + case NoneChatToolMode: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.NONE } }; + break; + + case AutoChatToolMode: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.AUTO } }; + break; + + case RequiredChatToolMode required: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.ANY } }; + if (required.RequiredFunctionName is not null) + { + ((config.ToolConfig.FunctionCallingConfig ??= new()).AllowedFunctionNames ??= []).Add(required.RequiredFunctionName); + } + break; + } + } + + // Set the response format if specified. + if (options.ResponseFormat is ChatResponseFormatJson responseFormat) + { + config.ResponseMimeType = "application/json"; + if (responseFormat.Schema is { } schema) + { + config.ResponseJsonSchema = schema; + } + } + } + + // Transfer messages to request, handling system messages specially + Dictionary? callIdToFunctionNames = null; + foreach (var message in messages) + { + if (message.Role == ChatRole.System) + { + string instruction = message.Text; + if (!string.IsNullOrWhiteSpace(instruction)) + { + ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instruction }); + } + + continue; + } + + Content content = new() { Role = message.Role == ChatRole.Assistant ? "model" : "user" }; + content.Parts ??= []; + AddPartsForAIContents(ref callIdToFunctionNames, message.Contents, content.Parts); + + contents.Add(content); + } + + // Make sure the request contains at least one content part (the request would always fail if empty). + if (!contents.SelectMany(c => c.Parts ?? Enumerable.Empty()).Any()) + { + contents.Add(new() { Role = "user", Parts = [new() { Text = "" }] }); + } + + return (model, contents, config); + } + + /// Creates s for and adds them to . + private static void AddPartsForAIContents(ref Dictionary? callIdToFunctionNames, IList contents, List parts) + { + for (int i = 0; i < contents.Count; i++) + { + var content = contents[i]; + + byte[]? thoughtSignature = null; + if (content is not TextReasoningContent { ProtectedData: not null } && + i + 1 < contents.Count && + contents[i + 1] is TextReasoningContent nextReasoning && + string.IsNullOrWhiteSpace(nextReasoning.Text) && + nextReasoning.ProtectedData is { } protectedData) + { + i++; + thoughtSignature = Convert.FromBase64String(protectedData); + } + + // Before the main switch, do any necessary state tracking. We want to do this + // even if the AIContent includes a Part as its RawRepresentation. + if (content is AIFunctionCallContent fcc) + { + (callIdToFunctionNames ??= [])[fcc.CallId] = fcc.Name; + callIdToFunctionNames[""] = fcc.Name; // track last function name in case calls don't have IDs + } + + Part? part = null; + switch (content) + { + case AIContent aic when aic.RawRepresentation is Part rawPart: + part = rawPart; + break; + + case AITextContent textContent: + part = new() { Text = textContent.Text }; + break; + + case TextReasoningContent reasoningContent: + part = new() + { + Thought = true, + Text = !string.IsNullOrWhiteSpace(reasoningContent.Text) ? reasoningContent.Text : null, + ThoughtSignature = reasoningContent.ProtectedData is not null ? Convert.FromBase64String(reasoningContent.ProtectedData) : null, + }; + break; + + case AIDataContent dataContent: + part = new() + { + InlineData = new() + { + MimeType = dataContent.MediaType, + Data = dataContent.Data.ToArray(), + DisplayName = dataContent.Name, + } + }; + break; + + case AIUriContent uriContent: + part = new() + { + FileData = new() + { + FileUri = uriContent.Uri.AbsoluteUri, + MimeType = uriContent.MediaType, + } + }; + break; + + case AIFunctionCallContent functionCallContent: + part = new() + { + FunctionCall = new() + { + Id = functionCallContent.CallId, + Name = functionCallContent.Name, + Args = functionCallContent.Arguments is null ? null : functionCallContent.Arguments as Dictionary ?? new(functionCallContent.Arguments!), + }, + ThoughtSignature = thoughtSignature ?? s_skipThoughtValidation, + }; + break; + + case AIFunctionResultContent functionResultContent: + FunctionResponse funcResponse = new() + { + Id = functionResultContent.CallId, + }; + + if (callIdToFunctionNames?.TryGetValue(functionResultContent.CallId ?? "", out string? functionName) is true || + callIdToFunctionNames?.TryGetValue("", out functionName) is true) + { + funcResponse.Name = functionName; + } + + switch (functionResultContent.Result) + { + case null: + break; + + case AIContent aic when ToFunctionResponsePart(aic) is { } singleContentBlob: + funcResponse.Parts = [singleContentBlob]; + break; + + case IEnumerable aiContents: + List? nonBlobContent = null; + foreach (var aiContent in aiContents) + { + if (ToFunctionResponsePart(aiContent) is { } contentBlob) + { + (funcResponse.Parts ??= []).Add(contentBlob); + } + else + { + (nonBlobContent ??= []).Add(aiContent); + } + } + + if (nonBlobContent is not null) + { + funcResponse.Response = new() { ["result"] = nonBlobContent }; + } + break; + + case AITextContent textContent: + funcResponse.Response = new() { ["result"] = textContent.Text }; + break; + + default: + funcResponse.Response = new() { ["result"] = functionResultContent.Result }; + break; + } + + part = new() + { + FunctionResponse = funcResponse, + }; + + static FunctionResponsePart? ToFunctionResponsePart(AIContent content) + { + switch (content) + { + case AIContent when content.RawRepresentation is FunctionResponsePart functionResponsePart: + return functionResponsePart; + + case AIDataContent dc when IsSupportedMediaType(dc.MediaType): + FunctionResponseBlob dataBlob = new() + { + MimeType = dc.MediaType, + Data = dc.Data.Span.ToArray(), + }; + + if (!string.IsNullOrWhiteSpace(dc.Name)) + { + dataBlob.DisplayName = dc.Name; + } + + return new() { InlineData = dataBlob }; + + case AIUriContent uc when IsSupportedMediaType(uc.MediaType): + return new() + { + FileData = new() + { + MimeType = uc.MediaType, + FileUri = uc.Uri.AbsoluteUri, + } + }; + + default: + return null; + } + + // https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#mm-fr + static bool IsSupportedMediaType(string mediaType) => + // images + mediaType.Equals("image/png", StringComparison.OrdinalIgnoreCase) || + mediaType.Equals("image/jpeg", StringComparison.OrdinalIgnoreCase) || + mediaType.Equals("image/webp", StringComparison.OrdinalIgnoreCase) || + // documents + mediaType.Equals("application/pdf", StringComparison.OrdinalIgnoreCase) || + mediaType.Equals("text/plain", StringComparison.OrdinalIgnoreCase); + } + break; + } + + if (part is not null) + { + part.ThoughtSignature ??= thoughtSignature; + parts.Add(part); + } + + thoughtSignature = null; + } + } + + /// Creates s for and adds them to . + [SuppressMessage("Design", "MEAI001:Suppress for experimental types", Justification = "Using experimental MEAI types")] + private static void AddAIContentsForParts(List parts, IList contents) + { + foreach (var part in parts) + { + AIContent content; + + if (!string.IsNullOrEmpty(part.Text)) + { + content = part.Thought is true ? + new TextReasoningContent(part.Text) : + new AITextContent(part.Text); + } + else if (part.InlineData is { } inlineData) + { + content = new AIDataContent(inlineData.Data, inlineData.MimeType ?? "application/octet-stream") + { + Name = inlineData.DisplayName, + }; + } + else if (part.FileData is { FileUri: not null } fileData) + { + content = new AIUriContent(new Uri(fileData.FileUri), fileData.MimeType ?? "application/octet-stream"); + } + else if (part.FunctionCall is { Name: not null } functionCall) + { + content = new AIFunctionCallContent(functionCall.Id ?? "", functionCall.Name, functionCall.Args!); + } + else if (part.FunctionResponse is { } functionResponse) + { + content = new AIFunctionResultContent( + functionResponse.Id ?? "", + functionResponse.Response?.TryGetValue("output", out var output) is true ? output : + functionResponse.Response?.TryGetValue("error", out var error) is true ? error : + null); + } + else if (part.ExecutableCode is { Code: not null } executableCode) + { +#pragma warning disable MEAI001 // CodeInterpreterToolCallContent is experimental + content = new CodeInterpreterToolCallContent() + { + Inputs = + [ + new AIDataContent(Encoding.UTF8.GetBytes(executableCode.Code), executableCode.Language switch + { + Language.PYTHON => "text/x-python", + _ => "text/x-source-code", + }) + ], + }; +#pragma warning restore MEAI001 + } + else if (part.CodeExecutionResult is { Output: { } codeOutput } codeExecutionResult) + { +#pragma warning disable MEAI001 // CodeInterpreterToolResultContent is experimental + content = new CodeInterpreterToolResultContent() + { + Outputs = + [ + codeExecutionResult.Outcome is Outcome.OUTCOME_OK ? + new AITextContent(codeOutput) : + new ErrorContent(codeOutput) { ErrorCode = codeExecutionResult.Outcome.ToString() } + ], + }; +#pragma warning restore MEAI001 + } + else + { + content = new AIContent(); + } + + content.RawRepresentation = part; + contents.Add(content); + + if (part.ThoughtSignature is { } thoughtSignature) + { + contents.Add(new TextReasoningContent(null) + { + ProtectedData = Convert.ToBase64String(thoughtSignature), + }); + } + } + } + + private static ChatFinishReason? PopulateResponseContents(GenerateContentResponse generateResult, IList responseContents) + { + ChatFinishReason? finishReason = null; + + // Populate the response messages. There should only be at most one candidate, but if there are more, ignore all but the first. + if (generateResult.Candidates is { Count: > 0 } && + generateResult.Candidates[0] is { Content: { } candidateContent } candidate) + { + // Grab the finish reason if one exists. + finishReason = ConvertFinishReason(candidate.FinishReason); + + // Add all of the response content parts as AIContents. + if (candidateContent.Parts is { } parts) + { + AddAIContentsForParts(parts, responseContents); + } + + // Add any citation metadata. + if (candidate.CitationMetadata is { Citations: { Count: > 0 } citations } && + responseContents.OfType().FirstOrDefault() is AITextContent textContent) + { + foreach (var citation in citations) + { + textContent.Annotations = + [ + new CitationAnnotation() + { + Title = citation.Title, + Url = Uri.TryCreate(citation.Uri, UriKind.Absolute, out Uri? uri) ? uri : null, + AnnotatedRegions = + [ + new TextSpanAnnotatedRegion() + { + StartIndex = citation.StartIndex, + EndIndex = citation.EndIndex, + } + ], + } + ]; + } + } + } + + // Populate error information if there is any. + if (generateResult.PromptFeedback is { } promptFeedback) + { + responseContents.Add(new ErrorContent(promptFeedback.BlockReasonMessage)); + } + + return finishReason; + } + + /// Creates an M.E.AI from a Google . + private static ChatFinishReason? ConvertFinishReason(FinishReason? finishReason) + { + return finishReason switch + { + null => null, + + FinishReason.MAX_TOKENS => + ChatFinishReason.Length, + + FinishReason.MALFORMED_FUNCTION_CALL or + FinishReason.UNEXPECTED_TOOL_CALL => + ChatFinishReason.ToolCalls, + + FinishReason.FINISH_REASON_UNSPECIFIED or + FinishReason.STOP => + ChatFinishReason.Stop, + + _ => ChatFinishReason.ContentFilter, + }; + } + + /// Creates a populated from the supplied . + private static UsageDetails ExtractUsageDetails(GenerateContentResponseUsageMetadata usageMetadata) + { + UsageDetails details = new() + { + InputTokenCount = usageMetadata.PromptTokenCount, + OutputTokenCount = usageMetadata.CandidatesTokenCount, + TotalTokenCount = usageMetadata.TotalTokenCount, + }; + + AddIfPresent(nameof(usageMetadata.CachedContentTokenCount), usageMetadata.CachedContentTokenCount); + AddIfPresent(nameof(usageMetadata.ThoughtsTokenCount), usageMetadata.ThoughtsTokenCount); + AddIfPresent(nameof(usageMetadata.ToolUsePromptTokenCount), usageMetadata.ToolUsePromptTokenCount); + + return details; + + void AddIfPresent(string key, int? value) + { + if (value is int intValue) + { + (details.AdditionalCounts ??= [])[key] = intValue; + } + } + } +} + +#endif diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs new file mode 100644 index 000000000000..cf649b24a09f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xRetry; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiGenAIChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipReason = "This test is for manual verification."; + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientGenerationReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("Large Language Model", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Brandon", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and write a long story about my name.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + Assert.False(string.IsNullOrWhiteSpace(message)); + this.Output.WriteLine(message); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("1520", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientReturnsUsageDetailsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Usage); + this.Output.WriteLine($"Input tokens: {response.Usage.InputTokenCount}"); + this.Output.WriteLine($"Output tokens: {response.Usage.OutputTokenCount}"); + this.Output.WriteLine($"Total tokens: {response.Usage.TotalTokenCount}"); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithChatOptionsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Generate a random number between 1 and 100.") + }; + + var chatOptions = new ChatOptions + { + Temperature = 0.0f, + MaxOutputTokens = 100 + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs new file mode 100644 index 000000000000..9173365a60b9 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ComponentModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using xRetry; +using Xunit; +using Xunit.Abstractions; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiGenAIFunctionCallingChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipMessage = "This test is for manual verification."; + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + + var functionCallContent = response.Messages + .SelectMany(m => m.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + + var functionCallContent = responses + .SelectMany(r => r.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var responses = await autoInvokingClient.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + var content = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithMultipleFunctionCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers first and next return age of Anna customer?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("28", content, StringComparison.OrdinalIgnoreCase); + } + + public sealed class CustomerPlugin + { + [KernelFunction(nameof(GetCustomers))] + [Description("Get list of customers.")] + [return: Description("List of customers.")] + public string[] GetCustomers() + { + return + [ + "John Kowalski", + "Anna Nowak", + "Steve Smith", + ]; + } + + [KernelFunction(nameof(GetCustomerAge))] + [Description("Get age of customer.")] + [return: Description("Age of customer.")] + public int GetCustomerAge([Description("Name of customer")] string customerName) + { + return customerName switch + { + "John Kowalski" => 35, + "Anna Nowak" => 28, + "Steve Smith" => 42, + _ => throw new ArgumentException("Customer not found."), + }; + } + } + + public sealed class MathPlugin + { + [KernelFunction(nameof(Sum))] + [Description("Sum numbers.")] + public int Sum([Description("Numbers to sum")] int[] numbers) + { + return numbers.Sum(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs new file mode 100644 index 000000000000..de51bd66d0d8 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xRetry; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiVertexAIChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipReason = "This test is for manual verification."; + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientGenerationReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("Large Language Model", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Brandon", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and write a long story about my name.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + Assert.False(string.IsNullOrWhiteSpace(message)); + this.Output.WriteLine(message); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("1520", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientReturnsUsageDetailsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Usage); + this.Output.WriteLine($"Input tokens: {response.Usage.InputTokenCount}"); + this.Output.WriteLine($"Output tokens: {response.Usage.OutputTokenCount}"); + this.Output.WriteLine($"Total tokens: {response.Usage.TotalTokenCount}"); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithChatOptionsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Generate a random number between 1 and 100.") + }; + + var chatOptions = new ChatOptions + { + Temperature = 0.0f, + MaxOutputTokens = 100 + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs new file mode 100644 index 000000000000..7510b1609719 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ComponentModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using xRetry; +using Xunit; +using Xunit.Abstractions; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiVertexAIFunctionCallingChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipMessage = "This test is for manual verification."; + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + + var functionCallContent = response.Messages + .SelectMany(m => m.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + + var functionCallContent = responses + .SelectMany(r => r.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var responses = await autoInvokingClient.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + var content = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithMultipleFunctionCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers first and next return age of Anna customer?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("28", content, StringComparison.OrdinalIgnoreCase); + } + + public sealed class CustomerPlugin + { + [KernelFunction(nameof(GetCustomers))] + [Description("Get list of customers.")] + [return: Description("List of customers.")] + public string[] GetCustomers() + { + return + [ + "John Kowalski", + "Anna Nowak", + "Steve Smith", + ]; + } + + [KernelFunction(nameof(GetCustomerAge))] + [Description("Get age of customer.")] + [return: Description("Age of customer.")] + public int GetCustomerAge([Description("Name of customer")] string customerName) + { + return customerName switch + { + "John Kowalski" => 35, + "Anna Nowak" => 28, + "Steve Smith" => 42, + _ => throw new ArgumentException("Customer not found."), + }; + } + } + + public sealed class MathPlugin + { + [KernelFunction(nameof(Sum))] + [Description("Sum numbers.")] + public int Sum([Description("Numbers to sum")] int[] numbers) + { + return numbers.Sum(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs index 723785497ccd..7e6bb8a45f54 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs @@ -3,6 +3,8 @@ using System; using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; using Microsoft.SemanticKernel.Embeddings; @@ -65,6 +67,52 @@ protected TestsBase(ITestOutputHelper output) _ => throw new ArgumentOutOfRangeException(nameof(serviceType), serviceType, null) }; + protected IChatClient GetGenAIChatClient(string? overrideModelId = null) + { + var modelId = overrideModelId ?? this.GoogleAI.Gemini.ModelId; + var apiKey = this.GoogleAI.ApiKey; + + var kernel = Kernel.CreateBuilder() + .AddGoogleGenAIChatClient(modelId, apiKey) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetVertexAIChatClient(string? overrideModelId = null) + { + var modelId = overrideModelId ?? this.VertexAI.Gemini.ModelId; + + var kernel = Kernel.CreateBuilder() + .AddGoogleVertexAIChatClient(modelId, project: this.VertexAI.ProjectId, location: this.VertexAI.Location) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetGenAIChatClientWithVision() + { + var modelId = this.GoogleAI.Gemini.VisionModelId; + var apiKey = this.GoogleAI.ApiKey; + + var kernel = Kernel.CreateBuilder() + .AddGoogleGenAIChatClient(modelId, apiKey) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetVertexAIChatClientWithVision() + { + var modelId = this.VertexAI.Gemini.VisionModelId; + + var kernel = Kernel.CreateBuilder() + .AddGoogleVertexAIChatClient(modelId, project: this.VertexAI.ProjectId, location: this.VertexAI.Location) + .Build(); + + return kernel.GetRequiredService(); + } + [Obsolete("Temporary test utility for Obsolete ITextEmbeddingGenerationService")] protected ITextEmbeddingGenerationService GetEmbeddingService(ServiceType serviceType) => serviceType switch { diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index d0e45a75f94f..ec65cb12f288 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -41,6 +41,7 @@ +