diff --git a/src/__tests__/cost.test.ts b/src/__tests__/cost.test.ts new file mode 100644 index 0000000..681a008 --- /dev/null +++ b/src/__tests__/cost.test.ts @@ -0,0 +1,64 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { trackUsage, getTotalCost, resetCostTracker } from "../cost"; +import { configure, resetConfig } from "../config"; +import { BudgetExceededError } from "../errors"; + +describe("Cost Tracking", () => { + beforeEach(async () => { + await resetCostTracker(); + resetConfig(); + }); + + it("should track usage for a known model and calculate correct cost", async () => { + // google/gemini-3-flash: $0.075 / 1M input, $0.30 / 1M output + await trackUsage("google/gemini-3-flash", { promptTokens: 1_000_000, completionTokens: 2_000_000 }); + // Cost should be 0.075 + 2 * 0.30 = 0.675 + expect(await getTotalCost()).toBeCloseTo(0.675); + }); + + it("should match model ignoring prefix if exact match not found", async () => { + // This matches the suffix "gemini-3-flash" + await trackUsage("custom-provider/gemini-3-flash", { promptTokens: 2_000_000, completionTokens: 1_000_000 }); + // Cost should be 2 * 0.075 + 0.30 = 0.45 + expect(await getTotalCost()).toBeCloseTo(0.45); + }); + + it("should accumulate costs across multiple calls", async () => { + await trackUsage("google/gemini-3-flash", { promptTokens: 1_000_000, completionTokens: 0 }); // 0.075 + await trackUsage("anthropic/claude-haiku-4.5", { promptTokens: 1_000_000, completionTokens: 1_000_000 }); // 0.25 + 1.25 = 1.50 + expect(await getTotalCost()).toBeCloseTo(1.575); + }); + + it("should throw BudgetExceededError if cost exceeds maxCostPerRun", async () => { + configure({ maxCostPerRun: 1.0 }); + + await trackUsage("google/gemini-3-flash", { promptTokens: 1_000_000, completionTokens: 1_000_000 }); // 0.375 + expect(await getTotalCost()).toBeCloseTo(0.375); + + await expect(async () => { + // 1M prompt + 3M completion = 0.25 + 3.75 = 4.00 + await trackUsage("anthropic/claude-haiku-4.5", { promptTokens: 1_000_000, completionTokens: 3_000_000 }); + }).rejects.toThrow(BudgetExceededError); + + // The cost still accumulated + expect(await getTotalCost()).toBeCloseTo(4.375); + }); + + it("should not throw if maxCostPerRun is not set", async () => { + await expect(async () => { + await trackUsage("gpt-5.5", { promptTokens: 10_000_000, completionTokens: 10_000_000 }); // 50 + 150 = 200 + }).not.toThrow(); + expect(await getTotalCost()).toBeCloseTo(200); + }); + + it("should fallback to gemini-3-flash pricing for unknown models", async () => { + await trackUsage("unknown/model", { promptTokens: 1_000_000, completionTokens: 1_000_000 }); + expect(await getTotalCost()).toBeCloseTo(0.375); + }); + + it("should use custom pricing if configured", async () => { + configure({ pricing: { "my-model": { input: 1.0, output: 2.0 } } }); + await trackUsage("my-model", { promptTokens: 1_000_000, completionTokens: 1_000_000 }); + expect(await getTotalCost()).toBeCloseTo(3.0); + }); +}); diff --git a/src/assertion.ts b/src/assertion.ts index f20b765..bcd1010 100644 --- a/src/assertion.ts +++ b/src/assertion.ts @@ -4,6 +4,7 @@ import { getModelId } from "./config"; import { ASSERTION_MODEL_TIMEOUT, THINKING_BUDGET_DEFAULT } from "./constants"; import { logger } from "./logger"; import { resolveModel } from "./models"; +import { trackUsage } from "./cost"; import { AssertionResult, AssertionOptions } from "./types"; import { resolvePage, safeSnapshot, withTimeout } from "./utils"; @@ -130,7 +131,7 @@ Never hallucinate. Be truthful and if you are not sure, use a low confidence sco // Claude assertion function const getClaudeAssertion = async (): Promise => { // First get Claude's text response with thinking if enabled - const { text } = await generateText({ + const textResult = await generateText({ model: resolveModel(getModelId("assertionPrimary")), temperature: 0, providerOptions: thinkingEnabled @@ -146,20 +147,28 @@ Never hallucinate. Be truthful and if you are not sure, use a low confidence sco messages, }); + if (textResult.usage) { + await trackUsage(getModelId("assertionPrimary"), textResult.usage); + } + // Convert Claude's response to structured format using Haiku - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("assertionPrimary")), temperature: 0.1, - prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, + prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${textResult.text}`, output: Output.object({ schema: assertionSchema }), }); - return output; + if (outputResult.usage) { + await trackUsage(getModelId("assertionPrimary"), outputResult.usage); + } + + return outputResult.output; }; // Gemini assertion function const getGeminiAssertion = async (): Promise => { - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("assertionSecondary")), temperature: 0, providerOptions: thinkingEnabled @@ -178,7 +187,11 @@ Never hallucinate. Be truthful and if you are not sure, use a low confidence sco output: Output.object({ schema: assertionSchema }), }); - return output; + if (outputResult.usage) { + await trackUsage(getModelId("assertionSecondary"), outputResult.usage); + } + + return outputResult.output; }; // Arbiter function using Gemini 2.5 Pro with thinking enabled @@ -241,7 +254,7 @@ Please carefully review the evidence (screenshot and accessibility snapshot (whe }, ]; - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("assertionArbiter")), temperature: 0, providerOptions: { @@ -258,7 +271,11 @@ Please carefully review the evidence (screenshot and accessibility snapshot (whe output: Output.object({ schema: assertionSchema }), }); - return output; + if (outputResult.usage) { + await trackUsage(getModelId("assertionArbiter"), outputResult.usage); + } + + return outputResult.output; }; const runAssertion = async (attempt = 0): Promise => { diff --git a/src/config.ts b/src/config.ts index 40da346..e90957c 100644 --- a/src/config.ts +++ b/src/config.ts @@ -62,6 +62,18 @@ type Config = { }; /** Base path for file uploads. Default: "./uploads" */ uploadBasePath?: string; + /** + * Maximum allowed LLM usage cost per run in USD. + * If the cumulative cost of AI calls exceeds this budget, an error is thrown. + * Helps prevent bill shock. + */ + maxCostPerRun?: number; + /** + * Custom pricing overrides for specific models. + * Key: model ID (e.g. "google/gemini-3-flash") + * Value: { input: number, output: number } where numbers are cost per 1M tokens. + */ + pricing?: Record; }; let globalConfig: Config = {}; diff --git a/src/cost.ts b/src/cost.ts new file mode 100644 index 0000000..93d4820 --- /dev/null +++ b/src/cost.ts @@ -0,0 +1,136 @@ +/** + * cost.ts + * + * Tracks LLM usage and calculates cost based on standard provider pricing. + */ + +import { logger } from "./logger"; +import { getConfig } from "./config"; +import { BudgetExceededError } from "./errors"; +import { redis } from "./redis"; +import * as fs from "fs"; +import * as path from "path"; + +// Standard provider pricing (per 1M tokens, in USD) +const PRICING: Record = { + "google/gemini-3-flash": { input: 0.075, output: 0.30 }, + "google/gemini-3-flash-preview": { input: 0.075, output: 0.30 }, + "google/gemini-2.5-flash": { input: 0.075, output: 0.30 }, + "google/gemini-3.1-pro-preview": { input: 1.25, output: 5.00 }, + "anthropic/claude-4.5-haiku": { input: 0.25, output: 1.25 }, + "anthropic/claude-haiku-4.5": { input: 0.25, output: 1.25 }, + "gpt-5.5": { input: 5.00, output: 15.00 }, // Typical GPT-4o level pricing +}; + +let localTotalCost = 0; + +function getExecutionId() { + return process.env.executionId || "default"; +} + +function getRedisKey() { + return `passmark:run:cost:${getExecutionId()}`; +} + +function getEffectivePricing(model: string): { input: number; output: number } { + const config = getConfig(); + + // 1. Check custom pricing overrides + if (config.pricing && config.pricing[model]) { + return config.pricing[model]; + } + + // 2. Check standard pricing (exact match) + let price = PRICING[model]; + if (price) return price; + + // 3. Try matching suffix (e.g. "openrouter/google/gemini-3-flash" -> "google/gemini-3-flash") + const match = Object.keys(PRICING).find(k => model.endsWith(k)); + if (match) return PRICING[match]; + + // 4. Default fallback + return PRICING["google/gemini-3-flash"]; +} + +export async function trackUsage(model: string, usage: any) { + if (!usage) return; + + const promptTokens = usage.promptTokens ?? usage.prompt_tokens ?? 0; + const completionTokens = usage.completionTokens ?? usage.completion_tokens ?? 0; + + const price = getEffectivePricing(model); + const cost = (promptTokens / 1_000_000) * price.input + (completionTokens / 1_000_000) * price.output; + + let currentTotalCost = 0; + + if (redis) { + const key = getRedisKey(); + const newTotalStr = await redis.incrbyfloat(key, cost); + // Set a 24-hour TTL to prevent infinite cost accumulation from old runs. + await redis.expire(key, 86400); + currentTotalCost = parseFloat(newTotalStr); + } else { + localTotalCost += cost; + currentTotalCost = localTotalCost; + } + + // Attach usage and cost to the active OpenTelemetry span + try { + const { trace } = require("@opentelemetry/api"); + const activeSpan = trace.getActiveSpan(); + if (activeSpan) { + activeSpan.setAttributes({ + "ai.usage.prompt_tokens": promptTokens, + "ai.usage.completion_tokens": completionTokens, + "ai.usage.cost": cost, + "ai.model": model, + }); + } + } catch (err) { + // ignore if api is missing or not installed + } + + const maxCost = getConfig().maxCostPerRun; + if (maxCost !== undefined && currentTotalCost > maxCost) { + throw new BudgetExceededError(`Maximum cost per run of $${maxCost} exceeded. Current total cost across all workers: $${currentTotalCost.toFixed(4)}`); + } +} + +export async function getTotalCost(): Promise { + if (redis) { + const val = await redis.get(getRedisKey()); + return val ? parseFloat(val) : 0; + } + return localTotalCost; +} + +export async function resetCostTracker() { + if (redis) { + await redis.del(getRedisKey()); + } + localTotalCost = 0; +} + +// Log the total cost at the end of the Node.js process +// Also generate a passmark-cost.json artifact in uploadBasePath if configured. +process.on("exit", () => { + // We use localTotalCost for the exit log because process.on("exit") is synchronous and we can't await Redis. + // However, since tests run in workers, we will log the local worker cost. + if (localTotalCost > 0 && !redis) { + logger.info(`💰 Total LLM Cost for this worker: $${localTotalCost.toFixed(4)}`); + } + + // Try to write artifact + try { + const uploadBase = getConfig().uploadBasePath || "./uploads"; + if (fs.existsSync(uploadBase)) { + const artifactPath = path.join(uploadBase, `passmark-cost-${process.pid}.json`); + fs.writeFileSync(artifactPath, JSON.stringify({ + workerCost: localTotalCost, + timestamp: new Date().toISOString() + }, null, 2)); + } + } catch (e) { + // Ignore artifact write errors + } +}); diff --git a/src/cua/loop.ts b/src/cua/loop.ts index 2a6cbe3..e65bfab 100644 --- a/src/cua/loop.ts +++ b/src/cua/loop.ts @@ -5,6 +5,7 @@ import { logger } from "../logger"; import { waitForDOMStabilization } from "../utils"; import { executeAction, type ComputerAction } from "./actions"; import { getOpenAIClient } from "./client"; +import { trackUsage } from "../cost"; export type RunCUALoopOptions = { page: Page; @@ -40,6 +41,10 @@ type CUAResponse = { id: string; output?: CUAOutputItem[]; output_text?: string; + usage?: { + prompt_tokens: number; + completion_tokens: number; + }; }; type OpenAIWithResponses = OpenAI & { @@ -107,6 +112,12 @@ export async function runCUALoop({ let response: CUAResponse; try { response = await openai.responses.create(initialRequest, { signal: abortSignal }); + if (response.usage) { + await trackUsage(model, { + promptTokens: response.usage.prompt_tokens, + completionTokens: response.usage.completion_tokens, + }); + } } catch (err: unknown) { const e = err as OpenAIErrorLike; logger.error( @@ -170,6 +181,13 @@ export async function runCUALoop({ }, { signal: abortSignal }, ); + + if (response.usage) { + await trackUsage(model, { + promptTokens: response.usage.prompt_tokens, + completionTokens: response.usage.completion_tokens, + }); + } } logger.warn(`[cua] loop hit maxSteps=${maxSteps} without model stopping`); diff --git a/src/errors.ts b/src/errors.ts index f10ced4..441caa5 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -89,4 +89,13 @@ export class ValidationError extends PassmarkError { constructor(message: string) { super(message, "VALIDATION_ERROR"); } +} + +/** + * Thrown when the LLM usage cost exceeds the configured maxCostPerRun budget. + */ +export class BudgetExceededError extends PassmarkError { + constructor(message: string) { + super(message, "BUDGET_EXCEEDED"); + } } \ No newline at end of file diff --git a/src/extract.ts b/src/extract.ts index abff50d..bb17c29 100644 --- a/src/extract.ts +++ b/src/extract.ts @@ -2,6 +2,7 @@ import { generateText, Output } from "ai"; import { z } from "zod"; import { getModelId } from "./config"; import { resolveModel } from "./models"; +import { trackUsage } from "./cost"; const extractionSchema = z.object({ extractedValue: z.string().describe("The extracted value based on the prompt"), @@ -35,7 +36,7 @@ export async function extractDataWithAI({ url: string; prompt: string; }): Promise { - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("utility")), temperature: 0, output: Output.object({ schema: extractionSchema }), @@ -66,5 +67,9 @@ ${prompt} Return the extracted value.`, }); - return output.extractedValue; + if (outputResult.usage) { + await trackUsage(getModelId("utility"), outputResult.usage); + } + + return outputResult.output.extractedValue; } diff --git a/src/index.ts b/src/index.ts index 3bb3d56..84ceebe 100644 --- a/src/index.ts +++ b/src/index.ts @@ -41,6 +41,7 @@ import { resolveEmailPlaceholders, } from "./data-cache"; import { getConfig, getMode, getModelId } from "./config"; +import { trackUsage } from "./cost"; import { runCUALoop, buildRunStepsPromptCUA, buildRunUserFlowPromptCUA } from "./cua"; import { extractDataWithAI } from "./extract"; import { logger } from "./logger"; @@ -508,6 +509,10 @@ export const runSteps = async ({ }), ); + if (result.usage) { + await trackUsage(getModelId("stepExecution"), result.usage); + } + // Cache the step action only if it was a single tool call (simple, deterministic action). // Multi-step actions are not cached as they may be non-deterministic. const allToolCalls = result.steps @@ -671,7 +676,7 @@ export const runUserFlow = async ({ ); if (assertion) { - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("utility")), prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, output: Output.object({ @@ -686,7 +691,12 @@ export const runUserFlow = async ({ }), }), }); - return output; + + if (outputResult.usage) { + await trackUsage(getModelId("utility"), outputResult.usage); + } + + return outputResult.output; } return text; @@ -706,7 +716,7 @@ export const runUserFlow = async ({ }); try { - const { text } = await maybeWithSpan( + const textResult = await maybeWithSpan( { capability: "user_flow_execution", step: "agentic_tool_calling" }, async () => { return generateText({ @@ -749,10 +759,14 @@ export const runUserFlow = async ({ }, ); + if (textResult.usage) { + await trackUsage(effort === "low" ? getModelId("userFlowLow") : getModelId("userFlowHigh"), textResult.usage); + } + if (assertion) { - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("utility")), - prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, + prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${textResult.text}`, output: Output.object({ schema: z.object({ assertionPassed: z.boolean().describe("Indicates whether the assertion passed or not."), @@ -766,10 +780,14 @@ export const runUserFlow = async ({ }), }); - return output; + if (outputResult.usage) { + await trackUsage(getModelId("utility"), outputResult.usage); + } + + return outputResult.output; } - return text; + return textResult.text; } catch (error: unknown) { logger.error({ err: error }, "Error during user flow execution"); } diff --git a/src/utils/index.ts b/src/utils/index.ts index 10e431b..dc2f9b0 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -12,6 +12,7 @@ import { z } from "zod"; import { getModelId } from "../config"; import { logger } from "../logger"; import { resolveModel } from "../models"; +import { trackUsage } from "../cost"; import { PageInput, WaitConditionResult, @@ -300,7 +301,7 @@ ${condition} Analyze the attached before and after screenshots and determine if the wait condition has been met. `; - const { output } = await generateText({ + const outputResult = await generateText({ model: resolveModel(getModelId("utility")), temperature: 0, messages: [ @@ -316,7 +317,11 @@ Analyze the attached before and after screenshots and determine if the wait cond output: Output.object({ schema: waitConditionSchema }), }); - return output; + if (outputResult.usage) { + await trackUsage(getModelId("utility"), outputResult.usage); + } + + return outputResult.output; }; while (Date.now() - startTime < timeout) {