Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/__tests__/context.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,34 @@ describe("WebMCPProvider availability", () => {
deleteNativeModelContext();
}
});

it("available is true when native API lacks unregisterTool (Chrome 148+)", async () => {
const native = {
registerTool() {},
// no unregisterTool — simulates Chrome 148+
};
Object.defineProperty(navigator, "modelContext", {
value: native,
configurable: true,
enumerable: true,
writable: false,
});

try {
const { getByTestId } = render(
<WebMCPProvider name="test" version="1.0">
<StatusDisplay />
</WebMCPProvider>,
);

await waitFor(() => {
expect(getByTestId("status")).toHaveTextContent("yes");
});
expect(navigator.modelContext).toBe(native);
} finally {
deleteNativeModelContext();
}
});
});

// ─── Polyfill lifecycle ───────────────────────────────────────────
Expand Down
137 changes: 137 additions & 0 deletions src/hooks/__tests__/useMcpTool.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1090,3 +1090,140 @@ describe("provider warning", () => {
);
});
});

// ─── Signal-only native API (Chrome 148+) ────────────────────────

describe("signal-only native API (no unregisterTool)", () => {
type ToolEntry = { name: string; abortCleanup: () => void };

function installSignalOnlyNative() {
const registered: ToolEntry[] = [];

const native = {
registerTool(
tool: { name: string; [key: string]: unknown },
opts?: { signal?: AbortSignal },
) {
const entry: ToolEntry = { name: tool.name, abortCleanup: () => {} };
registered.push(entry);
if (opts?.signal) {
const handler = () => {
const idx = registered.findIndex((t) => t.name === tool.name);
if (idx !== -1) registered.splice(idx, 1);
};
opts.signal.addEventListener("abort", handler, { once: true });
entry.abortCleanup = () => opts.signal?.removeEventListener("abort", handler);
}
},
// no unregisterTool — simulates Chrome 148+
};

Object.defineProperty(navigator, "modelContext", {
value: native,
configurable: true,
enumerable: true,
writable: true,
});

return registered;
}

function deleteModelContext() {
const desc = Object.getOwnPropertyDescriptor(navigator, "modelContext");
if (desc) {
Object.defineProperty(navigator, "modelContext", {
value: undefined,
configurable: true,
writable: true,
});
delete navigator.modelContext;
}
}

afterEach(() => {
deleteModelContext();
});

it("registers and unregisters via abort on mount/unmount", async () => {
const registered = installSignalOnlyNative();

const { unmount } = render(
<ToolComponent
config={{
name: "greet",
description: "Say hello",
handler: async () => OK_RESULT,
}}
/>,
);

await act(async () => {});
expect(registered.some((t) => t.name === "greet")).toBe(true);

unmount();
await act(async () => {});
expect(registered.some((t) => t.name === "greet")).toBe(false);
});

it("handles StrictMode double-mount with signal-only API", async () => {
const registered = installSignalOnlyNative();

const { unmount } = render(
<StrictMode>
<ToolComponent
config={{
name: "greet",
description: "Say hello",
handler: async () => OK_RESULT,
}}
/>
</StrictMode>,
);

await act(async () => {});
const greetTools = registered.filter((t) => t.name === "greet");
expect(greetTools.length).toBe(1);

unmount();
await act(async () => {});
expect(registered.some((t) => t.name === "greet")).toBe(false);
});

it("re-registers with fresh signal on prop change", async () => {
const registered = installSignalOnlyNative();
const registerSpy = vi.spyOn(
navigator.modelContext as NonNullable<typeof navigator.modelContext>,
"registerTool",
);

const { rerender } = render(
<ToolComponent
config={{
name: "greet",
description: "Say hello",
handler: async () => OK_RESULT,
}}
/>,
);

await act(async () => {});
expect(registerSpy).toHaveBeenCalledTimes(1);
expect(registered.some((t) => t.name === "greet")).toBe(true);

// Change description to trigger re-registration
rerender(
<ToolComponent
config={{
name: "greet",
description: "Say hello v2",
handler: async () => OK_RESULT,
}}
/>,
);

await act(async () => {});
expect(registerSpy).toHaveBeenCalledTimes(2);
// Old signal aborted old entry, new entry registered
expect(registered.filter((t) => t.name === "greet").length).toBe(1);
});
});
7 changes: 5 additions & 2 deletions src/hooks/useMcpTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ export function useMcpTool(
},
};

const controller = new AbortController();

