|
| 1 | +import fs from 'node:fs/promises' |
| 2 | +import path from 'node:path' |
| 3 | +import { invariant } from '@epic-web/invariant' |
| 4 | +import { faker } from '@faker-js/faker' |
| 5 | +import { Client } from '@modelcontextprotocol/sdk/client/index.js' |
| 6 | +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' |
| 7 | +import { |
| 8 | + CreateMessageRequestSchema, |
| 9 | + type CreateMessageResult, |
| 10 | +} from '@modelcontextprotocol/sdk/types.js' |
| 11 | +import { test, beforeAll, afterAll, expect } from 'vitest' |
| 12 | +import { type z } from 'zod' |
| 13 | + |
| 14 | +let client: Client |
| 15 | +const EPIC_ME_DB_PATH = `./test.ignored/db.${process.env.VITEST_WORKER_ID}.sqlite` |
| 16 | + |
| 17 | +beforeAll(async () => { |
| 18 | + const dir = path.dirname(EPIC_ME_DB_PATH) |
| 19 | + await fs.mkdir(dir, { recursive: true }) |
| 20 | + client = new Client( |
| 21 | + { |
| 22 | + name: 'EpicMeTester', |
| 23 | + version: '1.0.0', |
| 24 | + }, |
| 25 | + { |
| 26 | + capabilities: { |
| 27 | + sampling: {}, |
| 28 | + }, |
| 29 | + }, |
| 30 | + ) |
| 31 | + const transport = new StdioClientTransport({ |
| 32 | + command: 'tsx', |
| 33 | + args: ['src/index.ts'], |
| 34 | + env: { |
| 35 | + ...process.env, |
| 36 | + EPIC_ME_DB_PATH, |
| 37 | + }, |
| 38 | + }) |
| 39 | + await client.connect(transport) |
| 40 | +}) |
| 41 | + |
| 42 | +afterAll(async () => { |
| 43 | + await client.transport?.close() |
| 44 | + await fs.unlink(EPIC_ME_DB_PATH) |
| 45 | +}) |
| 46 | + |
| 47 | +test('Tool Definition', async () => { |
| 48 | + const list = await client.listTools() |
| 49 | + const [firstTool] = list.tools |
| 50 | + invariant(firstTool, '🚨 No tools found') |
| 51 | + |
| 52 | + expect(firstTool).toEqual( |
| 53 | + expect.objectContaining({ |
| 54 | + name: expect.stringMatching(/^create_entry$/i), |
| 55 | + description: expect.stringMatching(/^create a new journal entry$/i), |
| 56 | + inputSchema: expect.objectContaining({ |
| 57 | + type: 'object', |
| 58 | + properties: expect.objectContaining({ |
| 59 | + title: expect.objectContaining({ |
| 60 | + type: 'string', |
| 61 | + description: expect.stringMatching(/title/i), |
| 62 | + }), |
| 63 | + content: expect.objectContaining({ |
| 64 | + type: 'string', |
| 65 | + description: expect.stringMatching(/content/i), |
| 66 | + }), |
| 67 | + }), |
| 68 | + }), |
| 69 | + }), |
| 70 | + ) |
| 71 | +}) |
| 72 | + |
| 73 | +async function deferred<ResolvedValue>() { |
| 74 | + const ref = {} as { |
| 75 | + promise: Promise<ResolvedValue> |
| 76 | + resolve: (value: ResolvedValue) => void |
| 77 | + reject: (reason?: any) => void |
| 78 | + value: ResolvedValue | undefined |
| 79 | + reason: any | undefined |
| 80 | + } |
| 81 | + ref.promise = new Promise<ResolvedValue>((resolve, reject) => { |
| 82 | + ref.resolve = (value) => { |
| 83 | + ref.value = value |
| 84 | + resolve(value) |
| 85 | + } |
| 86 | + ref.reject = (reason) => { |
| 87 | + ref.reason = reason |
| 88 | + reject(reason) |
| 89 | + } |
| 90 | + }) |
| 91 | + |
| 92 | + return ref |
| 93 | +} |
| 94 | + |
| 95 | +test('Sampling', async () => { |
| 96 | + const messageResultDeferred = await deferred<CreateMessageResult>() |
| 97 | + const messageRequestDeferred = |
| 98 | + await deferred<z.infer<typeof CreateMessageRequestSchema>>() |
| 99 | + |
| 100 | + client.setRequestHandler(CreateMessageRequestSchema, (r) => { |
| 101 | + messageRequestDeferred.resolve(r) |
| 102 | + return messageResultDeferred.promise |
| 103 | + }) |
| 104 | + |
| 105 | + const fakeTag1 = { |
| 106 | + name: faker.lorem.word(), |
| 107 | + description: faker.lorem.sentence(), |
| 108 | + } |
| 109 | + const fakeTag2 = { |
| 110 | + name: faker.lorem.word(), |
| 111 | + description: faker.lorem.sentence(), |
| 112 | + } |
| 113 | + |
| 114 | + const result = await client.callTool({ |
| 115 | + name: 'create_tag', |
| 116 | + arguments: fakeTag1, |
| 117 | + }) |
| 118 | + const tag1Resource = (result.content as any).find( |
| 119 | + (c: any) => c.type === 'resource', |
| 120 | + )?.resource |
| 121 | + invariant(tag1Resource, '🚨 No tag1 resource found') |
| 122 | + const newTag1 = JSON.parse(tag1Resource.text) as any |
| 123 | + invariant(newTag1.id, '🚨 No new tag1 found') |
| 124 | + |
| 125 | + const entry = { |
| 126 | + title: faker.lorem.words(3), |
| 127 | + content: faker.lorem.paragraphs(2), |
| 128 | + } |
| 129 | + await client.callTool({ |
| 130 | + name: 'create_entry', |
| 131 | + arguments: entry, |
| 132 | + }) |
| 133 | + const request = await messageRequestDeferred.promise |
| 134 | + |
| 135 | + expect(request).toEqual( |
| 136 | + expect.objectContaining({ |
| 137 | + method: 'sampling/createMessage', |
| 138 | + params: expect.objectContaining({ |
| 139 | + maxTokens: expect.any(Number), |
| 140 | + systemPrompt: expect.stringMatching(/example/i), |
| 141 | + messages: expect.arrayContaining([ |
| 142 | + expect.objectContaining({ |
| 143 | + role: 'user', |
| 144 | + content: expect.objectContaining({ |
| 145 | + type: 'text', |
| 146 | + text: expect.stringMatching(/entry/i), |
| 147 | + mimeType: 'application/json', |
| 148 | + }), |
| 149 | + }), |
| 150 | + ]), |
| 151 | + }), |
| 152 | + }), |
| 153 | + ) |
| 154 | + |
| 155 | + messageResultDeferred.resolve({ |
| 156 | + model: 'stub-model', |
| 157 | + stopReason: 'endTurn', |
| 158 | + role: 'assistant', |
| 159 | + content: { |
| 160 | + type: 'text', |
| 161 | + text: JSON.stringify([{ id: newTag1.id }, fakeTag2]), |
| 162 | + }, |
| 163 | + }) |
| 164 | + |
| 165 | + // give the server a chance to process the result |
| 166 | + await new Promise((resolve) => setTimeout(resolve, 100)) |
| 167 | +}) |
0 commit comments