Skip to content

Commit ae33116

Browse files
committed
hellaswag2.ts and loaded_hellaswag.spec.ts used for testing the loaded model on the whole HellaSwag
1 parent c9e076e commit ae33116

File tree

2 files changed

+329
-0
lines changed

2 files changed

+329
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import { expect } from 'chai';
2+
import path from 'path';
3+
import { AutoTokenizer, PreTrainedTokenizer } from '@xenova/transformers';
4+
import { GPT } from './index.js';
5+
import { GPTModel } from './model.js';
6+
import { loadWeightsFromJSON } from './load_weights.js';
7+
import { evaluate } from '../hellaswag2.js';
8+
9+
describe('GPT Model with Pretrained Weights on Full HellaSwag', () => {
10+
11+
let gptForTest: GPT;
12+
let tokenizer: PreTrainedTokenizer;
13+
14+
before(async function() {
15+
this.timeout(2400000000000);
16+
17+
console.log('Setting up benchmark: loading model, weights, tokenizer...');
18+
19+
console.time('Model+Tokenizer Loading Time');
20+
21+
const modelConfig = {
22+
modelType: 'gpt2' as const,
23+
contextLength: 1024
24+
};
25+
26+
const loadedGptModel = new GPTModel(modelConfig);
27+
const weightsFilename = 'gpt2_weights.jsonl';
28+
const weightsFileUrl = new URL(path.resolve(weightsFilename), 'file://').href;
29+
30+
await loadWeightsFromJSON(loadedGptModel, weightsFileUrl);
31+
32+
gptForTest = new GPT(modelConfig, loadedGptModel);
33+
tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2');
34+
35+
console.timeEnd('Model+Tokenizer Loading Time');
36+
console.log('Setup complete.');
37+
});
38+
39+
after(() => {
40+
console.log('Tearing down test suite: disposing of model...');
41+
gptForTest?.[Symbol.dispose]();
42+
});
43+
44+
it('evaluates the loaded model on the entire HellaSwag dataset', async () => {
45+
console.log('\n--- Starting HellaSwag Benchmark ---');
46+
47+
console.time('Evaluation Time on HellaSwag');
48+
const accuracy = await evaluate(gptForTest, tokenizer, 10042, false);
49+
console.timeEnd('Evaluation Time on HellaSwag');
50+
51+
console.log(`\n--- Benchmark Complete ---`);
52+
console.log(`Final Accuracy on Full HellaSwag: ${(accuracy * 100).toFixed(2)}%`);
53+
54+
expect(accuracy).to.be.gt(0.20);
55+
expect(accuracy).to.be.lt(0.30);
56+
57+
console.log(`Accuracy is: ${(accuracy * 100).toFixed(2)}%`);
58+
console.log('Benchmark passed successfully.');
59+
60+
}).timeout(6000000000000);
61+
});

