diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index 73327a3012..940741bdf7 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -3,14 +3,19 @@ import { NativeOllamaHandler } from "../native-ollama" import { ApiHandlerOptions } from "../../../shared/api" import { getOllamaModels } from "../fetchers/ollama" +import { getApiRequestTimeout } from "../utils/timeout-config" // Mock the ollama package -const mockChat = vitest.fn() +const { mockChat, MockOllama } = vitest.hoisted(() => { + const mockChat = vitest.fn() + const MockOllama = vitest.fn().mockImplementation(() => ({ + chat: mockChat, + })) + return { mockChat, MockOllama } +}) vitest.mock("ollama", () => { return { - Ollama: vitest.fn().mockImplementation(() => ({ - chat: mockChat, - })), + Ollama: MockOllama, Message: vitest.fn(), } }) @@ -20,6 +25,13 @@ vitest.mock("../fetchers/ollama", () => ({ getOllamaModels: vitest.fn(), })) +// Mock the timeout config +vitest.mock("../utils/timeout-config", () => ({ + getApiRequestTimeout: vitest.fn(), +})) + +const mockGetApiRequestTimeout = vitest.mocked(getApiRequestTimeout) + const mockGetOllamaModels = vitest.mocked(getOllamaModels) describe("NativeOllamaHandler", () => { @@ -28,6 +40,9 @@ describe("NativeOllamaHandler", () => { beforeEach(() => { vitest.clearAllMocks() + // Default mock for timeout config (600s = 600000ms) + mockGetApiRequestTimeout.mockReturnValue(600_000) + // Default mock for getOllamaModels mockGetOllamaModels.mockResolvedValue({ llama2: { @@ -605,4 +620,64 @@ describe("NativeOllamaHandler", () => { expect(firstEndIndex).toBeGreaterThan(lastPartialIndex) }) }) + + describe("timeout configuration", () => { + it("should pass a custom fetch with timeout to the Ollama client", async () => { + mockGetApiRequestTimeout.mockReturnValue(900_000) // 900s + + // Create a new handler to trigger ensureClient with the mocked timeout + const options: ApiHandlerOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434", + } + + const timeoutHandler = new NativeOllamaHandler(options) + + mockChat.mockImplementation(async function* () { + yield { message: { content: "Response" } } + }) + + const stream = timeoutHandler.createMessage("System", [{ role: "user" as const, content: "Test" }]) + for await (const _ of stream) { + // consume stream + } + + // Verify Ollama constructor was called with a fetch option + expect(MockOllama).toHaveBeenCalledWith( + expect.objectContaining({ + host: "http://localhost:11434", + fetch: expect.any(Function), + }), + ) + }) + + it("should not pass custom fetch when timeout is undefined", async () => { + mockGetApiRequestTimeout.mockReturnValue(undefined) + + const options: ApiHandlerOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434", + } + + const timeoutHandler = new NativeOllamaHandler(options) + + mockChat.mockImplementation(async function* () { + yield { message: { content: "Response" } } + }) + + const stream = timeoutHandler.createMessage("System", [{ role: "user" as const, content: "Test" }]) + for await (const _ of stream) { + // consume stream + } + + // Verify Ollama constructor was called WITHOUT a fetch option + expect(MockOllama).toHaveBeenCalledWith( + expect.not.objectContaining({ + fetch: expect.any(Function), + }), + ) + }) + }) }) diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 99c1dc03cf..bba894e9db 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -7,6 +7,7 @@ import { BaseProvider } from "./base-provider" import type { ApiHandlerOptions } from "../../shared/api" import { getOllamaModels } from "./fetchers/ollama" import { TagMatcher } from "../../utils/tag-matcher" +import { getApiRequestTimeout } from "./utils/timeout-config" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" interface OllamaChatOptions { @@ -160,7 +161,20 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio try { const clientOptions: OllamaOptions = { host: this.options.ollamaBaseUrl || "http://localhost:11434", - // Note: The ollama npm package handles timeouts internally + } + + // Apply configurable timeout via custom fetch wrapper. + // The ollama npm package uses Node.js native fetch (Undici) which + // defaults to a 300s (5 minute) timeout. This respects the user's + // apiRequestTimeout setting (default 600s) to support slow inference. + const timeoutMs = getApiRequestTimeout() + if (timeoutMs) { + clientOptions.fetch = ((url: RequestInfo | URL, init?: RequestInit) => { + return fetch(url, { + ...init, + signal: init?.signal ?? AbortSignal.timeout(timeoutMs), + }) + }) as typeof fetch } // Add API key if provided (for Ollama cloud or authenticated instances)