Skip to content

Commit 2f6afbb

Browse files
authored
feat: add embeddings package (#9)
1 parent 2426a68 commit 2f6afbb

File tree

6 files changed

+648
-0
lines changed

6 files changed

+648
-0
lines changed

packages/embeddings/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Embeddings
2+
3+
This package provides functions for generating embeddings using Vertex AI and calculating similarity between embeddings in Apps Script.

packages/embeddings/package.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"name": "@repository/embeddings",
3+
"version": "0.1.0",
4+
"scripts": {
5+
"check": "tsc --noEmit",
6+
"test": "vitest"
7+
},
8+
"author": "Justin Poehnelt <jpoehnelt@google.com>",
9+
"license": "Apache-2.0",
10+
"devDependencies": {
11+
"@types/google-apps-script": "^1.0.97",
12+
"vitest": "^3.0.9"
13+
},
14+
"type": "module",
15+
"private": true,
16+
"main": "./src/index.ts",
17+
"types": "./src/index.ts"
18+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import { beforeEach, describe, expect, it, vi } from "vitest";
2+
import { getTextEmbeddings, similarity, similarityEmoji } from "./index.js";
3+
4+
// Mock Google Apps Script global objects
5+
global.ScriptApp = {
6+
getOAuthToken: vi.fn().mockReturnValue("mock-token"),
7+
} as unknown as typeof ScriptApp;
8+
global.PropertiesService = {
9+
getScriptProperties: vi.fn().mockReturnValue({
10+
getProperty: vi
11+
.fn()
12+
.mockImplementation((key) =>
13+
key === "PROJECT_ID" ? "mock-project-id" : null,
14+
),
15+
}),
16+
} as unknown as typeof PropertiesService;
17+
18+
const fetchAll = vi.fn();
19+
global.UrlFetchApp = { fetchAll } as unknown as typeof UrlFetchApp;
20+
21+
describe("similarity", () => {
22+
it("calculates cosine similarity correctly", () => {
23+
// Parallel vectors (should be 1.0)
24+
expect(similarity([1, 2, 3], [2, 4, 6])).toBeCloseTo(1.0);
25+
26+
// Orthogonal vectors (should be 0.0)
27+
expect(similarity([1, 0, 0], [0, 1, 0])).toBeCloseTo(0.0);
28+
29+
// Opposite vectors (should be -1.0)
30+
expect(similarity([1, 2, 3], [-1, -2, -3])).toBeCloseTo(-1.0);
31+
});
32+
33+
it("throws an error when vectors have different lengths", () => {
34+
expect(() => similarity([1, 2, 3, 4], [1, 2, 3])).toThrow(
35+
"Vectors must have the same length",
36+
);
37+
});
38+
});
39+
40+
describe("similarityEmoji", () => {
41+
it("returns the correct emoji based on similarity value", () => {
42+
expect(similarityEmoji(1.0)).toBe("🔥"); // Very high (>=0.9)
43+
expect(similarityEmoji(0.8)).toBe("✅"); // High (>=0.7 and <0.9)
44+
expect(similarityEmoji(0.6)).toBe("👍"); // Medium (>=0.5 and <0.7)
45+
expect(similarityEmoji(0.4)).toBe("🤔"); // Low (>=0.3 and <0.5)
46+
expect(similarityEmoji(0.2)).toBe("❌"); // Very low (<0.3)
47+
});
48+
});
49+
50+
describe("getEmbeddings", () => {
51+
const mockResponse = {
52+
getResponseCode: vi.fn().mockReturnValue(200),
53+
getContentText: vi.fn().mockReturnValue(
54+
JSON.stringify({
55+
predictions: [{ embeddings: { values: [0.1, 0.2, 0.3] } }],
56+
}),
57+
),
58+
};
59+
60+
beforeEach(() => {
61+
vi.clearAllMocks();
62+
fetchAll.mockReturnValue([mockResponse]);
63+
});
64+
65+
it("handles single string input", () => {
66+
const result = getTextEmbeddings("test text");
67+
68+
expect(fetchAll).toHaveBeenCalledTimes(1);
69+
const requests = fetchAll.mock.calls[0][0];
70+
expect(requests).toHaveLength(1);
71+
72+
const payload = JSON.parse(requests[0].payload);
73+
expect(payload.instances[0].content).toBe("test text");
74+
75+
expect(result).toEqual([[0.1, 0.2, 0.3]]);
76+
});
77+
78+
it("handles array of strings input", () => {
79+
const mockResponses = [
80+
{
81+
getResponseCode: vi.fn().mockReturnValue(200),
82+
getContentText: vi.fn().mockReturnValue(
83+
JSON.stringify({
84+
predictions: [{ embeddings: { values: [0.1, 0.2, 0.3] } }],
85+
}),
86+
),
87+
},
88+
{
89+
getResponseCode: vi.fn().mockReturnValue(200),
90+
getContentText: vi.fn().mockReturnValue(
91+
JSON.stringify({
92+
predictions: [{ embeddings: { values: [0.4, 0.5, 0.6] } }],
93+
}),
94+
),
95+
},
96+
];
97+
98+
fetchAll.mockReturnValue(mockResponses);
99+
100+
const result = getTextEmbeddings(["text1", "text2"]);
101+
expect(result).toEqual([
102+
[0.1, 0.2, 0.3],
103+
[0.4, 0.5, 0.6],
104+
]);
105+
});
106+
107+
it("uses custom parameters", () => {
108+
// Test custom parameters
109+
getTextEmbeddings("test", {
110+
model: "custom-model",
111+
parameters: {},
112+
projectId: "custom-project",
113+
region: "custom-region",
114+
});
115+
116+
const requests = fetchAll.mock.calls[0][0];
117+
expect(requests[0].url).toContain("custom-region");
118+
expect(requests[0].url).toContain("custom-model");
119+
});
120+
});

packages/embeddings/src/index.ts

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
const MODEL_ID = "text-embedding-005";
2+
const REGION = "us-central1";
3+
4+
interface Parameters {
5+
autoTruncate?: boolean;
6+
outputDimensionality?: number;
7+
}
8+
9+
interface Instance {
10+
task_type?:
11+
| "RETRIEVAL_DOCUMENT"
12+
| "RETRIEVAL_QUERY"
13+
| "SEMANTIC_SIMILARITY"
14+
| "CLASSIFICATION"
15+
| "CLUSTERING"
16+
| "QUESTION_ANSWERING"
17+
| "FACT_VERIFICATION"
18+
| "CODE_RETRIEVAL_QUERY";
19+
title?: string;
20+
content: string;
21+
}
22+
23+
/**
24+
* Options for generating embeddings.
25+
*/
26+
interface Options {
27+
/**
28+
* The project ID that the model is in.
29+
* @default 'PropertiesService.getScriptProperties().getProperty("PROJECT_ID")'
30+
*/
31+
projectId?: string;
32+
33+
/**
34+
* The ID of the model to use.
35+
* @default 'text-embedding-005'.
36+
*/
37+
model?: string;
38+
39+
/**
40+
* Additional parameters to pass to the model.
41+
*/
42+
parameters?: Parameters;
43+
44+
/**
45+
* The region that the model is in.
46+
* @default 'us-central1'
47+
*/
48+
region?: string;
49+
50+
/**
51+
* The OAuth token to use to authenticate the request.
52+
* @default `ScriptApp.getOAuthToken()`
53+
*/
54+
token?: string;
55+
}
56+
57+
const getProjectId = (): string => {
58+
const projectId =
59+
PropertiesService.getScriptProperties().getProperty("PROJECT_ID");
60+
if (!projectId) {
61+
throw new Error("PROJECT_ID not found in script properties");
62+
}
63+
64+
return projectId;
65+
};
66+
67+
/**
68+
* Generate embeddings for the given text content.
69+
*
70+
* @param content - The text content to generate embeddings for.
71+
* @param options - Options for the embeddings generation.
72+
* @returns The generated embeddings.
73+
*
74+
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
75+
*/
76+
export function getTextEmbeddings(
77+
contentOrContentArray: string | string[],
78+
options: Options = {},
79+
): number[][] {
80+
const inputs = Array.isArray(contentOrContentArray)
81+
? contentOrContentArray
82+
: [contentOrContentArray];
83+
84+
return getBatchedEmbeddings(
85+
inputs.map((content) => ({ content })),
86+
options,
87+
);
88+
}
89+
90+
/**
91+
* Generate embeddings for the given instances in parallel UrlFetchApp requests.
92+
*
93+
* @param instances - The instances to generate embeddings for.
94+
* @param options - Options for the embeddings generation.
95+
* @returns The generated embeddings.
96+
*
97+
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
98+
*/
99+
export function getBatchedEmbeddings(
100+
instances: Instance[],
101+
{
102+
parameters = {},
103+
model = MODEL_ID,
104+
projectId = getProjectId(),
105+
region = REGION,
106+
token = ScriptApp.getOAuthToken(),
107+
}: Options = {},
108+
): number[][] {
109+
const chunks = chunkArray(instances, 5);
110+
const requests = chunks.map((instances) => ({
111+
url: `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/models/${model}:predict`,
112+
method: "post" as const,
113+
headers: {
114+
Authorization: `Bearer ${token}`,
115+
"Content-Type": "application/json",
116+
},
117+
muteHttpExceptions: true,
118+
contentType: "application/json",
119+
payload: JSON.stringify({
120+
instances,
121+
parameters,
122+
}),
123+
}));
124+
125+
const responses = UrlFetchApp.fetchAll(requests);
126+
127+
const results = responses.map((response) => {
128+
if (response.getResponseCode() !== 200) {
129+
throw new Error(response.getContentText());
130+
}
131+
132+
return JSON.parse(response.getContentText());
133+
});
134+
135+
return results.flatMap((result) =>
136+
result.predictions.map(
137+
(prediction: { embeddings: { values: number[] } }) =>
138+
prediction.embeddings.values,
139+
),
140+
);
141+
}
142+
143+
/**
144+
* Calculates the dot product of two vectors.
145+
* @param x - The first vector.
146+
* @param y - The second vector.
147+
*/
148+
function dotProduct_(x: number[], y: number[]): number {
149+
let result = 0;
150+
for (let i = 0, l = Math.min(x.length, y.length); i < l; i += 1) {
151+
result += x[i] * y[i];
152+
}
153+
return result;
154+
}
155+
156+
/**
157+
* Calculates the magnitude of a vector.
158+
* @param x - The vector.
159+
*/
160+
function magnitude(x: number[]): number {
161+
let result = 0;
162+
for (let i = 0, l = x.length; i < l; i += 1) {
163+
result += x[i] ** 2;
164+
}
165+
return Math.sqrt(result);
166+
}
167+
168+
/**
169+
* Calculates the cosine similarity between two vectors.
170+
* @param x - The first vector.
171+
* @param y - The second vector.
172+
* @returns The cosine similarity value between -1 and 1.
173+
*/
174+
export function similarity(x: number[], y: number[]): number {
175+
if (x.length !== y.length) {
176+
throw new Error("Vectors must have the same length");
177+
}
178+
return dotProduct_(x, y) / (magnitude(x) * magnitude(y));
179+
}
180+
181+
/**
182+
* Returns an emoji representing the similarity value.
183+
* @param value - The similarity value.
184+
*/
185+
export const similarityEmoji = (value: number): string => {
186+
if (value >= 0.9) return "🔥"; // Very high similarity
187+
if (value >= 0.7) return "✅"; // High similarity
188+
if (value >= 0.5) return "👍"; // Medium similarity
189+
if (value >= 0.3) return "🤔"; // Low similarity
190+
return "❌"; // Very low similarity
191+
};
192+
193+
function chunkArray<T>(array: T[], size: number): T[][] {
194+
const chunks: T[][] = [];
195+
for (let i = 0; i < array.length; i += size) {
196+
chunks.push(array.slice(i, i + size));
197+
}
198+
return chunks;
199+
}

packages/embeddings/tsconfig.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"compilerOptions": {
3+
"module": "NodeNext",
4+
"target": "ES2022",
5+
"lib": ["esnext"],
6+
"strict": true,
7+
"esModuleInterop": true,
8+
"skipLibCheck": true,
9+
"types": ["@types/google-apps-script"],
10+
"experimentalDecorators": true
11+
},
12+
"include": ["src/**/*.ts"],
13+
"exclude": ["node_modules", "dist"]
14+
}

0 commit comments

Comments
 (0)