discojs/src/models/hellaswag2.ts

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import { promises as fsPromises } from 'fs';
2+
import fetch from 'node-fetch';
3+
import * as tf from '@tensorflow/tfjs';
4+
import { GPT } from './index.js';
5+
import { tokenize } from '../processing/text.js';
6+
import { PreTrainedTokenizer } from '@xenova/transformers';
7+
import * as readline from 'readline';
8+
import { fileURLToPath } from 'url';
9+
import path from 'path';
10+
import fs from 'fs';
11+
import { List } from 'immutable';
12+
import { ONNXModel } from './onnx.js';
13+
14+
15+
const HELLASWAG_URL = 'https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl';
16+
const __filename = fileURLToPath(import.meta.url);
17+
const __dirname = path.dirname(__filename);
18+
const LOCAL_FILE = path.resolve(__dirname, '../../../datasets/hellaswag_val.jsonl');
19+
20+
21+
async function fileExists(path_: string = LOCAL_FILE): Promise<boolean> {
22+
try {
23+
await fsPromises.access(path_);
24+
return true;
25+
} catch {
26+
return false;
27+
}
28+
}
29+
30+
// Download the HellaSwag dataset if it doesn't exist locally
31+
async function downloadHellaSwag(path_: string = LOCAL_FILE): Promise<void> {
32+
if (await fileExists(path_)) return;
33+
34+
const res = await fetch(HELLASWAG_URL);
35+
const fileStream = fs.createWriteStream(path_);
36+
37+
await new Promise<void>((resolve, reject) => {
38+
res.body?.pipe(fileStream);
39+
res.body?.on('error', reject);
40+
fileStream.on('error', reject);
41+
fileStream.on('finish', () => resolve());
42+
});
43+
}
44+
45+
/**
46+
* Represents a single example from the HellaSwag dataset.
47+
*
48+
* ctx - The context sentence or paragraph that sets up the situation.
49+
* endings - An array of four possible continuations of the context.
50+
* label - The index (0–3) of the correct ending in the `endings` array.
51+
*/
52+
interface HellaSwagExample {
53+
ctx: string;
54+
endings: string[];
55+
label: number;
56+
}
57+
58+
async function* loadExamples(limit = 100): AsyncGenerator<HellaSwagExample> {
59+
60+
// Read the dataset line by line
61+
const fileStream = fs.createReadStream(LOCAL_FILE, 'utf-8');
62+
const rl = readline.createInterface({ input: fileStream, crlfDelay: Infinity });
63+
64+
let count = 0;
65+
for await (const line of rl) {
66+
// Stop if the desired number of examples has been reached
67+
if (count++ >= limit) break;
68+
69+
try {
70+
const data = JSON.parse(line.trim()) as HellaSwagExample;
71+
yield { ctx: data.ctx, endings: data.endings, label: data.label };
72+
} catch (e) {
73+
console.error(`Failed to parse line ${count}:`, line);
74+
throw e;
75+
}
76+
}
77+
}
78+
79+
80+
// DEBUGGING FUNCTION LOADS A SINGLE EXAMPLE
81+
async function* loadExample(limit = 1, lineNumber?: number): AsyncGenerator<HellaSwagExample> {
82+
const fileStream = fs.createReadStream(LOCAL_FILE, 'utf-8');
83+
const rl = readline.createInterface({ input: fileStream, crlfDelay: Infinity });
84+
85+
let count = 0;
86+
for await (const line of rl) {
87+
if (!line.trim()) continue;
88+
89+
if (lineNumber !== undefined) {
90+
if (count === lineNumber) {
91+
const data = JSON.parse(line.trim()) as HellaSwagExample;
92+
yield { ctx: data.ctx, endings: data.endings, label: data.label };
93+
break; // only one line
94+
}
95+
} else {
96+
if (count >= limit) break;
97+
const data = JSON.parse(line.trim()) as HellaSwagExample;
98+
yield { ctx: data.ctx, endings: data.endings, label: data.label };
99+
}
100+
101+
count++;
102+
}
103+
}
104+
105+
// Computes the log likelihood of the input sequence using the tfjs model
106+
// The input sequence is expected to be a concatenation of the context and the ending
107+
// The function computes the log likelihood of each ending and returns the one with the loss of each ending
108+
// Sources:
109+
// https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
110+
//https://www.youtube.com/watch?v=l8pRSuU81PU
111+
async function computeLogLikelihood(gpt: GPT, inputIds: number[], ctxLength: number): Promise<number> {
112+
const lossTensor = tf.tidy(() => {
113+
// Convert input sequence to shape [1, seq_len]
114+
const inputTensor = tf.tensor2d([inputIds], [1, inputIds.length], 'int32');
115+
116+
// Get model logits: [1, seq_len, vocab_size]
117+
const logits3D = gpt.extract().predict(inputTensor) as tf.Tensor3D;
118+
119+
// Shift logits to align with next-token targets
120+
const shiftedLogits = logits3D.slice([0, 0, 0], [1, inputIds.length - 1, -1]);
121+
122+
// Target tokens (next tokens), same length as shifted logits
123+
const shiftedTargets = inputIds.slice(1);
124+
const targetTensor = tf.tensor1d(shiftedTargets, 'int32');
125+
126+
// One-hot encode targets for cross-entropy loss
127+
const oneHotLabels = tf.oneHot(targetTensor, shiftedLogits.shape[2]);
128+
129+
// Compute per-token cross-entropy log-probabilities (unnormalized loss)
130+
const logProbs = tf.losses.softmaxCrossEntropy(oneHotLabels, shiftedLogits.squeeze());
131+
132+
// Create a mask to only include loss after the context length
133+
const mask = tf.tensor1d(inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)), 'float32').slice(1);
134+
135+
// Apply the mask and average over the selected tokens
136+
const masked = logProbs.mul(mask);
137+
const loss = masked.sum().div(mask.sum());
138+
139+
return loss;
140+
});
141+
const lossNumber = await lossTensor.array();
142+
if (typeof lossNumber !== 'number') {
143+
throw new Error('got multiple loss')
144+
}
145+
return lossNumber;
146+
}
147+
148+
149+
// Computes the log likelihood of the input sequence using the ONNX model
150+
// The input sequence is expected to be a concatenation of the context and the ending
151+
// The function computes the log likelihood of each ending and returns the one with the loss of each ending
152+
// Sources:
153+
// https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
154+
// https://www.youtube.com/watch?v=l8pRSuU81PU
155+
async function computeONNXLogLikelihood(model: ONNXModel, inputIds: number[], ctxLength: number): Promise<number> {
156+
const batchInput = List([List(inputIds)]); // [1, seq_len]
157+
158+
// Run model to get logits: flattened [T * V]
159+
const logitsTensor = await model.getLogits(batchInput);
160+
const logits = logitsTensor.data as number[];
161+
const [_B, T, V] = logitsTensor.dims;
162+
163+
// Reshape flattened logits into [T][V]
164+
const reshaped: number[][] = Array.from({ length: T }, (_, t) =>
165+
logits.slice(t * V, (t + 1) * V)
166+
);
167+
168+
// Shift targets (next-token prediction)
169+
const targets = inputIds.slice(1); // length = T - 1
170+
const logitsShifted = reshaped.slice(0, T - 1); // also length = T - 1
171+
172+
// Compute per-token cross-entropy loss manually
173+
const losses = logitsShifted.map((logit, i) => {
174+
const maxLogit = Math.max(...logit); // for numerical stability
175+
const exp = logit.map(x => Math.exp(x - maxLogit));
176+
const sumExp = exp.reduce((a, b) => a + b, 0);
177+
const probs = exp.map(e => e / sumExp); // softmax
178+
return -Math.log(probs[targets[i]]); // cross-entropy loss
179+
});
180+
181+
// Create a binary mask for non-context tokens
182+
const mask = inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)).slice(1);
183+
184+
// Apply the mask to the losses
185+
const maskedLosses = losses.map((l, i) => l * mask[i]);
186+
187+
// Average the masked losses
188+
const totalLoss = maskedLosses.reduce((a, b) => a + b, 0);
189+
const sum = mask.reduce((a, b) => a + b, 0 as number);
190+
191+
return totalLoss / (sum || 1); // avoid division by 0
192+
}
193+
194+
195+
type Tokenizer = PreTrainedTokenizer;
196+
type ModelType = GPT | ONNXModel;
197+
198+
/**
199+
* Evaluates the model on the HellaSwag dataset.
200+
* model - The model to evaluate (either GPT or ONNXModel)
201+
* tokenizer - The tokenizer to use for tokenizing the input text
202+
* limit - The number of examples to evaluate on (default: 50)
203+
* print - Whether to print the results (default: true)
204+
* @returns The accuracy of the model on the dataset
205+
*/
206+
export async function evaluate(
207+
model: ModelType,
208+
tokenizer: Tokenizer,
209+
limit = 50, // Number of examples to evaluate on (set to 10042 for all examples)
210+
print = true,
211+
dataset_path: string = LOCAL_FILE
212+
): Promise<number> {
213+
await downloadHellaSwag(dataset_path);
214+
215+
let correct = 0;
216+
let total = 0;
217+
218+
for await (const example of loadExamples(limit)) {
219+
const endingTokens = example.endings.map(e =>
220+
tokenize(tokenizer, example.ctx + ' ' + e, {
221+
truncation: true,
222+
max_length: 128
223+
}).toArray()
224+
);
225+
226+
const ctxTokens = tokenize(tokenizer, example.ctx, {
227+
truncation: true,
228+
max_length: 128
229+
}).toArray();
230+
231+
let losses: number[] = [];
232+
233+
if (model instanceof GPT) {
234+
losses = await Promise.all(
235+
endingTokens.map(e =>
236+
computeLogLikelihood(model, e, ctxTokens.length)
237+
)
238+
);
239+
} else {
240+
losses = await Promise.all(
241+
endingTokens.map(e =>
242+
computeONNXLogLikelihood(model, e, ctxTokens.length)
243+
)
244+
);
245+
}
246+
247+
const pred = losses.indexOf(Math.min(...losses));
248+
if (pred === example.label) correct++;
249+
total++;
250+
251+
// Print the results
252+
if (print) {
253+
console.log(`\nExample #${total}`);
254+
console.log(`Context: ${example.ctx}`);
255+
example.endings.forEach((end, i) => {
256+
console.log(
257+
` ${i}: ${end} (loss: ${losses[i].toFixed(4)})${i === example.label ? ' <-- correct' : ''}${i === pred ? ' <-- picked' : ''}`
258+
);
259+
});
260+
const accuracy_temp = correct / total;
261+
console.log(`\n Accuracy on ${total} examples: ${(accuracy_temp * 100).toFixed(2)}%`);
262+
}
263+
}
264+
265+
const accuracy = correct / total;
266+
console.log(`\nFinal accuracy on ${total} examples: ${(accuracy * 100).toFixed(2)}%`);
267+
return accuracy;
268+
}

0 commit comments

Comments
 (0)