diff --git a/src/__tests__/context.test.tsx b/src/__tests__/context.test.tsx index 68c2bff..854c4b4 100644 --- a/src/__tests__/context.test.tsx +++ b/src/__tests__/context.test.tsx @@ -78,6 +78,34 @@ describe("WebMCPProvider availability", () => { deleteNativeModelContext(); } }); + + it("available is true when native API lacks unregisterTool (Chrome 148+)", async () => { + const native = { + registerTool() {}, + // no unregisterTool — simulates Chrome 148+ + }; + Object.defineProperty(navigator, "modelContext", { + value: native, + configurable: true, + enumerable: true, + writable: false, + }); + + try { + const { getByTestId } = render( + + + , + ); + + await waitFor(() => { + expect(getByTestId("status")).toHaveTextContent("yes"); + }); + expect(navigator.modelContext).toBe(native); + } finally { + deleteNativeModelContext(); + } + }); }); // ─── Polyfill lifecycle ─────────────────────────────────────────── diff --git a/src/hooks/__tests__/useMcpTool.test.tsx b/src/hooks/__tests__/useMcpTool.test.tsx index 8dee248..2a4510b 100644 --- a/src/hooks/__tests__/useMcpTool.test.tsx +++ b/src/hooks/__tests__/useMcpTool.test.tsx @@ -1090,3 +1090,140 @@ describe("provider warning", () => { ); }); }); + +// ─── Signal-only native API (Chrome 148+) ──────────────────────── + +describe("signal-only native API (no unregisterTool)", () => { + type ToolEntry = { name: string; abortCleanup: () => void }; + + function installSignalOnlyNative() { + const registered: ToolEntry[] = []; + + const native = { + registerTool( + tool: { name: string; [key: string]: unknown }, + opts?: { signal?: AbortSignal }, + ) { + const entry: ToolEntry = { name: tool.name, abortCleanup: () => {} }; + registered.push(entry); + if (opts?.signal) { + const handler = () => { + const idx = registered.findIndex((t) => t.name === tool.name); + if (idx !== -1) registered.splice(idx, 1); + }; + opts.signal.addEventListener("abort", handler, { once: true }); + entry.abortCleanup = () => opts.signal?.removeEventListener("abort", handler); + } + }, + // no unregisterTool — simulates Chrome 148+ + }; + + Object.defineProperty(navigator, "modelContext", { + value: native, + configurable: true, + enumerable: true, + writable: true, + }); + + return registered; + } + + function deleteModelContext() { + const desc = Object.getOwnPropertyDescriptor(navigator, "modelContext"); + if (desc) { + Object.defineProperty(navigator, "modelContext", { + value: undefined, + configurable: true, + writable: true, + }); + delete navigator.modelContext; + } + } + + afterEach(() => { + deleteModelContext(); + }); + + it("registers and unregisters via abort on mount/unmount", async () => { + const registered = installSignalOnlyNative(); + + const { unmount } = render( + OK_RESULT, + }} + />, + ); + + await act(async () => {}); + expect(registered.some((t) => t.name === "greet")).toBe(true); + + unmount(); + await act(async () => {}); + expect(registered.some((t) => t.name === "greet")).toBe(false); + }); + + it("handles StrictMode double-mount with signal-only API", async () => { + const registered = installSignalOnlyNative(); + + const { unmount } = render( + + OK_RESULT, + }} + /> + , + ); + + await act(async () => {}); + const greetTools = registered.filter((t) => t.name === "greet"); + expect(greetTools.length).toBe(1); + + unmount(); + await act(async () => {}); + expect(registered.some((t) => t.name === "greet")).toBe(false); + }); + + it("re-registers with fresh signal on prop change", async () => { + const registered = installSignalOnlyNative(); + const registerSpy = vi.spyOn( + navigator.modelContext as NonNullable, + "registerTool", + ); + + const { rerender } = render( + OK_RESULT, + }} + />, + ); + + await act(async () => {}); + expect(registerSpy).toHaveBeenCalledTimes(1); + expect(registered.some((t) => t.name === "greet")).toBe(true); + + // Change description to trigger re-registration + rerender( + OK_RESULT, + }} + />, + ); + + await act(async () => {}); + expect(registerSpy).toHaveBeenCalledTimes(2); + // Old signal aborted old entry, new entry registered + expect(registered.filter((t) => t.name === "greet").length).toBe(1); + }); +}); diff --git a/src/hooks/useMcpTool.ts b/src/hooks/useMcpTool.ts index e9ed4fb..3ea3f59 100644 --- a/src/hooks/useMcpTool.ts +++ b/src/hooks/useMcpTool.ts @@ -223,8 +223,10 @@ export function useMcpTool( }, }; + const controller = new AbortController(); + try { - mc.registerTool(descriptor); + mc.registerTool(descriptor, { signal: controller.signal }); } catch (err) { warnOnce( `register-${cfg.name}`, @@ -241,7 +243,8 @@ export function useMcpTool( return; } TOOL_OWNER_BY_NAME.delete(cfg.name); - mc.unregisterTool(cfg.name); + mc.unregisterTool?.(cfg.name); + controller.abort(); }; }, [ ctx.available, diff --git a/src/index.ts b/src/index.ts index b576bf6..7bfbd8a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,6 +4,7 @@ export type { CallToolResult, McpToolConfigJsonSchema, McpToolConfigZod, + RegisterToolOptions, ToolAnnotations, ToolExecutionState, UseMcpToolReturn, diff --git a/src/polyfill/__tests__/registry.test.ts b/src/polyfill/__tests__/registry.test.ts index d6246ce..892f1e9 100644 --- a/src/polyfill/__tests__/registry.test.ts +++ b/src/polyfill/__tests__/registry.test.ts @@ -124,4 +124,95 @@ describe("createRegistry", () => { const stored = registry.getTools().get("test_tool"); expect(stored?.description).toBe("A test tool"); }); + + describe("AbortSignal support", () => { + it("removes tool when abort signal fires", async () => { + const registry = createRegistry(); + const cb = vi.fn(); + registry.onToolsChanged(cb); + + const controller = new AbortController(); + registry.registerTool(makeTool(), { signal: controller.signal }); + expect(registry.getTools().has("test_tool")).toBe(true); + + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(1); + + controller.abort(); + expect(registry.getTools().has("test_tool")).toBe(false); + + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(2); + }); + + it("skips registration when signal is already aborted", () => { + const registry = createRegistry(); + registry.registerTool(makeTool(), { signal: AbortSignal.abort() }); + expect(registry.getTools().has("test_tool")).toBe(false); + }); + + it("abort after manual unregisterTool is a safe no-op", async () => { + const registry = createRegistry(); + const cb = vi.fn(); + registry.onToolsChanged(cb); + + const controller = new AbortController(); + registry.registerTool(makeTool(), { signal: controller.signal }); + + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(1); + + registry.unregisterTool("test_tool"); + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(2); + + controller.abort(); + await Promise.resolve(); + // no extra notification — tool was already removed + expect(cb).toHaveBeenCalledTimes(2); + }); + + it("fires notification via microtask when abort removes a tool", async () => { + const registry = createRegistry(); + const cb = vi.fn(); + registry.onToolsChanged(cb); + + const controller = new AbortController(); + registry.registerTool(makeTool(), { signal: controller.signal }); + await Promise.resolve(); + + controller.abort(); + // notification not yet fired (queued as microtask) + expect(cb).toHaveBeenCalledTimes(1); + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(2); + }); + + it("stale abort does not remove a same-name re-registration", async () => { + const registry = createRegistry(); + const cb = vi.fn(); + registry.onToolsChanged(cb); + + const controller1 = new AbortController(); + registry.registerTool(makeTool(), { signal: controller1.signal }); + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(1); + + // Unregister, then re-register with a new signal + registry.unregisterTool("test_tool"); + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(2); + + const controller2 = new AbortController(); + registry.registerTool(makeTool(), { signal: controller2.signal }); + await Promise.resolve(); + expect(cb).toHaveBeenCalledTimes(3); + + // Abort the OLD signal — should NOT remove the new registration + controller1.abort(); + await Promise.resolve(); + expect(registry.getTools().has("test_tool")).toBe(true); + expect(cb).toHaveBeenCalledTimes(3); + }); + }); }); diff --git a/src/polyfill/registry.ts b/src/polyfill/registry.ts index 53146c8..f6b3699 100644 --- a/src/polyfill/registry.ts +++ b/src/polyfill/registry.ts @@ -1,7 +1,7 @@ -import type { ToolDescriptor } from "../types"; +import type { RegisterToolOptions, ToolDescriptor } from "../types"; export interface RegistryInternal { - registerTool(tool: ToolDescriptor): void; + registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void; unregisterTool(name: string): void; getTools(): ReadonlyMap; onToolsChanged(callback: (() => void) | null): void; @@ -22,7 +22,7 @@ export function createRegistry(): RegistryInternal { } return { - registerTool(tool: ToolDescriptor): void { + registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void { if (typeof tool.name !== "string" || tool.name === "") { throw new DOMException("Tool name must be a non-empty string", "InvalidStateError"); } @@ -36,11 +36,29 @@ export function createRegistry(): RegistryInternal { throw new DOMException(`Tool "${tool.name}" is already registered`, "InvalidStateError"); } + if (options?.signal?.aborted) { + return; + } + tools.set(tool.name, { ...tool, inputSchema: tool.inputSchema ?? { type: "object", properties: {} }, }); scheduleNotification(); + + if (options?.signal) { + const name = tool.name; + const stored = tools.get(name); + options.signal.addEventListener( + "abort", + () => { + if (tools.get(name) === stored && tools.delete(name)) { + scheduleNotification(); + } + }, + { once: true }, + ); + } }, unregisterTool(name: string): void { diff --git a/src/types.ts b/src/types.ts index 922eea6..ba21e9a 100644 --- a/src/types.ts +++ b/src/types.ts @@ -136,9 +136,13 @@ export interface WebMCPStatus { available: boolean; } +export interface RegisterToolOptions { + signal?: AbortSignal; +} + export interface ModelContext { - registerTool(tool: ToolDescriptor): void; - unregisterTool(name: string): void; + registerTool(tool: ToolDescriptor, options?: RegisterToolOptions): void; + unregisterTool?(name: string): void; } export interface ModelContextTestingToolInfo {