|
| 1 | +// import fs from 'fs'; |
| 2 | +import fsPromise from 'node:fs/promises'; |
| 3 | + |
| 4 | +import { dirname } from 'path'; |
| 5 | +import { fileURLToPath } from 'url'; |
| 6 | +import { parse } from 'ts-command-line-args' |
| 7 | + |
1 | 8 | import '@tensorflow/tfjs-node'; |
2 | 9 | import fs from 'node:fs'; |
3 | 10 | import path from 'node:path'; |
4 | | -import { Tokenizer, models } from '@epfml/discojs'; |
| 11 | +import { models, serialization, Tokenizer } from '@epfml/discojs'; |
5 | 12 | import { loadHellaSwag } from '@epfml/discojs-node'; |
| 13 | +// import { AutoTokenizer } from '@xenova/transformers'; |
6 | 14 |
|
7 | | -const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt'); |
8 | | -const logLines: string[] = []; |
| 15 | +const __dirname = dirname(fileURLToPath(import.meta.url)); |
9 | 16 |
|
| 17 | +const logLines: string[] = []; |
10 | 18 | function log(message: string) { |
11 | 19 | console.log(message); |
12 | 20 | logLines.push(message); |
13 | 21 | } |
14 | 22 |
|
15 | | -const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1) |
16 | | - |
17 | | -async function evaluateTFJS(tokenizer: Tokenizer) { |
18 | | - const model = new models.GPT({ seed: 42 }); |
19 | | - log('Evaluating TFJS GPT on HellaSwag...'); |
| 23 | +async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) { |
| 24 | + const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints) |
| 25 | + const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2'); |
| 26 | + log('Starting the HellaSwag benchmark...'); |
20 | 27 |
|
21 | 28 | const start = Date.now(); |
22 | | - const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false); |
| 29 | + const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true); |
23 | 30 | const duration = ((Date.now() - start) / 1000).toFixed(2); |
24 | 31 |
|
25 | | - log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`); |
26 | | - log(`TFJS GPT Evaluation Time: ${duration} seconds`); |
| 32 | + log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`); |
| 33 | + log(`Evaluation Time: ${duration} seconds`); |
27 | 34 | } |
28 | 35 |
|
29 | | -async function evaluateXenova(tokenizer: Tokenizer) { |
30 | | - const model = await models.ONNXModel.init_pretrained('Xenova/gpt2'); |
31 | | - log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...'); |
| 36 | +const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const; |
| 37 | +type ModelType = typeof ModelTypes[number]; |
32 | 38 |
|
33 | | - const start = Date.now(); |
34 | | - const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false); |
35 | | - const duration = ((Date.now() - start) / 1000).toFixed(2); |
36 | | - |
37 | | - log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`); |
38 | | - log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`); |
| 39 | +interface HellaSwagArgs { |
| 40 | + model: ModelType |
| 41 | + numDataPoints: number |
| 42 | + logFile: string |
| 43 | + pretrainedModelPath: string |
| 44 | + help?: boolean |
39 | 45 | } |
40 | 46 |
|
41 | 47 | async function main(): Promise<void> { |
42 | | - fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file |
| 48 | + const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json") |
| 49 | + const args = parse<HellaSwagArgs>({ |
| 50 | + model: { |
| 51 | + type: (raw: string) => raw as ModelType, |
| 52 | + description: `Model type, one of ${ModelTypes}`, |
| 53 | + defaultValue: 'onnx' |
| 54 | + }, |
| 55 | + numDataPoints: { |
| 56 | + type: Number, |
| 57 | + description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark', |
| 58 | + defaultValue: -1 |
| 59 | + }, |
| 60 | + logFile: { |
| 61 | + type: String, |
| 62 | + description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log' |
| 63 | + }, |
| 64 | + pretrainedModelPath: { |
| 65 | + type: String, |
| 66 | + description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model', |
| 67 | + defaultValue: defaultPretrainedModelPath |
| 68 | + }, |
| 69 | + help: { |
| 70 | + type: Boolean, |
| 71 | + optional: true, |
| 72 | + alias: 'h', |
| 73 | + description: 'Prints this usage guide' |
| 74 | + } |
| 75 | + }, { helpArg: 'help' }) |
43 | 76 |
|
44 | | - const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2'); |
45 | | - await evaluateTFJS(tokenizer); |
46 | | - log('\n---\n'); |
47 | | - await evaluateXenova(tokenizer); |
| 77 | + const logFile = path.join(__dirname, args.logFile); |
| 78 | + fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file |
| 79 | + |
| 80 | + let model: | models.GPT | models.ONNXModel | undefined; |
| 81 | + switch (args.model) { |
| 82 | + case 'onnx': |
| 83 | + log("Using ONNX pretrained model Xenova/gpt2") |
| 84 | + model = await models.ONNXModel.init_pretrained('Xenova/gpt2'); |
| 85 | + break; |
| 86 | + case 'gpt-tfjs-random': |
| 87 | + log("Using GPT-TFJS with random initialization") |
| 88 | + model = new models.GPT({ seed: 42 }); |
| 89 | + break; |
| 90 | + case 'gpt-tfjs-pretrained': |
| 91 | + log("Using GPT-TFJS with pretrained weights") |
| 92 | + if (args.pretrainedModelPath === undefined) { |
| 93 | + throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath") |
| 94 | + } |
| 95 | + const encodedModel = await fsPromise.readFile(args.pretrainedModelPath); |
| 96 | + model = await serialization.model.decode(encodedModel) as models.GPT; |
| 97 | + break; |
| 98 | + default: |
| 99 | + throw new Error(`Unrecognized model type: ${model}`); |
| 100 | + } |
| 101 | + await evaluateModel(model, args.numDataPoints); |
48 | 102 |
|
49 | 103 | fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8'); |
50 | 104 | console.log(`\nResults written to ${logFile}`); |
|
0 commit comments