Skip to content

Commit 6565027

Browse files
committed
onnx-converter: new npm workspace to convert GPT2 from ONNX to TFJS
1 parent c3301b6 commit 6565027

File tree

17 files changed

+10484
-43
lines changed

17 files changed

+10484
-43
lines changed

cli/src/hellaswag_gpt.ts

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,104 @@
1+
// import fs from 'fs';
2+
import fsPromise from 'node:fs/promises';
3+
4+
import { dirname } from 'path';
5+
import { fileURLToPath } from 'url';
6+
import { parse } from 'ts-command-line-args'
7+
18
import '@tensorflow/tfjs-node';
29
import fs from 'node:fs';
310
import path from 'node:path';
4-
import { Tokenizer, models } from '@epfml/discojs';
11+
import { models, serialization, Tokenizer } from '@epfml/discojs';
512
import { loadHellaSwag } from '@epfml/discojs-node';
13+
// import { AutoTokenizer } from '@xenova/transformers';
614

7-
const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
8-
const logLines: string[] = [];
15+
const __dirname = dirname(fileURLToPath(import.meta.url));
916

17+
const logLines: string[] = [];
1018
function log(message: string) {
1119
console.log(message);
1220
logLines.push(message);
1321
}
1422

15-
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)
16-
17-
async function evaluateTFJS(tokenizer: Tokenizer) {
18-
const model = new models.GPT({ seed: 42 });
19-
log('Evaluating TFJS GPT on HellaSwag...');
23+
async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) {
24+
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints)
25+
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
26+
log('Starting the HellaSwag benchmark...');
2027

2128
const start = Date.now();
22-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
29+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true);
2330
const duration = ((Date.now() - start) / 1000).toFixed(2);
2431