try {
mc.registerTool(descriptor);
mc.registerTool(descriptor, { signal: controller.signal });
} catch (err) {
warnOnce(
`register-${cfg.name}`,
Expand All @@ -241,7 +243,8 @@ export function useMcpTool(
return;
}
TOOL_OWNER_BY_NAME.delete(cfg.name);
mc.unregisterTool(cfg.name);
mc.unregisterTool?.(cfg.name);
controller.abort();
};
}, [
ctx.available,
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export type {
CallToolResult,
McpToolConfigJsonSchema,
McpToolConfigZod,
RegisterToolOptions,
ToolAnnotations,
ToolExecutionState,
UseMcpToolReturn,
Expand Down
91 changes: 91 additions & 0 deletions src/polyfill/__tests__/registry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,95 @@ describe("createRegistry", () => {
const stored = registry.getTools().get("test_tool");
expect(stored?.description).toBe("A test tool");
});

describe("AbortSignal support", () => {
it("removes tool when abort signal fires", async () => {
const registry = createRegistry();
const cb = vi.fn();
registry.onToolsChanged(cb);

const controller = new AbortController();
registry.registerTool(makeTool(), { signal: controller.signal });
expect(registry.getTools().has("test_tool")).toBe(true);

await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(1);

controller.abort();
expect(registry.getTools().has("test_tool")).toBe(false);

await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(2);
});

it("skips registration when signal is already aborted", () => {
const registry = createRegistry();
registry.registerTool(makeTool(), { signal: AbortSignal.abort() });
expect(registry.getTools().has("test_tool")).toBe(false);
});

it("abort after manual unregisterTool is a safe no-op", async () => {
const registry = createRegistry();
const cb = vi.fn();
registry.onToolsChanged(cb);

const controller = new AbortController();
registry.registerTool(makeTool(), { signal: controller.signal });

await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(1);

registry.unregisterTool("test_tool");
await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(2);

controller.abort();
await Promise.resolve();
// no extra notification — tool was already removed
expect(cb).toHaveBeenCalledTimes(2);
});

it("fires notification via microtask when abort removes a tool", async () => {
const registry = createRegistry();
const cb = vi.fn();
registry.onToolsChanged(cb);

const controller = new AbortController();
registry.registerTool(makeTool(), { signal: controller.signal });
await Promise.resolve();

controller.abort();
// notification not yet fired (queued as microtask)
expect(cb).toHaveBeenCalledTimes(1);
await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(2);
});

it("stale abort does not remove a same-name re-registration", async () => {
const registry = createRegistry();
const cb = vi.fn();
registry.onToolsChanged(cb);

const controller1 = new AbortController();
registry.registerTool(makeTool(), { signal: controller1.signal });
await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(1);

// Unregister, then re-register with a new signal
registry.unregisterTool("test_tool");
await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(2);

const controller2 = new AbortController();
registry.registerTool(makeTool(), { signal: controller2.signal });
await Promise.resolve();
expect(cb).toHaveBeenCalledTimes(3);

// Abort the OLD signal — should NOT remove the new registration
controller1.abort();
await Promise.resolve();
expect(registry.getTools().has("test_tool")).toBe(true);
expect(cb).toHaveBeenCalledTimes(3);
});
});
});
24 changes: 21 additions & 3 deletions src/polyfill/registry.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ToolDescriptor } from "../types";
import type { RegisterToolOptions, ToolDescriptor } from "../types";

export interface RegistryInternal {
registerTool(tool: ToolDescriptor): void;
registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void;
unregisterTool(name: string): void;
getTools(): ReadonlyMap<string, ToolDescriptor>;
onToolsChanged(callback: (() => void) | null): void;
Expand All @@ -22,7 +22,7 @@ export function createRegistry(): RegistryInternal {
}

return {
registerTool(tool: ToolDescriptor): void {
registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void {
if (typeof tool.name !== "string" || tool.name === "") {
throw new DOMException("Tool name must be a non-empty string", "InvalidStateError");
}
Expand All @@ -36,11 +36,29 @@ export function createRegistry(): RegistryInternal {
throw new DOMException(`Tool "${tool.name}" is already registered`, "InvalidStateError");
}

if (options?.signal?.aborted) {
return;
}

tools.set(tool.name, {
...tool,
inputSchema: tool.inputSchema ?? { type: "object", properties: {} },
});
scheduleNotification();

if (options?.signal) {
const name = tool.name;
const stored = tools.get(name);
options.signal.addEventListener(
"abort",
() => {
if (tools.get(name) === stored && tools.delete(name)) {
scheduleNotification();
}
},
{ once: true },
);
}
},

unregisterTool(name: string): void {
Expand Down
8 changes: 6 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,13 @@ export interface WebMCPStatus {
available: boolean;
}

export interface RegisterToolOptions {
signal?: AbortSignal;
}

export interface ModelContext {
registerTool(tool: ToolDescriptor): void;
unregisterTool(name: string): void;
registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void;
unregisterTool?(name: string): void;
}

export interface ModelContextTestingToolInfo {
Expand Down
Loading