Skip to content

Commit 407ecfa

Browse files
Merge pull request #20 from RonasIT/14-asynchronously-loaded-datasets
Asynchronously loaded datasets
2 parents 40ee9a4 + 03e7334 commit 407ecfa

File tree

8 files changed

+140
-19
lines changed

8 files changed

+140
-19
lines changed

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,51 @@ await trainer.trainAndTest({
127127
});
128128
```
129129

130+
##### Loading data asynchronously
131+
132+
When working with large dataset, you might find out that the whole dataset
133+
can't fit in memory. In this situation you might want to load the data in
134+
chunks. To do this, you can define the asynchronous generators for
135+
`trainingDataset`, `validationDataset` and `testingDataset`.
136+
137+
This library provides the `makeChunkedDataset` helper to make it easier to
138+
create chunked datasets where chunks are controlled with `skip` and `take`
139+
parameters.
140+
141+
`makeChunkedDataset` helper accepts the following parameters:
142+
143+
- `loadChunk` – an asynchronous function accepting the numeric `skip` and `take`
144+
parameters and returning an array of samples.
145+
- `chunkSize` – the number of samples loaded per chunk.
146+
- `batchSize` – the number of samples in each batch.
147+
148+
```typescript
149+
const loadTrainingSamplesChunk = async (skip: number, take: number): Promise<Array<Sample>> => {
150+
// Your samples chunk loading logic goes here. For example, you may want to
151+
// load samples from database, or from a remote data source.
152+
};
153+
154+
const makeTrainingDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
155+
loadChunk: loadTrainingSamplesChunk,
156+
chunkSize: 32,
157+
batchSize: 32
158+
});
159+
160+
// You should also define similar functions for validationDataset and
161+
// trainingDataset. We omit this for the sake of brevity.
162+
163+
const trainingDataset = makeTrainingDataset();
164+
const validationDataset = makeValidationDataset();
165+
const testingDataset = makeTestingDataset();
166+
167+
await trainer.trainAndTest({
168+
trainingDataset,
169+
validationDataset,
170+
testingDataset,
171+
printTestingResults: true
172+
});
173+
```
174+
130175
#### Saving the model
131176

132177
To save the trained model, you need to call the `save` method of the
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { Sample } from '../training/sample';
2+
import { FeatureExtractor } from './feature-extractor';
3+
4+
export const extractFeatures = async <D, T>({
5+
data,
6+
inputFeatureExtractors,
7+
outputFeatureExtractor
8+
}: {
9+
data: Array<D>;
10+
inputFeatureExtractors: Array<FeatureExtractor<D, T>>;
11+
outputFeatureExtractor: FeatureExtractor<D, T>;
12+
}): Promise<Array<Sample>> => {
13+
const samples = [];
14+
15+
for (const dataItem of data) {
16+
const [inputFeatures, outputFeature] = await Promise.all([
17+
Promise.all(
18+
inputFeatureExtractors.map((featureExtractor) => {
19+
return featureExtractor.extract(dataItem);
20+
})
21+
),
22+
outputFeatureExtractor.extract(dataItem)
23+
]);
24+
25+
const input = inputFeatures.map((feature) => feature.value);
26+
const output = [outputFeature.value];
27+
28+
samples.push({ input, output });
29+
}
30+
31+
return samples;
32+
}

packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { data, TensorContainer } from '@tensorflow/tfjs-node';
22
import { splitSamplesIntoTrainingValidationTestForBinaryClassification } from '../data-splitting/training-validation-test-for-binary-classification';
3-
import { makeDataset } from '../utils/make-dataset-from-array';
3+
import { makeDataset } from '../utils/make-dataset';
4+
import { extractFeatures } from './extract-features';
45
import { FeatureExtractor } from './feature-extractor';
56

67
export const prepareDatasetsForBinaryClassification = async <D, T>({
@@ -24,23 +25,11 @@ export const prepareDatasetsForBinaryClassification = async <D, T>({
2425
validationDataset: data.Dataset<TensorContainer>;
2526
testingDataset: data.Dataset<TensorContainer>;
2627
}> => {
27-
const samples = [];
28-
29-
for (const dataItem of data) {
30-
const [inputFeatures, outputFeature] = await Promise.all([
31-
Promise.all(
32-
inputFeatureExtractors.map((featureExtractor) => {
33-
return featureExtractor.extract(dataItem);
34-
})
35-
),
36-
outputFeatureExtractor.extract(dataItem)
37-
]);
38-
39-
const input = inputFeatures.map((feature) => feature.value);
40-
const output = [outputFeature.value];
41-
42-
samples.push({ input, output });
43-
}
28+
const samples = await extractFeatures({
29+
data,
30+
inputFeatureExtractors,
31+
outputFeatureExtractor
32+
});
4433

4534
const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(
4635
samples,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
export * from './classification/binary-classification-trainer';
22
export * from './classification/binary-classifier';
33
export * from './data-splitting/training-validation-test-for-binary-classification';
4+
export * from './feature-engineering/extract-features';
45
export * from './feature-engineering/feature-extractor';
56
export * from './feature-engineering/feature';
67
export * from './feature-engineering/prepare-datasets-for-binary-classification';
78
export * from './training/sample';
9+
export * from './utils/make-chunked-dataset';
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { Tensor, tensor } from '@tensorflow/tfjs-node';
2+
import { Sample } from '../training/sample';
3+
4+
export const makeChunkedDatasetGenerator = async function* ({
5+
loadChunk,
6+
chunkSize
7+
}: {
8+
loadChunk: (skip: number, take: number) => Promise<Array<Sample>>,
9+
chunkSize: number
10+
}): AsyncGenerator<{ xs: Tensor, ys: Tensor }> {
11+
let skip = 0;
12+
let take = chunkSize;
13+
14+
while (true) {
15+
const samples = await loadChunk(skip, take);
16+
17+
for (const sample of samples) {
18+
yield {
19+
xs: tensor(sample.input),
20+
ys: tensor(sample.output)
21+
};
22+
}
23+
24+
if (samples.length < take) {
25+
break;
26+
}
27+
28+
skip += take;
29+
}
30+
};
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { data, TensorContainer } from '@tensorflow/tfjs-node';
2+
import { Sample } from '../training/sample';
3+
import { makeChunkedDatasetGenerator } from '../utils/make-chunked-dataset-generator';
4+
5+
export const makeChunkedDataset = ({
6+
loadChunk,
7+
chunkSize,
8+
batchSize
9+
}: {
10+
loadChunk: (skip: number, take: number) => Promise<Array<Sample>>,
11+
chunkSize: number,
12+
batchSize: number
13+
}): data.Dataset<TensorContainer> => {
14+
return data
15+
.generator(
16+
() => makeChunkedDatasetGenerator({
17+
loadChunk,
18+
chunkSize
19+
}) as any
20+
)
21+
.batch(batchSize);
22+
};

packages/tfjs-node-helpers/src/utils/make-dataset-from-array.ts renamed to packages/tfjs-node-helpers/src/utils/make-dataset.ts

File renamed without changes.

packages/tfjs-node-helpers/tsconfig.lib.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"compilerOptions": {
44
"outDir": "../../dist/out-tsc",
55
"declaration": true,
6-
"types": []
6+
"types": [],
7+
"lib": ["es2018"]
78
},
89
"include": ["**/*.ts"],
910
"exclude": ["jest.config.ts", "**/*.spec.ts", "**/*.test.ts"]

0 commit comments

Comments
 (0)