25-
log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
26-
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
32+
log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`);
33+
log(`Evaluation Time: ${duration} seconds`);
2734
}
2835

29-
async function evaluateXenova(tokenizer: Tokenizer) {
30-
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
31-
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
36+
const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const;
37+
type ModelType = typeof ModelTypes[number];
3238

33-
const start = Date.now();
34-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
35-
const duration = ((Date.now() - start) / 1000).toFixed(2);
36-
37-
log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
38-
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
39+
interface HellaSwagArgs {
40+
model: ModelType
41+
numDataPoints: number
42+
logFile: string
43+
pretrainedModelPath: string
44+
help?: boolean
3945
}
4046

4147
async function main(): Promise<void> {
42-
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
48+
const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
49+
const args = parse<HellaSwagArgs>({
50+
model: {
51+
type: (raw: string) => raw as ModelType,
52+
description: `Model type, one of ${ModelTypes}`,
53+
defaultValue: 'onnx'
54+
},
55+
numDataPoints: {
56+
type: Number,
57+
description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark',
58+
defaultValue: -1
59+
},
60+
logFile: {
61+
type: String,
62+
description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log'
63+
},
64+
pretrainedModelPath: {
65+
type: String,
66+
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
67+
defaultValue: defaultPretrainedModelPath
68+
},
69+
help: {
70+
type: Boolean,
71+
optional: true,
72+
alias: 'h',
73+
description: 'Prints this usage guide'
74+
}
75+
}, { helpArg: 'help' })
4376

44-
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
45-
await evaluateTFJS(tokenizer);
46-
log('\n---\n');
47-
await evaluateXenova(tokenizer);
77+
const logFile = path.join(__dirname, args.logFile);
78+
fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file
79+
80+
let model: | models.GPT | models.ONNXModel | undefined;
81+
switch (args.model) {
82+
case 'onnx':
83+
log("Using ONNX pretrained model Xenova/gpt2")
84+
model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
85+
break;
86+
case 'gpt-tfjs-random':
87+
log("Using GPT-TFJS with random initialization")
88+
model = new models.GPT({ seed: 42 });
89+
break;
90+
case 'gpt-tfjs-pretrained':
91+
log("Using GPT-TFJS with pretrained weights")
92+
if (args.pretrainedModelPath === undefined) {
93+
throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath")
94+
}
95+
const encodedModel = await fsPromise.readFile(args.pretrainedModelPath);
96+
model = await serialization.model.decode(encodedModel) as models.GPT;
97+
break;
98+
default:
99+
throw new Error(`Unrecognized model type: ${model}`);
100+
}
101+
await evaluateModel(model, args.numDataPoints);
48102

49103
fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
50104
console.log(`\nResults written to ${logFile}`);

datasets/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@
2020

2121
# GDHF demo
2222
/tinder_dog/
23+
24+
# HellaSwag benchmark
25+
hellaswag*

discojs/src/models/gpt/layers.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ describe('GPT Layers', () => {
174174
name: 'testCSA',
175175
contextLength: 5,
176176
nHead: 2,
177-
nEmbd: 8, // divisible by nHead, so head size = 4
178-
dropout: 0.0, // no dropout for deterministic tests
177+
nEmbd: 8, // divisible by nHead, so head size = 4
178+
attnDrop: 0.0, // no dropout for deterministic tests
179+
residDrop: 0.0,
179180
nLayer: 2,
180181
seed: 42
181182
};

discojs/src/models/hellaswag.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ type ModelType = GPT | ONNXModel;
126126
export async function evaluate(
127127
model: ModelType,
128128
tokenizer: Tokenizer,
129-
dataset: HellaSwagExample[],
129+
dataset: HellaSwagDataset,
130130
print = true
131131
): Promise<number> {
132132
let correct = 0;

onnx-converter/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node_modules
2+
assets
3+
dist

onnx-converter/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
## Usage
2+
3+
This workspace is currently used to convert ONNX [GPT-2 model](https://huggingface.co/Xenova/gpt2) to Tensorflow.js. On the one hand, ONNX allows converting pretrained models from PyTorch or Tensorflow to the ONNX format, therefore there currently exists many pretrained models in ONNX format. However, ONNX libraries currently only support inference. On the other hand, Tensorflow.js doesn't have a converter that can handle recent Transformers models (despite having a [converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter)), but TF.js allows further training models.
4+
5+
Therefore, we want to convert pretrained models such as GPT-2 from ONNX format to Tensorflow.js to further fine-tune them. You generate a TF.js `model.json` by running `npm run convert_onnx` in this workspace.
6+
7+
What the script does is:
8+
1. Read the ONNX GPT-2 model from [Xenova's repository](https://huggingface.co/Xenova/gpt2)
9+
2. Use the ONNX protobuf definition to read the file and iterate through the model layers. The ONNX JavaScript protobuf comes from [this repository](https://github.com/microsoft/onnxruntime/blob/main/js/web/lib/onnxjs/).
10+
3. Convert all weights to TF.js tensors
11+
4. Init a TF.js model with the loaded weights and export the model
12+
13+
Running `npm run convert_onnx` creates a GPT-tfjs `model.json` file in the `./assets/` folder.
14+
15+
## ONNX JS protobuf
16+
17+
The ONNX specification has limited support in JavaScript. We found an old JS implementation in the [ONNX Runtime Web repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf). We had to adapt their files as follows to be compatible with our newer environment:
18+
1. Copy `onnx.js` and `onnx.d.ts` from [the repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf) in `./onnx-converter/src/protobuf`
19+
2. Rename `onnx.js` to `onnx.cjs`
20+
3. Create `onnx-proto.js` as a wrapper around the protobuf definition:
21+
```js
22+
import { createRequire } from 'module';
23+
const require = createRequire(import.meta.url);
24+
const onnxModule = require('./onnx.cjs');
25+
26+
export const onnx = onnxModule.onnx;
27+
export default onnxModule;
28+
```
29+
4. Create `onnx-proto.d.ts` with the matching TypeScript definition:
30+
```ts
31+
export { onnx } from './onnx.js';
32+
declare const onnxModule: {
33+
onnx: typeof import('./onnx.js').onnx;
34+
};
35+
export default onnxModule;
36+
```

onnx-converter/package.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"name": "onnx-converter",
3+
"private": true,
4+
"type": "module",
5+
"main": "dist/gpt2_from_onnx.js",
6+
"scripts": {
7+
"convert_onnx": "npm run build && node dist/convert_onnx.js",
8+
"build": "tsc && cp -r src/protobuf dist",
9+
"lint": "npx eslint .",
10+
"test": ": nothing"
11+
},
12+
"author": "",
13+
"license": "ISC",
14+
"dependencies": {
15+
"@epfml/discojs-node": "*"
16+
},
17+
"devDependencies": {
18+
"nodemon": "3",
19+
"ts-command-line-args": "2"
20+
}
21+
}

0 commit comments

Comments
 (0)