Skip to content

Commit 57aa746

Browse files
committed
feat: add makeChunkedDataset helper
1 parent 2e66f2c commit 57aa746

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

packages/tfjs-node-helpers/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ export * from './feature-engineering/feature-extractor';
66
export * from './feature-engineering/feature';
77
export * from './feature-engineering/prepare-datasets-for-binary-classification';
88
export * from './training/sample';
9+
export * from './utils/make-chunked-dataset';
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { Tensor, tensor } from '@tensorflow/tfjs-node';
2+
import { Sample } from '../training/sample';
3+
4+
export const makeChunkedDatasetGenerator = async function* ({
5+
loadChunk,
6+
chunkSize
7+
}: {
8+
loadChunk: (skip: number, take: number) => Promise<Array<Sample>>,
9+
chunkSize: number
10+
}): AsyncGenerator<{ xs: Tensor, ys: Tensor }> {
11+
let skip = 0;
12+
let take = chunkSize;
13+
14+
while (true) {
15+
const samples = await loadChunk(skip, take);
16+
17+
for (const sample of samples) {
18+
yield {
19+
xs: tensor(sample.input),
20+
ys: tensor(sample.output)
21+
};
22+
}
23+
24+
if (samples.length < take) {
25+
break;
26+
}
27+
28+
skip += take;
29+
}
30+
};
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { data, TensorContainer } from '@tensorflow/tfjs-node';
2+
import { Sample } from '../training/sample';
3+
import { makeChunkedDatasetGenerator } from '../utils/make-chunked-dataset-generator';
4+
5+
export const makeChunkedDataset = ({
6+
loadChunk,
7+
chunkSize,
8+
batchSize
9+
}: {
10+
loadChunk: (skip: number, take: number) => Promise<Array<Sample>>,
11+
chunkSize: number,
12+
batchSize: number
13+
}): data.Dataset<TensorContainer> => {
14+
return data
15+
.generator(
16+
() => makeChunkedDatasetGenerator({
17+
loadChunk,
18+
chunkSize
19+
}) as any
20+
)
21+
.batch(batchSize);
22+
};

packages/tfjs-node-helpers/tsconfig.lib.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"compilerOptions": {
44
"outDir": "../../dist/out-tsc",
55
"declaration": true,
6-
"types": []
6+
"types": [],
7+
"lib": ["es2018"]
78
},
89
"include": ["**/*.ts"],
910
"exclude": ["jest.config.ts", "**/*.spec.ts", "**/*.test.ts"]

0 commit comments

Comments
 (0)