diff --git a/.changeset/hungry-bugs-lick.md b/.changeset/hungry-bugs-lick.md new file mode 100644 index 000000000000..672de3ba8487 --- /dev/null +++ b/.changeset/hungry-bugs-lick.md @@ -0,0 +1,5 @@ +--- +"@langchain/google-common": minor +--- + +Add jsonSchema method support to withStructuredOutput in langchain-google-common and VertexAI diff --git a/libs/providers/langchain-google-common/src/chat_models.ts b/libs/providers/langchain-google-common/src/chat_models.ts index 1df6e27198a4..06e831ab00af 100644 --- a/libs/providers/langchain-google-common/src/chat_models.ts +++ b/libs/providers/langchain-google-common/src/chat_models.ts @@ -20,7 +20,10 @@ import { RunnableSequence, } from "@langchain/core/runnables"; import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"; -import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; +import { + BaseLLMOutputParser, + JsonOutputParser, +} from "@langchain/core/output_parsers"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { concat } from "@langchain/core/utils/stream"; import { @@ -492,60 +495,73 @@ export abstract class ChatGoogleBase const method = config?.method; const includeRaw = config?.includeRaw; if (method === "jsonMode") { - throw new Error(`Google only supports "functionCalling" as a method.`); + throw new Error( + `Google only supports "jsonSchema" or "functionCalling" as a method.` + ); } - let functionName = name ?? "extract"; + let llm; let outputParser: BaseLLMOutputParser; - let tools: GeminiTool[]; - if (isInteropZodSchema(schema)) { - const jsonSchema = schemaToGeminiParameters(schema); - tools = [ - { - functionDeclarations: [ - { - name: functionName, - description: - jsonSchema.description ?? "A function available to call.", - parameters: jsonSchema as GeminiFunctionSchema, - }, - ], - }, - ]; - outputParser = new JsonOutputKeyToolsParser({ - returnSingle: true, - keyName: functionName, - zodSchema: schema, - }); - } else { - let geminiFunctionDefinition: GeminiFunctionDeclaration; - if ( - typeof schema.name === "string" && - typeof schema.parameters === "object" && - schema.parameters != null - ) { - geminiFunctionDefinition = schema as GeminiFunctionDeclaration; - functionName = schema.name; + if (method === "functionCalling") { + let functionName = name ?? "extract"; + let tools: GeminiTool[]; + if (isInteropZodSchema(schema)) { + const jsonSchema = schemaToGeminiParameters(schema); + tools = [ + { + functionDeclarations: [ + { + name: functionName, + description: + jsonSchema.description ?? "A function available to call.", + parameters: jsonSchema as GeminiFunctionSchema, + }, + ], + }, + ]; + outputParser = new JsonOutputKeyToolsParser({ + returnSingle: true, + keyName: functionName, + zodSchema: schema, + }); } else { - // We are providing the schema for *just* the parameters, probably - const parameters: GeminiJsonSchema = removeAdditionalProperties(schema); - geminiFunctionDefinition = { - name: functionName, - description: schema.description ?? "", - parameters, - }; + let geminiFunctionDefinition: GeminiFunctionDeclaration; + if ( + typeof schema.name === "string" && + typeof schema.parameters === "object" && + schema.parameters != null + ) { + geminiFunctionDefinition = schema as GeminiFunctionDeclaration; + functionName = schema.name; + } else { + // We are providing the schema for *just* the parameters, probably + const parameters: GeminiJsonSchema = + removeAdditionalProperties(schema); + geminiFunctionDefinition = { + name: functionName, + description: schema.description ?? "", + parameters, + }; + } + tools = [ + { + functionDeclarations: [geminiFunctionDefinition], + }, + ]; + outputParser = new JsonOutputKeyToolsParser({ + returnSingle: true, + keyName: functionName, + }); } - tools = [ - { - functionDeclarations: [geminiFunctionDefinition], - }, - ]; - outputParser = new JsonOutputKeyToolsParser({ - returnSingle: true, - keyName: functionName, + llm = this.bindTools(tools).withConfig({ tool_choice: functionName }); + } else { + // Default to jsonSchema method + const jsonSchema = schemaToGeminiParameters(schema); + llm = this.withConfig({ + responseSchema: jsonSchema as GeminiJsonSchema, }); + outputParser = new JsonOutputParser(); } - const llm = this.bindTools(tools).withConfig({ tool_choice: functionName }); if (!includeRaw) { return llm.pipe(outputParser).withConfig({ diff --git a/libs/providers/langchain-google-common/src/tests/chat_models.test.ts b/libs/providers/langchain-google-common/src/tests/chat_models.test.ts index 9415f8869623..17033961fc7c 100644 --- a/libs/providers/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/providers/langchain-google-common/src/tests/chat_models.test.ts @@ -1649,7 +1649,9 @@ describe("Mock ChatGoogle - Gemini", () => { const baseModel = new ChatGoogle({ authOptions, }); - const model = baseModel.withStructuredOutput(tool); + const model = baseModel.withStructuredOutput(tool, { + method: "functionCalling", + }); await model.invoke("What?"); @@ -1711,7 +1713,9 @@ describe("Mock ChatGoogle - Gemini", () => { }, required: ["greeterName"], }; - const model = baseModel.withStructuredOutput(schema); + const model = baseModel.withStructuredOutput(schema, { + method: "functionCalling", + }); await model.invoke("Hi, I'm kwkaiser"); const func = record?.opts?.data?.tools?.[0]?.functionDeclarations?.[0]; @@ -1721,6 +1725,124 @@ describe("Mock ChatGoogle - Gemini", () => { expect(func.parameters?.properties?.greeterName?.nullable).toEqual(true); }); + test("4. Functions withStructuredOutput - jsonSchema method request", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-json-schema-mock.json", + }; + + const schema = z.object({ + testName: z.string().describe("The name of the test that should be run."), + }); + + const baseModel = new ChatGoogle({ + authOptions, + }); + const model = baseModel.withStructuredOutput(schema, { + method: "jsonSchema", + }); + + await model.invoke("What?"); + + const { data } = record.opts; + // Should not have tools when using jsonSchema method + expect(data.tools).not.toBeDefined(); + // Should have responseSchema in generationConfig + expect(data.generationConfig).toBeDefined(); + expect(data.generationConfig.responseSchema).toBeDefined(); + expect(data.generationConfig.responseSchema.type).toBe("object"); + expect(data.generationConfig.responseSchema.properties).toBeDefined(); + expect( + data.generationConfig.responseSchema.properties.testName + ).toBeDefined(); + expect(data.generationConfig.responseSchema.properties.testName.type).toBe( + "string" + ); + // Should set responseMimeType to application/json + expect(data.generationConfig.responseMimeType).toBe("application/json"); + }); + + test("4. Functions withStructuredOutput - default uses jsonSchema method", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-json-schema-mock.json", + }; + + const schema = z.object({ + testName: z.string().describe("The name of the test."), + }); + + const baseModel = new ChatGoogle({ + authOptions, + }); + // Not specifying method - should default to jsonSchema + const model = baseModel.withStructuredOutput(schema); + + await model.invoke("What is the answer?"); + + const { data } = record.opts; + // Should not have tools when using jsonSchema method (default) + expect(data.tools).not.toBeDefined(); + // Should have responseSchema in generationConfig + expect(data.generationConfig).toBeDefined(); + expect(data.generationConfig.responseSchema).toBeDefined(); + expect(data.generationConfig.responseMimeType).toBe("application/json"); + }); + + test("4. Functions withStructuredOutput - functionCalling method request", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-4-mock.json", + }; + + const schema = z.object({ + testName: z.string().describe("The name of the test that should be run."), + }); + + const baseModel = new ChatGoogle({ + authOptions, + }); + const model = baseModel.withStructuredOutput(schema, { + method: "functionCalling", + }); + + await model.invoke("What?"); + + const { data } = record.opts; + // Should have tools when using functionCalling method + expect(data.tools).toBeDefined(); + expect(Array.isArray(data.tools)).toBeTruthy(); + expect(data.tools).toHaveLength(1); + expect(data.tools[0].functionDeclarations).toBeDefined(); + // Should not have responseSchema in generationConfig + expect(data.generationConfig?.responseSchema).not.toBeDefined(); + }); + + test("4. Functions withStructuredOutput - jsonMode throws error", async () => { + const baseModel = new ChatGoogle({}); + + const schema = z.object({ + answer: z.string(), + }); + + expect(() => + baseModel.withStructuredOutput(schema, { + method: "jsonMode", + }) + ).toThrowError( + `Google only supports "jsonSchema" or "functionCalling" as a method.` + ); + }); + test("4. Functions - results", async () => { const record: Record = {}; const projectId = mockId(); diff --git a/libs/providers/langchain-google-common/src/tests/data/chat-json-schema-mock.json b/libs/providers/langchain-google-common/src/tests/data/chat-json-schema-mock.json new file mode 100644 index 000000000000..64f0c5910a70 --- /dev/null +++ b/libs/providers/langchain-google-common/src/tests/data/chat-json-schema-mock.json @@ -0,0 +1,54 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "{\"testName\": \"cobalt\"}" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "promptFeedback": { + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } +} diff --git a/libs/providers/langchain-google-common/src/types.ts b/libs/providers/langchain-google-common/src/types.ts index 51f2a071e91e..40699067512b 100644 --- a/libs/providers/langchain-google-common/src/types.ts +++ b/libs/providers/langchain-google-common/src/types.ts @@ -322,6 +322,12 @@ export interface GoogleAIModelParams extends GoogleModelParams { */ responseMimeType?: GoogleAIResponseMimeType; + /** + * The schema that the model's output should conform to. + * When this is set, the model will output JSON that conforms to the schema. + */ + responseSchema?: GeminiJsonSchema; + /** * Whether or not to stream. * @default false @@ -412,6 +418,12 @@ export interface GoogleAIModelRequestParams extends GoogleAIModelParams { * https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-use */ cachedContent?: string; + + /** + * The schema that the model's output should conform to. + * When this is set, the model will output JSON that conforms to the schema. + */ + responseSchema?: GeminiJsonSchema; } export interface GoogleAIBaseLLMInput @@ -713,6 +725,7 @@ export interface GeminiGenerationConfig { responseModalities?: GoogleAIModelModality[]; thinkingConfig?: GoogleThinkingConfig; speechConfig?: GoogleSpeechConfig; + responseSchema?: GeminiJsonSchema; } export interface GeminiRequest { diff --git a/libs/providers/langchain-google-common/src/utils/common.ts b/libs/providers/langchain-google-common/src/utils/common.ts index b09b602ac028..3885941d9ce7 100644 --- a/libs/providers/langchain-google-common/src/utils/common.ts +++ b/libs/providers/langchain-google-common/src/utils/common.ts @@ -197,6 +197,8 @@ export function copyAIModelParamsInto( options?.responseMimeType ?? params?.responseMimeType ?? target?.responseMimeType; + ret.responseSchema = + options?.responseSchema ?? params?.responseSchema ?? target?.responseSchema; ret.responseModalities = options?.responseModalities ?? params?.responseModalities ?? diff --git a/libs/providers/langchain-google-common/src/utils/gemini.ts b/libs/providers/langchain-google-common/src/utils/gemini.ts index 8c46e25d8e81..58e0a8105cf2 100644 --- a/libs/providers/langchain-google-common/src/utils/gemini.ts +++ b/libs/providers/langchain-google-common/src/utils/gemini.ts @@ -1673,7 +1673,10 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { frequencyPenalty: parameters.frequencyPenalty, maxOutputTokens: parameters.maxOutputTokens, stopSequences: parameters.stopSequences, - responseMimeType: parameters.responseMimeType, + responseMimeType: parameters.responseSchema + ? "application/json" + : parameters.responseMimeType, + responseSchema: parameters.responseSchema, responseModalities: parameters.responseModalities, speechConfig: normalizeSpeechConfig(parameters.speechConfig), };