From 368af4c642ddf4f2bdfb989cae020c13b6800ce8 Mon Sep 17 00:00:00 2001 From: zhengxuyu Date: Tue, 28 Apr 2026 17:26:50 +0100 Subject: [PATCH] feat: add token usage tracking to runSteps, runUserFlow, and assert Surfaces AI SDK usage data (inputTokens, outputTokens, totalTokens) from every generateText call. runSteps() now returns a UsageResult with per-call breakdowns. runUserFlow() includes a usage field in its return value. Assertion, extraction, and wait-condition calls participate in tracking when a usageTracker is provided. New exports: createUsageTracker(), TokenUsage, UsageResult, UsageTracker. --- CHANGELOG.md | 1 + src/__tests__/assertion.test.ts | 74 +++++++- src/__tests__/extract-usage.test.ts | 63 +++++++ src/__tests__/integration/run-steps.test.ts | 51 +++++- .../integration/run-user-flow.test.ts | 162 ++++++++++++++++++ src/__tests__/usage.test.ts | 92 ++++++++++ src/assertion.ts | 100 ++++++----- src/extract.ts | 7 +- src/index.ts | 100 +++++++---- src/types.ts | 11 +- src/usage.ts | 72 ++++++++ src/utils/index.ts | 41 ++--- 12 files changed, 667 insertions(+), 107 deletions(-) create mode 100644 src/__tests__/extract-usage.test.ts create mode 100644 src/__tests__/integration/run-user-flow.test.ts create mode 100644 src/__tests__/usage.test.ts create mode 100644 src/usage.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 5dbea15..8f66fb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Token usage tracking**: `runSteps()` now returns a `UsageResult` with per-call token breakdowns (`inputTokens`, `outputTokens`, `totalTokens`) and aggregate totals. `runUserFlow()` includes a `usage` field in its return value. New exported types: `TokenUsage`, `UsageResult`, `UsageTracker`. New exported function: `createUsageTracker()`. Assertion, extraction, and wait-condition AI calls all participate in tracking when a `usageTracker` is provided. - `maxRetries` option to `AssertionOptions` (default: `1`) to control how many times a failed assertion is retried with a fresh page snapshot and screenshot. Setting it to `0` disables retries. - `onRetry` callback to `AssertionOptions` that fires before each retry, receiving the retry index and the full `AssertionResult` from the previous attempt for debugging flaky assertions. - **CUA mode** (`configure({ ai: { mode: "cua" } })`): execute `runSteps` and `runUserFlow` through OpenAI's Responses API with the built-in `computer` tool. Screenshot-driven, coordinate-based actions via Playwright's `page.mouse` / `page.keyboard`. Requires `OPENAI_API_KEY` and `gateway: "none"`; Redis step caching is skipped in this mode because coordinate actions aren't portable across viewport sizes. diff --git a/src/__tests__/assertion.test.ts b/src/__tests__/assertion.test.ts index c07e587..7ab58f3 100644 --- a/src/__tests__/assertion.test.ts +++ b/src/__tests__/assertion.test.ts @@ -35,6 +35,7 @@ vi.mock("../utils", () => ({ import { assert } from "../assertion"; import { withTimeout } from "../utils"; import { generateText } from "ai"; +import { createUsageTracker } from "../usage"; function createMockPage() { return { @@ -67,13 +68,17 @@ function makeGenerateTextImpl(opts: { return { output: opts.claude } as any; } if (model.includes("gemini-3-flash")) { - const g = typeof opts.gemini === "function" ? (opts.gemini as () => AssertionObj)() : opts.gemini; + const g = + typeof opts.gemini === "function" ? (opts.gemini as () => AssertionObj)() : opts.gemini; return { output: g } as any; } if (model.includes("3.1-pro-preview")) { return { - output: - opts.arbiter ?? { assertionPassed: false, confidenceScore: 0, reasoning: "no arbiter set" }, + output: opts.arbiter ?? { + assertionPassed: false, + confidenceScore: 0, + reasoning: "no arbiter set", + }, } as any; } return { output: { assertionPassed: false, confidenceScore: 0, reasoning: "unknown" } } as any; @@ -137,7 +142,11 @@ describe("assert consensus logic", () => { makeGenerateTextImpl({ claude: { assertionPassed: true, confidenceScore: 95, reasoning: "Claude: yes" }, gemini: { assertionPassed: false, confidenceScore: 30, reasoning: "Gemini: no" }, - arbiter: { assertionPassed: true, confidenceScore: 70, reasoning: "Arbiter: I side with Claude" }, + arbiter: { + assertionPassed: true, + confidenceScore: 70, + reasoning: "Arbiter: I side with Claude", + }, }) as any, ); @@ -160,7 +169,11 @@ describe("assert consensus logic", () => { makeGenerateTextImpl({ claude: { assertionPassed: true, confidenceScore: 60, reasoning: "Claude: yes" }, gemini: { assertionPassed: false, confidenceScore: 40, reasoning: "Gemini: no" }, - arbiter: { assertionPassed: false, confidenceScore: 45, reasoning: "Arbiter: I disagree, it fails" }, + arbiter: { + assertionPassed: false, + confidenceScore: 45, + reasoning: "Arbiter: I disagree, it fails", + }, }) as any, ); @@ -188,7 +201,11 @@ describe("assert consensus logic", () => { if (geminiCalls === 1) { throw new Error("transient model error"); } - return { assertionPassed: true, confidenceScore: 80, reasoning: "Gemini: ok after retry" }; + return { + assertionPassed: true, + confidenceScore: 80, + reasoning: "Gemini: ok after retry", + }; }, }) as any, ); @@ -210,7 +227,9 @@ describe("assert consensus logic", () => { const page = createMockPage(); // Make withTimeout reject once to simulate timeout - vi.mocked(withTimeout).mockImplementationOnce(() => Promise.reject(new Error("timed out")) as any); + vi.mocked(withTimeout).mockImplementationOnce( + () => Promise.reject(new Error("timed out")) as any, + ); vi.mocked(generateText).mockImplementation( makeGenerateTextImpl({ @@ -229,4 +248,45 @@ describe("assert consensus logic", () => { expect(res).toContain("✅ passed"); }); + + it("records usage data when usageTracker is provided", async () => { + const page = createMockPage(); + const tracker = createUsageTracker(); + + // Return usage data from mock generateText calls + vi.mocked(generateText).mockImplementation(async (args: any) => { + const model = String(args.model ?? ""); + const wantsStructured = Boolean(args.output); + const usage = { inputTokens: 100, outputTokens: 50, totalTokens: 150 }; + + if (!wantsStructured) { + return { text: "claude text", usage } as any; + } + if (model.includes("anthropic")) { + return { + output: { assertionPassed: true, confidenceScore: 90, reasoning: "Claude: ok" }, + usage, + } as any; + } + return { + output: { assertionPassed: true, confidenceScore: 80, reasoning: "Gemini: ok" }, + usage, + } as any; + }); + + await assert({ + page, + assertion: "The page shows items", + test: mockTest, + expect: ((a: unknown, _m?: string) => ({ toBe: (_v: unknown) => {} })) as any, + failSilently: true, + usageTracker: tracker, + }); + + const result = tracker.getResult(); + // Claude text call + Claude structured call + Gemini structured call = 3 calls + expect(result.details.length).toBeGreaterThanOrEqual(3); + expect(result.totalTokens).toBeGreaterThan(0); + expect(result.details.every((d) => d.operation === "assertion")).toBe(true); + }); }); diff --git a/src/__tests__/extract-usage.test.ts b/src/__tests__/extract-usage.test.ts new file mode 100644 index 0000000..67b2b62 --- /dev/null +++ b/src/__tests__/extract-usage.test.ts @@ -0,0 +1,63 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("../instrumentation", () => ({ axiomEnabled: false })); + +vi.mock("../models", () => ({ + resolveModel: (id: string) => id, +})); + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + generateText: vi.fn(), + }; +}); + +import { extractDataWithAI } from "../extract"; +import { createUsageTracker } from "../usage"; +import { generateText } from "ai"; + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("extractDataWithAI usage tracking", () => { + it("records usage when usageTracker is provided", async () => { + const tracker = createUsageTracker(); + + vi.mocked(generateText).mockResolvedValue({ + output: { extractedValue: "abc123" }, + usage: { inputTokens: 200, outputTokens: 30, totalTokens: 230 }, + } as any); + + const result = await extractDataWithAI({ + snapshot: "page content", + url: "https://example.com?token=abc123", + prompt: "Extract the token query parameter", + usageTracker: tracker, + }); + + expect(result).toBe("abc123"); + + const usage = tracker.getResult(); + expect(usage.details).toHaveLength(1); + expect(usage.details[0].operation).toBe("extraction"); + expect(usage.totalTokens).toBe(230); + }); + + it("works without usageTracker (backward compatible)", async () => { + vi.mocked(generateText).mockResolvedValue({ + output: { extractedValue: "value" }, + usage: { inputTokens: 100, outputTokens: 20, totalTokens: 120 }, + } as any); + + const result = await extractDataWithAI({ + snapshot: "content", + url: "https://example.com", + prompt: "Extract something", + }); + + expect(result).toBe("value"); + }); +}); diff --git a/src/__tests__/integration/run-steps.test.ts b/src/__tests__/integration/run-steps.test.ts index 8151f8e..88033dd 100644 --- a/src/__tests__/integration/run-steps.test.ts +++ b/src/__tests__/integration/run-steps.test.ts @@ -101,6 +101,7 @@ import { generateText } from "ai"; import { runCUALoop } from "../../cua"; import type { Page } from "@playwright/test"; import type { Step } from "../../types"; +import type { UsageResult } from "../../usage"; function createMockPage() { const mockLocator = { @@ -325,10 +326,7 @@ describe("runSteps", () => { it("call-level ai option applies to all steps without per-step override", async () => { const page = createMockPage(); - const steps: Step[] = [ - { description: "Step A" }, - { description: "Step B" }, - ]; + const steps: Step[] = [{ description: "Step A" }, { description: "Step B" }]; await runSteps({ page, @@ -382,4 +380,49 @@ describe("runSteps", () => { // generateText should be called because the step has bypassCache: true expect(generateText).toHaveBeenCalled(); }); + + it("returns usage data from AI step execution", async () => { + const page = createMockPage(); + const steps: Step[] = [{ description: "Step 1" }, { description: "Step 2" }]; + + vi.mocked(generateText).mockResolvedValue({ + text: "done", + steps: [], + usage: { inputTokens: 500, outputTokens: 100, totalTokens: 600 }, + } as any); + + const result = await runSteps({ + page, + userFlow: "usage tracking flow", + steps, + }); + + expect(result).toBeDefined(); + const usage = result as UsageResult; + expect(usage.details).toHaveLength(2); + expect(usage.totalTokens).toBe(1200); + expect(usage.details[0].operation).toBe("stepExecution"); + }); + + it("returns zero-token usage for cached steps", async () => { + const page = createMockPage(); + const steps: Step[] = [{ description: "Cached step" }]; + + vi.mocked(redis!.hgetall).mockResolvedValue({ + locator: 'getByRole("button", { name: "Go" })', + action: "click", + description: "Go button", + value: "", + }); + + const result = await runSteps({ + page, + userFlow: "cached usage flow", + steps, + }); + + expect(result).toBeDefined(); + const usage = result as UsageResult; + expect(usage.totalTokens).toBe(0); + }); }); diff --git a/src/__tests__/integration/run-user-flow.test.ts b/src/__tests__/integration/run-user-flow.test.ts new file mode 100644 index 0000000..8e39fe6 --- /dev/null +++ b/src/__tests__/integration/run-user-flow.test.ts @@ -0,0 +1,162 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("../../instrumentation", () => ({ axiomEnabled: false })); + +vi.mock("../../redis", () => ({ + redis: { + hgetall: vi.fn().mockResolvedValue({}), + hset: vi.fn().mockResolvedValue("OK"), + expire: vi.fn().mockResolvedValue(1), + }, +})); + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + generateText: vi.fn(), + streamText: vi.fn(), + }; +}); + +vi.mock("axiom/ai", () => ({ + withSpan: vi.fn((_meta: unknown, fn: () => unknown) => fn()), + wrapAISDKModel: vi.fn((model: unknown) => model), + wrapTool: vi.fn((_name: unknown, tool: unknown) => tool), + initAxiomAI: vi.fn(), + RedactionPolicy: { AxiomDefault: {} }, +})); + +vi.mock("../../models", () => ({ + resolveModel: vi.fn().mockReturnValue("mocked-model"), +})); + +vi.mock("../../tools", () => ({ + getAItools: vi.fn().mockReturnValue({ + tools: {}, + getPendingCacheData: vi.fn().mockReturnValue(null), + clearPendingCacheData: vi.fn(), + }), +})); + +vi.mock("../../utils", () => ({ + runLocatorCode: vi.fn().mockResolvedValue(undefined), + safeSnapshot: vi.fn().mockResolvedValue("snapshot content"), + verifyActionEffect: vi.fn().mockResolvedValue(undefined), + waitForCondition: vi.fn().mockResolvedValue(undefined), + waitForDOMStabilization: vi.fn().mockResolvedValue(undefined), + generatePhoneNumber: vi.fn().mockReturnValue("1234567890"), + resolvePage: vi.fn((input: unknown) => input), +})); + +vi.mock("../../extract", () => ({ + extractDataWithAI: vi.fn().mockResolvedValue("extracted-value"), +})); + +vi.mock("../../assertion", () => ({ + assert: vi.fn().mockResolvedValue("assertion passed"), +})); + +vi.mock("../../logger", () => ({ + logger: { + info: vi.fn(), + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +vi.mock("../../email", () => ({ + extractEmailContent: vi.fn(), + generateEmail: vi.fn().mockReturnValue("test@example.com"), +})); + +vi.mock("../../utils/secure-script-runner", () => ({ + runSecureScript: vi.fn().mockResolvedValue(undefined), +})); + +vi.mock("../../cua", () => ({ + runCUALoop: vi.fn().mockResolvedValue("cua-result"), + buildRunStepsPromptCUA: vi.fn().mockReturnValue("cua-prompt"), + buildRunUserFlowPromptCUA: vi.fn().mockReturnValue("cua-userflow-prompt"), +})); + +import { runUserFlow } from "../../index"; +import { resetConfig } from "../../config"; +import { generateText } from "ai"; +import type { Page } from "@playwright/test"; + +function createMockPage() { + const mockContext = { on: vi.fn(), off: vi.fn() }; + return { + locator: vi + .fn() + .mockReturnValue({ click: vi.fn(), fill: vi.fn(), describe: vi.fn().mockReturnThis() }), + getByRole: vi.fn().mockReturnValue({ click: vi.fn() }), + ariaSnapshot: vi.fn().mockResolvedValue("snapshot"), + screenshot: vi.fn().mockResolvedValue(Buffer.from("fake")), + url: vi.fn().mockReturnValue("https://example.com"), + evaluate: vi.fn().mockResolvedValue(undefined), + waitForLoadState: vi.fn().mockResolvedValue(undefined), + context: vi.fn().mockReturnValue(mockContext), + } as unknown as Page; +} + +describe("runUserFlow usage tracking", () => { + beforeEach(() => { + vi.clearAllMocks(); + resetConfig(); + }); + + it("returns usage data alongside text result", async () => { + vi.mocked(generateText).mockResolvedValue({ + text: "Flow completed successfully", + steps: [], + usage: { inputTokens: 800, outputTokens: 200, totalTokens: 1000 }, + } as any); + + const result = await runUserFlow({ + page: createMockPage(), + userFlow: "Test flow", + steps: "Navigate to page and verify", + }); + + expect(result).toBeDefined(); + expect(result!.text).toBe("Flow completed successfully"); + expect(result!.usage).toBeDefined(); + expect(result!.usage.totalTokens).toBe(1000); + expect(result!.usage.details[0].operation).toBe("userFlow"); + }); + + it("includes assertion usage when assertion is provided", async () => { + let callCount = 0; + vi.mocked(generateText).mockImplementation(async (args: any) => { + callCount++; + if (callCount === 1) { + // Main flow execution + return { + text: "Done", + steps: [], + usage: { inputTokens: 800, outputTokens: 200, totalTokens: 1000 }, + } as any; + } + // Assertion parsing call + return { + output: { assertionPassed: true, confidenceScore: 90, reasoning: "ok" }, + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + } as any; + }); + + const result = await runUserFlow({ + page: createMockPage(), + userFlow: "Test flow", + steps: "Navigate and check", + assertion: "Page shows content", + }); + + expect(result).toBeDefined(); + expect(result!.usage).toBeDefined(); + expect(result!.usage.totalTokens).toBe(1150); + expect(result!.usage.details).toHaveLength(2); + }); +}); diff --git a/src/__tests__/usage.test.ts b/src/__tests__/usage.test.ts new file mode 100644 index 0000000..321ec46 --- /dev/null +++ b/src/__tests__/usage.test.ts @@ -0,0 +1,92 @@ +import { describe, it, expect } from "vitest"; +import { createUsageTracker } from "../usage"; + +describe("UsageTracker", () => { + it("should start with empty state", () => { + const tracker = createUsageTracker(); + const result = tracker.getResult(); + + expect(result.details).toEqual([]); + expect(result.totalInputTokens).toBe(0); + expect(result.totalOutputTokens).toBe(0); + expect(result.totalTokens).toBe(0); + }); + + it("should track a single AI call", () => { + const tracker = createUsageTracker(); + tracker.record({ + model: "google/gemini-3-flash", + operation: "stepExecution", + usage: { inputTokens: 1000, outputTokens: 200, totalTokens: 1200 }, + }); + + const result = tracker.getResult(); + expect(result.details).toHaveLength(1); + expect(result.details[0]).toEqual({ + model: "google/gemini-3-flash", + operation: "stepExecution", + inputTokens: 1000, + outputTokens: 200, + totalTokens: 1200, + }); + expect(result.totalTokens).toBe(1200); + }); + + it("should accumulate multiple AI calls", () => { + const tracker = createUsageTracker(); + tracker.record({ + model: "google/gemini-3-flash", + operation: "stepExecution", + usage: { inputTokens: 1000, outputTokens: 200, totalTokens: 1200 }, + }); + tracker.record({ + model: "anthropic/claude-haiku-4.5", + operation: "assertion", + usage: { inputTokens: 500, outputTokens: 100, totalTokens: 600 }, + }); + + const result = tracker.getResult(); + expect(result.details).toHaveLength(2); + expect(result.totalInputTokens).toBe(1500); + expect(result.totalOutputTokens).toBe(300); + expect(result.totalTokens).toBe(1800); + }); + + it("should handle missing usage data gracefully", () => { + const tracker = createUsageTracker(); + tracker.record({ + model: "google/gemini-3-flash", + operation: "stepExecution", + usage: undefined, + }); + + const result = tracker.getResult(); + expect(result.details).toHaveLength(1); + expect(result.details[0].inputTokens).toBe(0); + expect(result.details[0].outputTokens).toBe(0); + expect(result.details[0].totalTokens).toBe(0); + expect(result.totalTokens).toBe(0); + }); + + it("should merge another tracker's results", () => { + const parent = createUsageTracker(); + parent.record({ + model: "google/gemini-3-flash", + operation: "stepExecution", + usage: { inputTokens: 1000, outputTokens: 200, totalTokens: 1200 }, + }); + + const child = createUsageTracker(); + child.record({ + model: "anthropic/claude-haiku-4.5", + operation: "assertion", + usage: { inputTokens: 500, outputTokens: 100, totalTokens: 600 }, + }); + + parent.merge(child); + + const result = parent.getResult(); + expect(result.details).toHaveLength(2); + expect(result.totalTokens).toBe(1800); + }); +}); diff --git a/src/assertion.ts b/src/assertion.ts index f20b765..9632e93 100644 --- a/src/assertion.ts +++ b/src/assertion.ts @@ -9,9 +9,7 @@ import { resolvePage, safeSnapshot, withTimeout } from "./utils"; const assertionSchema = z.object({ assertionPassed: z.boolean().describe("Indicates whether the assertion passed or not."), - confidenceScore: z - .number() - .describe("Confidence score of the assertion, between 0 and 100."), + confidenceScore: z.number().describe("Confidence score of the assertion, between 0 and 100."), reasoning: z .string() .describe( @@ -57,6 +55,7 @@ export const assert = async ({ failSilently, maxRetries = 1, onRetry = (retryCount: number, previousResult: AssertionResult) => {}, + usageTracker, }: AssertionOptions): Promise => { const thinkingEnabled = effort === "high"; @@ -65,31 +64,33 @@ export const assert = async ({ const imageContent = images ? images.map((image) => ({ type: "image" as const, image })) : [ - { - type: "image" as const, - image: (await resolvePage(page).screenshot({ fullPage: false })).toString("base64"), - }, - ]; + { + type: "image" as const, + image: (await resolvePage(page).screenshot({ fullPage: false })).toString("base64"), + }, + ]; const basePrompt = ` You are an AI-powered QA Agent designed to test web applications. You have access to the following information. Based on this information, you'll tell us whether the assertion provided below should pass or not. -${!images - ? ` +${ + !images + ? ` - An accessibility snapshot of the current page, which provides a detailed structure of the DOM - A screenshot of the current page` - : "- Screenshots from various stages of the user flow" - } + : "- Screenshots from various stages of the user flow" +} -${!images - ? ` +${ + !images + ? ` ${snapshot} ` - : "" - } + : "" +} ${assertion} @@ -130,53 +131,68 @@ Never hallucinate. Be truthful and if you are not sure, use a low confidence sco // Claude assertion function const getClaudeAssertion = async (): Promise => { // First get Claude's text response with thinking if enabled - const { text } = await generateText({ + const { text, usage: textUsage } = await generateText({ model: resolveModel(getModelId("assertionPrimary")), temperature: 0, providerOptions: thinkingEnabled ? { - anthropic: { - thinking: { type: "enabled", budgetTokens: THINKING_BUDGET_DEFAULT }, - }, - openrouter: { - reasoning: { max_tokens: THINKING_BUDGET_DEFAULT }, - }, - } + anthropic: { + thinking: { type: "enabled", budgetTokens: THINKING_BUDGET_DEFAULT }, + }, + openrouter: { + reasoning: { max_tokens: THINKING_BUDGET_DEFAULT }, + }, + } : undefined, messages, }); + usageTracker?.record({ + model: getModelId("assertionPrimary"), + operation: "assertion", + usage: textUsage, + }); // Convert Claude's response to structured format using Haiku - const { output } = await generateText({ + const { output, usage: structuredUsage } = await generateText({ model: resolveModel(getModelId("assertionPrimary")), temperature: 0.1, prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, output: Output.object({ schema: assertionSchema }), }); + usageTracker?.record({ + model: getModelId("assertionPrimary"), + operation: "assertion", + usage: structuredUsage, + }); return output; }; // Gemini assertion function const getGeminiAssertion = async (): Promise => { - const { output } = await generateText({ + const { output, usage: geminiUsage } = await generateText({ model: resolveModel(getModelId("assertionSecondary")), temperature: 0, providerOptions: thinkingEnabled ? { - google: { - thinkingConfig: { - thinkingBudget: THINKING_BUDGET_DEFAULT, + google: { + thinkingConfig: { + thinkingBudget: THINKING_BUDGET_DEFAULT, + }, }, - }, - openrouter: { - reasoning: { max_tokens: THINKING_BUDGET_DEFAULT }, - }, - } + openrouter: { + reasoning: { max_tokens: THINKING_BUDGET_DEFAULT }, + }, + } : undefined, messages, output: Output.object({ schema: assertionSchema }), }); + usageTracker?.record({ + model: getModelId("assertionSecondary"), + operation: "assertion", + usage: geminiUsage, + }); return output; }; @@ -199,14 +215,15 @@ Gemini's Assessment: - Confidence: ${geminiResult.confidenceScore}% - Reasoning: ${geminiResult.reasoning} -${!images - ? ` +${ + !images + ? ` ${snapshot} ` - : "" - } + : "" +} ${assertion} @@ -241,7 +258,7 @@ Please carefully review the evidence (screenshot and accessibility snapshot (whe }, ]; - const { output } = await generateText({ + const { output, usage: arbiterUsage } = await generateText({ model: resolveModel(getModelId("assertionArbiter")), temperature: 0, providerOptions: { @@ -257,6 +274,11 @@ Please carefully review the evidence (screenshot and accessibility snapshot (whe messages: arbiterMessages, output: Output.object({ schema: assertionSchema }), }); + usageTracker?.record({ + model: getModelId("assertionArbiter"), + operation: "assertion", + usage: arbiterUsage, + }); return output; }; diff --git a/src/extract.ts b/src/extract.ts index abff50d..e5e72b0 100644 --- a/src/extract.ts +++ b/src/extract.ts @@ -2,6 +2,7 @@ import { generateText, Output } from "ai"; import { z } from "zod"; import { getModelId } from "./config"; import { resolveModel } from "./models"; +import type { UsageTracker } from "./usage"; const extractionSchema = z.object({ extractedValue: z.string().describe("The extracted value based on the prompt"), @@ -30,12 +31,14 @@ export async function extractDataWithAI({ snapshot, url, prompt, + usageTracker, }: { snapshot: string; url: string; prompt: string; + usageTracker?: UsageTracker; }): Promise { - const { output } = await generateText({ + const { output, usage } = await generateText({ model: resolveModel(getModelId("utility")), temperature: 0, output: Output.object({ schema: extractionSchema }), @@ -66,5 +69,7 @@ ${prompt} Return the extracted value.`, }); + usageTracker?.record({ model: getModelId("utility"), operation: "extraction", usage }); + return output.extractedValue; } diff --git a/src/index.ts b/src/index.ts index c069c9e..49aa2ed 100644 --- a/src/index.ts +++ b/src/index.ts @@ -45,6 +45,7 @@ import { runCUALoop, buildRunStepsPromptCUA, buildRunUserFlowPromptCUA } from ". import { extractDataWithAI } from "./extract"; import { logger } from "./logger"; import { resolveModel } from "./models"; +import { createUsageTracker, type UsageResult } from "./usage"; import { runSecureScript } from "./utils/secure-script-runner"; import { createTabManager } from "./utils/tab-manager"; import { @@ -104,7 +105,8 @@ export const runSteps = async ({ executionId, failAssertionsSilently, ai: callLevelAi, -}: RunStepsOptions) => { +}: RunStepsOptions): Promise => { + const usageTracker = createUsageTracker(); executionId = executionId || process.env.executionId; // Track all open tabs for this run. The active page is updated automatically @@ -126,7 +128,8 @@ export const runSteps = async ({ const isPlaywrightRetry = test ? test.info().retry > 0 : false; if (isPlaywrightRetry) { logger.debug( - `Playwright retry detected (retry #${test!.info().retry + `Playwright retry detected (retry #${ + test!.info().retry }). Bypassing cache and using AI only.`, ); } @@ -280,25 +283,21 @@ export const runSteps = async ({ } try { - await maybeWithSpan( - { capability: "step_execution", step: "cua_loop" }, - () => - runCUALoop({ - page: tabManager.active(), - instruction: buildRunStepsPromptCUA({ - auth, - steps: processedSteps, - step, - userFlow, - stepIndex: i, - }), - maxSteps: STEP_EXECUTION_MAX_STEPS, - abortSignal: AbortSignal.timeout(STEP_EXECUTION_TIMEOUT), - onReasoning: onReasoning - ? (reasoning) => onReasoning({ id, reasoning }) - : undefined, - gateway: effectiveAi.gateway, + await maybeWithSpan({ capability: "step_execution", step: "cua_loop" }, () => + runCUALoop({ + page: tabManager.active(), + instruction: buildRunStepsPromptCUA({ + auth, + steps: processedSteps, + step, + userFlow, + stepIndex: i, }), + maxSteps: STEP_EXECUTION_MAX_STEPS, + abortSignal: AbortSignal.timeout(STEP_EXECUTION_TIMEOUT), + onReasoning: onReasoning ? (reasoning) => onReasoning({ id, reasoning }) : undefined, + gateway: effectiveAi.gateway, + }), ); } catch (error: unknown) { logger.error({ err: error }, `CUA step execution failed: ${step.description}`); @@ -455,9 +454,9 @@ export const runSteps = async ({ let pageScreenshotBeforeApplyingAction: string = ""; if (step.waitUntil) { - pageScreenshotBeforeApplyingAction = (await tabManager.active().screenshot({ fullPage: false })).toString( - "base64", - ); + pageScreenshotBeforeApplyingAction = ( + await tabManager.active().screenshot({ fullPage: false }) + ).toString("base64"); } const stepModelId = effectiveAi.getModelId("stepExecution"); @@ -485,7 +484,7 @@ export const runSteps = async ({ openrouter: { reasoning: { effort: "medium", - exclude: true + exclude: true, }, }, }, @@ -514,6 +513,8 @@ export const runSteps = async ({ }), ); + usageTracker.record({ model: stepModelId, operation: "stepExecution", usage: result.usage }); + // Cache the step action only if it was a single tool call (simple, deterministic action). // Multi-step actions are not cached as they may be non-deterministic. const allToolCalls = result.steps @@ -544,6 +545,7 @@ export const runSteps = async ({ previousSteps: processedSteps.slice(0, i), currentStep: step, nextStep: processedSteps[i + 1], + usageTracker, }); } @@ -556,6 +558,7 @@ export const runSteps = async ({ snapshot, url, prompt: step.extract.prompt, + usageTracker, }); const placeholderKey = `{{run.${step.extract.as}}}` as keyof typeof localValues; (localValues as Record)[placeholderKey] = extracted; @@ -611,6 +614,7 @@ export const runSteps = async ({ failSilently: failAssertionsSilently, maxRetries: 1, onRetry: (retryCount, previousResult) => {}, + usageTracker, }); if (onReasoning) { @@ -625,6 +629,8 @@ export const runSteps = async ({ } } } + + return usageTracker.getResult(); }; /** @@ -660,6 +666,7 @@ export const runUserFlow = async ({ thinkingBudget = THINKING_BUDGET_DEFAULT, ai: callLevelAi, }: UserFlowOptions) => { + const usageTracker = createUsageTracker(); const abortController = new AbortController(); const effectiveAi = resolveAI(callLevelAi); @@ -680,12 +687,14 @@ export const runUserFlow = async ({ ); if (assertion) { - const { output } = await generateText({ + const { output, usage: parseUsage } = await generateText({ model: resolveModel(effectiveAi.getModelId("utility"), effectiveAi.gateway), prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, output: Output.object({ schema: z.object({ - assertionPassed: z.boolean().describe("Indicates whether the assertion passed or not."), + assertionPassed: z + .boolean() + .describe("Indicates whether the assertion passed or not."), confidenceScore: z .number() .describe("Confidence score of the assertion, between 0 and 100."), @@ -695,10 +704,15 @@ export const runUserFlow = async ({ }), }), }); - return output; + usageTracker.record({ + model: effectiveAi.getModelId("utility"), + operation: "userFlow", + usage: parseUsage, + }); + return { ...output, usage: usageTracker.getResult() }; } - return text; + return { text, usage: usageTracker.getResult() }; } catch (error: unknown) { logger.error({ err: error }, "Error during CUA user flow execution"); return; @@ -715,7 +729,7 @@ export const runUserFlow = async ({ }); try { - const { text } = await maybeWithSpan( + const { text, usage: flowUsage } = await maybeWithSpan( { capability: "user_flow_execution", step: "agentic_tool_calling" }, async () => { return generateText({ @@ -758,8 +772,14 @@ export const runUserFlow = async ({ }, ); + const flowModelId = + effort === "low" + ? effectiveAi.getModelId("userFlowLow") + : effectiveAi.getModelId("userFlowHigh"); + usageTracker.record({ model: flowModelId, operation: "userFlow", usage: flowUsage }); + if (assertion) { - const { output } = await generateText({ + const { output, usage: parseUsage } = await generateText({ model: resolveModel(effectiveAi.getModelId("utility"), effectiveAi.gateway), prompt: `Convert the following text output into a valid JSON object with the specified properties:\n\n${text}`, output: Output.object({ @@ -774,11 +794,16 @@ export const runUserFlow = async ({ }), }), }); + usageTracker.record({ + model: effectiveAi.getModelId("utility"), + operation: "userFlow", + usage: parseUsage, + }); - return output; + return { ...output, usage: usageTracker.getResult() }; } - return text; + return { text, usage: usageTracker.getResult() }; } catch (error: unknown) { logger.error({ err: error }, "Error during user flow execution"); } @@ -833,4 +858,13 @@ export { extractEmailContent, generateEmail } from "./email"; export { assert } from "./assertion"; export type { AssertionResult } from "./types"; -export { PassmarkError, StepExecutionError, ValidationError, AIModelError, CacheError, ConfigurationError } from "./errors"; +export { createUsageTracker } from "./usage"; +export type { TokenUsage, UsageResult } from "./usage"; +export { + PassmarkError, + StepExecutionError, + ValidationError, + AIModelError, + CacheError, + ConfigurationError, +} from "./errors"; diff --git a/src/types.ts b/src/types.ts index e4ff76e..0912ad6 100644 --- a/src/types.ts +++ b/src/types.ts @@ -9,6 +9,7 @@ import { TestType, } from "@playwright/test"; import type { AIOverride } from "./config"; +import type { UsageTracker } from "./usage"; import type { TabManager } from "./utils/tab-manager"; export type PageInput = Page | TabManager; @@ -84,6 +85,8 @@ export type AssertionOptions = { images?: string[]; maxRetries?: number; onRetry?: (retryCount: number, previousResult: AssertionResult) => void; + /** When provided, token usage from AI calls is recorded into this tracker. */ + usageTracker?: UsageTracker; }; export type WaitConditionResult = { @@ -101,6 +104,8 @@ export type WaitForConditionOptions = { initialInterval?: number; // Initial wait interval in ms which will be increased exponentially timeout?: number; // We'll stop trying after this time maxInterval?: number; // Maximum wait interval in ms + /** When provided, token usage from AI calls is recorded into this tracker. */ + usageTracker?: UsageTracker; }; export type RunStepsOptions = { @@ -135,9 +140,9 @@ export type RunStepsOptions = { */ ai?: AIOverride; } & ( - | { + | { assertions: Omit[]; expect: Expect<{}>; } - | { assertions?: never; expect?: never } - ); + | { assertions?: never; expect?: never } +); diff --git a/src/usage.ts b/src/usage.ts new file mode 100644 index 0000000..3b89c25 --- /dev/null +++ b/src/usage.ts @@ -0,0 +1,72 @@ +export type TokenUsage = { + model: string; + operation: string; + inputTokens: number; + outputTokens: number; + totalTokens: number; +}; + +export type UsageResult = { + details: TokenUsage[]; + totalInputTokens: number; + totalOutputTokens: number; + totalTokens: number; +}; + +type RecordInput = { + model: string; + operation: string; + usage: + | { + inputTokens?: number | undefined; + outputTokens?: number | undefined; + totalTokens?: number | undefined; + } + | undefined; +}; + +export type UsageTracker = { + record(input: RecordInput): void; + merge(other: UsageTracker): void; + getResult(): UsageResult; +}; + +export function createUsageTracker(): UsageTracker { + const details: TokenUsage[] = []; + + return { + record({ model, operation, usage }: RecordInput) { + details.push({ + model, + operation, + inputTokens: usage?.inputTokens ?? 0, + outputTokens: usage?.outputTokens ?? 0, + totalTokens: usage?.totalTokens ?? 0, + }); + }, + + merge(other: UsageTracker) { + const otherResult = other.getResult(); + details.push(...otherResult.details); + }, + + getResult(): UsageResult { + let totalInputTokens = 0; + let totalOutputTokens = 0; + let totalTokens = 0; + + for (const d of details) { + totalInputTokens += d.inputTokens; + totalOutputTokens += d.outputTokens; + totalTokens += d.totalTokens; + } + + return { + details: [...details], + totalInputTokens, + totalOutputTokens, + totalTokens, + }; + }, + }; +} diff --git a/src/utils/index.ts b/src/utils/index.ts index a637ec1..df8ec04 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -12,11 +12,7 @@ import { z } from "zod"; import { getModelId } from "../config"; import { logger } from "../logger"; import { resolveModel } from "../models"; -import { - PageInput, - WaitConditionResult, - WaitForConditionOptions, -} from "../types"; +import { PageInput, WaitConditionResult, WaitForConditionOptions } from "../types"; import type { TabManager } from "./tab-manager"; /** @@ -66,7 +62,7 @@ export const withTimeout = ( export const safeSnapshot = async (input: PageInput, timeout = SNAPSHOT_TIMEOUT) => { const attempt = async () => { return await resolvePage(input).ariaSnapshot({ mode: "ai", timeout }); - } + }; try { const snapshot = await attempt(); @@ -191,7 +187,9 @@ export async function waitForDOMStabilization( (error instanceof Error && error.message?.includes("navigation")) ) { // Navigation occurred - wait for the page to be ready - await resolvePage(input).waitForLoadState("domcontentloaded").catch(() => { }); + await resolvePage(input) + .waitForLoadState("domcontentloaded") + .catch(() => {}); return; } // Re-throw other errors @@ -256,6 +254,7 @@ export async function waitForCondition({ initialInterval = WAIT_CONDITION_INITIAL_INTERVAL, maxInterval = WAIT_CONDITION_MAX_INTERVAL, timeout = WAIT_CONDITION_TIMEOUT, + usageTracker, }: WaitForConditionOptions): Promise { await waitForDOMStabilization(page); // Ensure DOM is stable before starting @@ -263,9 +262,9 @@ export async function waitForCondition({ let currentInterval = initialInterval; const checkCondition = async (): Promise => { - const pageScreenshotAfterApplyingAction = (await resolvePage(page).screenshot({ fullPage: false })).toString( - "base64", - ); + const pageScreenshotAfterApplyingAction = ( + await resolvePage(page).screenshot({ fullPage: false }) + ).toString("base64"); const prompt = ` You are an AI-powered QA Agent designed to test web applications. @@ -273,15 +272,16 @@ You are an AI-powered QA Agent designed to test web applications. You are helping to determine if a wait condition has been met during a test flow. -${previousSteps.length > 0 - ? `Previous steps completed:\n${previousSteps - .map( - (s, i) => - `${i + 1}. ${s.description}\n${s.data ? ` Data: ${JSON.stringify(s.data)}` : ""}`, - ) - .join("\n")}` - : "No previous steps." - } +${ + previousSteps.length > 0 + ? `Previous steps completed:\n${previousSteps + .map( + (s, i) => + `${i + 1}. ${s.description}\n${s.data ? ` Data: ${JSON.stringify(s.data)}` : ""}`, + ) + .join("\n")}` + : "No previous steps." +} Last executed step: ${currentStep.description} ${nextStep ? `Next step: ${nextStep.description}` : ""} @@ -310,7 +310,7 @@ ${condition} Analyze the attached before and after screenshots and determine if the wait condition has been met. `; - const { output } = await generateText({ + const { output, usage } = await generateText({ model: resolveModel(getModelId("utility")), temperature: 0, messages: [ @@ -325,6 +325,7 @@ Analyze the attached before and after screenshots and determine if the wait cond ], output: Output.object({ schema: waitConditionSchema }), }); + usageTracker?.record({ model: getModelId("utility"), operation: "waitCondition", usage }); return output; };