1+ import { promises as fsPromises } from 'fs' ;
2+ import fetch from 'node-fetch' ;
3+ import * as tf from '@tensorflow/tfjs' ;
4+ import { GPT } from './index.js' ;
5+ import { tokenize } from '../processing/text.js' ;
6+ import { PreTrainedTokenizer } from '@xenova/transformers' ;
7+ import * as readline from 'readline' ;
8+ import { fileURLToPath } from 'url' ;
9+ import path from 'path' ;
10+ import fs from 'fs' ;
11+ import { List } from 'immutable' ;
12+ import { ONNXModel } from './onnx.js' ;
13+
14+
15+ const HELLASWAG_URL = 'https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl' ;
16+ const __filename = fileURLToPath ( import . meta. url ) ;
17+ const __dirname = path . dirname ( __filename ) ;
18+ const LOCAL_FILE = path . resolve ( __dirname , '../../../datasets/hellaswag_val.jsonl' ) ;
19+
20+
21+ async function fileExists ( path_ : string = LOCAL_FILE ) : Promise < boolean > {
22+ try {
23+ await fsPromises . access ( path_ ) ;
24+ return true ;
25+ } catch {
26+ return false ;
27+ }
28+ }
29+
30+ // Download the HellaSwag dataset if it doesn't exist locally
31+ async function downloadHellaSwag ( path_ : string = LOCAL_FILE ) : Promise < void > {
32+ if ( await fileExists ( path_ ) ) return ;
33+
34+ const res = await fetch ( HELLASWAG_URL ) ;
35+ const fileStream = fs . createWriteStream ( path_ ) ;
36+
37+ await new Promise < void > ( ( resolve , reject ) => {
38+ res . body ?. pipe ( fileStream ) ;
39+ res . body ?. on ( 'error' , reject ) ;
40+ fileStream . on ( 'error' , reject ) ;
41+ fileStream . on ( 'finish' , ( ) => resolve ( ) ) ;
42+ } ) ;
43+ }
44+
45+ /**
46+ * Represents a single example from the HellaSwag dataset.
47+ *
48+ * ctx - The context sentence or paragraph that sets up the situation.
49+ * endings - An array of four possible continuations of the context.
50+ * label - The index (0–3) of the correct ending in the `endings` array.
51+ */
52+ interface HellaSwagExample {
53+ ctx : string ;
54+ endings : string [ ] ;
55+ label : number ;
56+ }
57+
58+ async function * loadExamples ( limit = 100 ) : AsyncGenerator < HellaSwagExample > {
59+
60+ // Read the dataset line by line
61+ const fileStream = fs . createReadStream ( LOCAL_FILE , 'utf-8' ) ;
62+ const rl = readline . createInterface ( { input : fileStream , crlfDelay : Infinity } ) ;
63+
64+ let count = 0 ;
65+ for await ( const line of rl ) {
66+ // Stop if the desired number of examples has been reached
67+ if ( count ++ >= limit ) break ;
68+
69+ try {
70+ const data = JSON . parse ( line . trim ( ) ) as HellaSwagExample ;
71+ yield { ctx : data . ctx , endings : data . endings , label : data . label } ;
72+ } catch ( e ) {
73+ console . error ( `Failed to parse line ${ count } :` , line ) ;
74+ throw e ;
75+ }
76+ }
77+ }
78+
79+
80+ // DEBUGGING FUNCTION LOADS A SINGLE EXAMPLE
81+ async function * loadExample ( limit = 1 , lineNumber ?: number ) : AsyncGenerator < HellaSwagExample > {
82+ const fileStream = fs . createReadStream ( LOCAL_FILE , 'utf-8' ) ;
83+ const rl = readline . createInterface ( { input : fileStream , crlfDelay : Infinity } ) ;
84+
85+ let count = 0 ;
86+ for await ( const line of rl ) {
87+ if ( ! line . trim ( ) ) continue ;
88+
89+ if ( lineNumber !== undefined ) {
90+ if ( count === lineNumber ) {
91+ const data = JSON . parse ( line . trim ( ) ) as HellaSwagExample ;
92+ yield { ctx : data . ctx , endings : data . endings , label : data . label } ;
93+ break ; // only one line
94+ }
95+ } else {
96+ if ( count >= limit ) break ;
97+ const data = JSON . parse ( line . trim ( ) ) as HellaSwagExample ;
98+ yield { ctx : data . ctx , endings : data . endings , label : data . label } ;
99+ }
100+
101+ count ++ ;
102+ }
103+ }
104+
105+ // Computes the log likelihood of the input sequence using the tfjs model
106+ // The input sequence is expected to be a concatenation of the context and the ending
107+ // The function computes the log likelihood of each ending and returns the one with the loss of each ending
108+ // Sources:
109+ // https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
110+ //https://www.youtube.com/watch?v=l8pRSuU81PU
111+ async function computeLogLikelihood ( gpt : GPT , inputIds : number [ ] , ctxLength : number ) : Promise < number > {
112+ const lossTensor = tf . tidy ( ( ) => {
113+ // Convert input sequence to shape [1, seq_len]
114+ const inputTensor = tf . tensor2d ( [ inputIds ] , [ 1 , inputIds . length ] , 'int32' ) ;
115+
116+ // Get model logits: [1, seq_len, vocab_size]
117+ const logits3D = gpt . extract ( ) . predict ( inputTensor ) as tf . Tensor3D ;
118+
119+ // Shift logits to align with next-token targets
120+ const shiftedLogits = logits3D . slice ( [ 0 , 0 , 0 ] , [ 1 , inputIds . length - 1 , - 1 ] ) ;
121+
122+ // Target tokens (next tokens), same length as shifted logits
123+ const shiftedTargets = inputIds . slice ( 1 ) ;
124+ const targetTensor = tf . tensor1d ( shiftedTargets , 'int32' ) ;
125+
126+ // One-hot encode targets for cross-entropy loss
127+ const oneHotLabels = tf . oneHot ( targetTensor , shiftedLogits . shape [ 2 ] ) ;
128+
129+ // Compute per-token cross-entropy log-probabilities (unnormalized loss)
130+ const logProbs = tf . losses . softmaxCrossEntropy ( oneHotLabels , shiftedLogits . squeeze ( ) ) ;
131+
132+ // Create a mask to only include loss after the context length
133+ const mask = tf . tensor1d ( inputIds . map ( ( _ , i ) => ( i >= ctxLength ? 1 : 0 ) ) , 'float32' ) . slice ( 1 ) ;
134+
135+ // Apply the mask and average over the selected tokens
136+ const masked = logProbs . mul ( mask ) ;
137+ const loss = masked . sum ( ) . div ( mask . sum ( ) ) ;
138+
139+ return loss ;
140+ } ) ;
141+ const lossNumber = await lossTensor . array ( ) ;
142+ if ( typeof lossNumber !== 'number' ) {
143+ throw new Error ( 'got multiple loss' )
144+ }
145+ return lossNumber ;
146+ }
147+
148+
149+ // Computes the log likelihood of the input sequence using the ONNX model
150+ // The input sequence is expected to be a concatenation of the context and the ending
151+ // The function computes the log likelihood of each ending and returns the one with the loss of each ending
152+ // Sources:
153+ // https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
154+ // https://www.youtube.com/watch?v=l8pRSuU81PU
155+ async function computeONNXLogLikelihood ( model : ONNXModel , inputIds : number [ ] , ctxLength : number ) : Promise < number > {
156+ const batchInput = List ( [ List ( inputIds ) ] ) ; // [1, seq_len]
157+
158+ // Run model to get logits: flattened [T * V]
159+ const logitsTensor = await model . getLogits ( batchInput ) ;
160+ const logits = logitsTensor . data as number [ ] ;
161+ const [ _B , T , V ] = logitsTensor . dims ;
162+
163+ // Reshape flattened logits into [T][V]
164+ const reshaped : number [ ] [ ] = Array . from ( { length : T } , ( _ , t ) =>
165+ logits . slice ( t * V , ( t + 1 ) * V )
166+ ) ;
167+
168+ // Shift targets (next-token prediction)
169+ const targets = inputIds . slice ( 1 ) ; // length = T - 1
170+ const logitsShifted = reshaped . slice ( 0 , T - 1 ) ; // also length = T - 1
171+
172+ // Compute per-token cross-entropy loss manually
173+ const losses = logitsShifted . map ( ( logit , i ) => {
174+ const maxLogit = Math . max ( ...logit ) ; // for numerical stability
175+ const exp = logit . map ( x => Math . exp ( x - maxLogit ) ) ;
176+ const sumExp = exp . reduce ( ( a , b ) => a + b , 0 ) ;
177+ const probs = exp . map ( e => e / sumExp ) ; // softmax
178+ return - Math . log ( probs [ targets [ i ] ] ) ; // cross-entropy loss
179+ } ) ;
180+
181+ // Create a binary mask for non-context tokens
182+ const mask = inputIds . map ( ( _ , i ) => ( i >= ctxLength ? 1 : 0 ) ) . slice ( 1 ) ;
183+
184+ // Apply the mask to the losses
185+ const maskedLosses = losses . map ( ( l , i ) => l * mask [ i ] ) ;
186+
187+ // Average the masked losses
188+ const totalLoss = maskedLosses . reduce ( ( a , b ) => a + b , 0 ) ;
189+ const sum = mask . reduce ( ( a , b ) => a + b , 0 as number ) ;
190+
191+ return totalLoss / ( sum || 1 ) ; // avoid division by 0
192+ }
193+
194+
195+ type Tokenizer = PreTrainedTokenizer ;
196+ type ModelType = GPT | ONNXModel ;
197+
198+ /**
199+ * Evaluates the model on the HellaSwag dataset.
200+ * model - The model to evaluate (either GPT or ONNXModel)
201+ * tokenizer - The tokenizer to use for tokenizing the input text
202+ * limit - The number of examples to evaluate on (default: 50)
203+ * print - Whether to print the results (default: true)
204+ * @returns The accuracy of the model on the dataset
205+ */
206+ export async function evaluate (
207+ model : ModelType ,
208+ tokenizer : Tokenizer ,
209+ limit = 50 , // Number of examples to evaluate on (set to 10042 for all examples)
210+ print = true ,
211+ dataset_path : string = LOCAL_FILE
212+ ) : Promise < number > {
213+ await downloadHellaSwag ( dataset_path ) ;
214+
215+ let correct = 0 ;
216+ let total = 0 ;
217+
218+ for await ( const example of loadExamples ( limit ) ) {
219+ const endingTokens = example . endings . map ( e =>
220+ tokenize ( tokenizer , example . ctx + ' ' + e , {
221+ truncation : true ,
222+ max_length : 128
223+ } ) . toArray ( )
224+ ) ;
225+
226+ const ctxTokens = tokenize ( tokenizer , example . ctx , {
227+ truncation : true ,
228+ max_length : 128
229+ } ) . toArray ( ) ;
230+
231+ let losses : number [ ] = [ ] ;
232+
233+ if ( model instanceof GPT ) {
234+ losses = await Promise . all (
235+ endingTokens . map ( e =>
236+ computeLogLikelihood ( model , e , ctxTokens . length )
237+ )
238+ ) ;
239+ } else {
240+ losses = await Promise . all (
241+ endingTokens . map ( e =>
242+ computeONNXLogLikelihood ( model , e , ctxTokens . length )
243+ )
244+ ) ;
245+ }
246+
247+ const pred = losses . indexOf ( Math . min ( ...losses ) ) ;
248+ if ( pred === example . label ) correct ++ ;
249+ total ++ ;
250+
251+ // Print the results
252+ if ( print ) {
253+ console . log ( `\nExample #${ total } ` ) ;
254+ console . log ( `Context: ${ example . ctx } ` ) ;
255+ example . endings . forEach ( ( end , i ) => {
256+ console . log (
257+ ` ${ i } : ${ end } (loss: ${ losses [ i ] . toFixed ( 4 ) } )${ i === example . label ? ' <-- correct' : '' } ${ i === pred ? ' <-- picked' : '' } `
258+ ) ;
259+ } ) ;
260+ const accuracy_temp = correct / total ;
261+ console . log ( `\n Accuracy on ${ total } examples: ${ ( accuracy_temp * 100 ) . toFixed ( 2 ) } %` ) ;
262+ }
263+ }
264+
265+ const accuracy = correct / total ;
266+ console . log ( `\nFinal accuracy on ${ total } examples: ${ ( accuracy * 100 ) . toFixed ( 2 ) } %` ) ;
267+ return accuracy ;
268+ }
0 commit comments