Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- **Token usage tracking**: `runSteps()` now returns a `UsageResult` with per-call token breakdowns (`inputTokens`, `outputTokens`, `totalTokens`) and aggregate totals. `runUserFlow()` includes a `usage` field in its return value. New exported types: `TokenUsage`, `UsageResult`, `UsageTracker`. New exported function: `createUsageTracker()`. Assertion, extraction, and wait-condition AI calls all participate in tracking when a `usageTracker` is provided.
- `maxRetries` option to `AssertionOptions` (default: `1`) to control how many times a failed assertion is retried with a fresh page snapshot and screenshot. Setting it to `0` disables retries.
- `onRetry` callback to `AssertionOptions` that fires before each retry, receiving the retry index and the full `AssertionResult` from the previous attempt for debugging flaky assertions.
- **CUA mode** (`configure({ ai: { mode: "cua" } })`): execute `runSteps` and `runUserFlow` through OpenAI's Responses API with the built-in `computer` tool. Screenshot-driven, coordinate-based actions via Playwright's `page.mouse` / `page.keyboard`. Requires `OPENAI_API_KEY` and `gateway: "none"`; Redis step caching is skipped in this mode because coordinate actions aren't portable across viewport sizes.
Expand Down
74 changes: 67 additions & 7 deletions src/__tests__/assertion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ vi.mock("../utils", () => ({
import { assert } from "../assertion";
import { withTimeout } from "../utils";
import { generateText } from "ai";
import { createUsageTracker } from "../usage";

function createMockPage() {
return {
Expand Down Expand Up @@ -67,13 +68,17 @@ function makeGenerateTextImpl(opts: {
return { output: opts.claude } as any;
}
if (model.includes("gemini-3-flash")) {
const g = typeof opts.gemini === "function" ? (opts.gemini as () => AssertionObj)() : opts.gemini;
const g =
typeof opts.gemini === "function" ? (opts.gemini as () => AssertionObj)() : opts.gemini;
return { output: g } as any;
}
if (model.includes("3.1-pro-preview")) {
return {
output:
opts.arbiter ?? { assertionPassed: false, confidenceScore: 0, reasoning: "no arbiter set" },
output: opts.arbiter ?? {
assertionPassed: false,
confidenceScore: 0,
reasoning: "no arbiter set",
},
} as any;
}
return { output: { assertionPassed: false, confidenceScore: 0, reasoning: "unknown" } } as any;
Expand Down Expand Up @@ -137,7 +142,11 @@ describe("assert consensus logic", () => {
makeGenerateTextImpl({
claude: { assertionPassed: true, confidenceScore: 95, reasoning: "Claude: yes" },
gemini: { assertionPassed: false, confidenceScore: 30, reasoning: "Gemini: no" },
arbiter: { assertionPassed: true, confidenceScore: 70, reasoning: "Arbiter: I side with Claude" },
arbiter: {
assertionPassed: true,
confidenceScore: 70,
reasoning: "Arbiter: I side with Claude",
},
}) as any,
);

Expand All @@ -160,7 +169,11 @@ describe("assert consensus logic", () => {
makeGenerateTextImpl({
claude: { assertionPassed: true, confidenceScore: 60, reasoning: "Claude: yes" },
gemini: { assertionPassed: false, confidenceScore: 40, reasoning: "Gemini: no" },
arbiter: { assertionPassed: false, confidenceScore: 45, reasoning: "Arbiter: I disagree, it fails" },
arbiter: {
assertionPassed: false,
confidenceScore: 45,
reasoning: "Arbiter: I disagree, it fails",
},
}) as any,
);

Expand Down Expand Up @@ -188,7 +201,11 @@ describe("assert consensus logic", () => {
if (geminiCalls === 1) {
throw new Error("transient model error");
}
return { assertionPassed: true, confidenceScore: 80, reasoning: "Gemini: ok after retry" };
return {
assertionPassed: true,
confidenceScore: 80,
reasoning: "Gemini: ok after retry",
};
},
}) as any,
);
Expand All @@ -210,7 +227,9 @@ describe("assert consensus logic", () => {
const page = createMockPage();

// Make withTimeout reject once to simulate timeout
vi.mocked(withTimeout).mockImplementationOnce(() => Promise.reject(new Error("timed out")) as any);
vi.mocked(withTimeout).mockImplementationOnce(
() => Promise.reject(new Error("timed out")) as any,
);

vi.mocked(generateText).mockImplementation(
makeGenerateTextImpl({
Expand All @@ -229,4 +248,45 @@ describe("assert consensus logic", () => {

expect(res).toContain("✅ passed");
});

it("records usage data when usageTracker is provided", async () => {
const page = createMockPage();
const tracker = createUsageTracker();

// Return usage data from mock generateText calls
vi.mocked(generateText).mockImplementation(async (args: any) => {
const model = String(args.model ?? "");
const wantsStructured = Boolean(args.output);
const usage = { inputTokens: 100, outputTokens: 50, totalTokens: 150 };

if (!wantsStructured) {
return { text: "claude text", usage } as any;
}
if (model.includes("anthropic")) {
return {
output: { assertionPassed: true, confidenceScore: 90, reasoning: "Claude: ok" },
usage,
} as any;
}
return {
output: { assertionPassed: true, confidenceScore: 80, reasoning: "Gemini: ok" },
usage,
} as any;
});

await assert({
page,
assertion: "The page shows items",
test: mockTest,
expect: ((a: unknown, _m?: string) => ({ toBe: (_v: unknown) => {} })) as any,
failSilently: true,
usageTracker: tracker,
});

const result = tracker.getResult();
// Claude text call + Claude structured call + Gemini structured call = 3 calls
expect(result.details.length).toBeGreaterThanOrEqual(3);
expect(result.totalTokens).toBeGreaterThan(0);
expect(result.details.every((d) => d.operation === "assertion")).toBe(true);
});
});
63 changes: 63 additions & 0 deletions src/__tests__/extract-usage.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { describe, it, expect, vi, beforeEach } from "vitest";

vi.mock("../instrumentation", () => ({ axiomEnabled: false }));

vi.mock("../models", () => ({
resolveModel: (id: string) => id,
}));

vi.mock("ai", async (importOriginal) => {
const actual = await importOriginal<typeof import("ai")>();
return {
...actual,
generateText: vi.fn(),
};
});

import { extractDataWithAI } from "../extract";
import { createUsageTracker } from "../usage";
import { generateText } from "ai";

beforeEach(() => {
vi.clearAllMocks();
});

describe("extractDataWithAI usage tracking", () => {
it("records usage when usageTracker is provided", async () => {
const tracker = createUsageTracker();

vi.mocked(generateText).mockResolvedValue({
output: { extractedValue: "abc123" },
usage: { inputTokens: 200, outputTokens: 30, totalTokens: 230 },
} as any);

const result = await extractDataWithAI({
snapshot: "page content",
url: "https://example.com?token=abc123",
prompt: "Extract the token query parameter",
usageTracker: tracker,
});

expect(result).toBe("abc123");

const usage = tracker.getResult();
expect(usage.details).toHaveLength(1);
expect(usage.details[0].operation).toBe("extraction");
expect(usage.totalTokens).toBe(230);
});

it("works without usageTracker (backward compatible)", async () => {
vi.mocked(generateText).mockResolvedValue({
output: { extractedValue: "value" },
usage: { inputTokens: 100, outputTokens: 20, totalTokens: 120 },
} as any);

const result = await extractDataWithAI({
snapshot: "content",
url: "https://example.com",
prompt: "Extract something",
});

expect(result).toBe("value");
});
});
51 changes: 47 additions & 4 deletions src/__tests__/integration/run-steps.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ import { generateText } from "ai";
import { runCUALoop } from "../../cua";
import type { Page } from "@playwright/test";
import type { Step } from "../../types";
import type { UsageResult } from "../../usage";

function createMockPage() {
const mockLocator = {
Expand Down Expand Up @@ -325,10 +326,7 @@ describe("runSteps", () => {
it("call-level ai option applies to all steps without per-step override", async () => {
const page = createMockPage();

const steps: Step[] = [
{ description: "Step A" },
{ description: "Step B" },
];
const steps: Step[] = [{ description: "Step A" }, { description: "Step B" }];

await runSteps({
page,
Expand Down Expand Up @@ -382,4 +380,49 @@ describe("runSteps", () => {
// generateText should be called because the step has bypassCache: true
expect(generateText).toHaveBeenCalled();
});

it("returns usage data from AI step execution", async () => {
const page = createMockPage();
const steps: Step[] = [{ description: "Step 1" }, { description: "Step 2" }];

vi.mocked(generateText).mockResolvedValue({
text: "done",
steps: [],
usage: { inputTokens: 500, outputTokens: 100, totalTokens: 600 },
} as any);

const result = await runSteps({
page,
userFlow: "usage tracking flow",
steps,
});

expect(result).toBeDefined();
const usage = result as UsageResult;
expect(usage.details).toHaveLength(2);
expect(usage.totalTokens).toBe(1200);
expect(usage.details[0].operation).toBe("stepExecution");
});

it("returns zero-token usage for cached steps", async () => {
const page = createMockPage();
const steps: Step[] = [{ description: "Cached step" }];

vi.mocked(redis!.hgetall).mockResolvedValue({
locator: 'getByRole("button", { name: "Go" })',
action: "click",
description: "Go button",
value: "",
});

const result = await runSteps({
page,
userFlow: "cached usage flow",
steps,
});

expect(result).toBeDefined();
const usage = result as UsageResult;
expect(usage.totalTokens).toBe(0);
});
});
Loading