Skip to content

Commit 2e66f2c

Browse files
committed
feat: add extractFeatures helper
1 parent fa24604 commit 2e66f2c

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { Sample } from '../training/sample';
2+
import { FeatureExtractor } from './feature-extractor';
3+
4+
export const extractFeatures = async <D, T>({
5+
data,
6+
inputFeatureExtractors,
7+
outputFeatureExtractor
8+
}: {
9+
data: Array<D>;
10+
inputFeatureExtractors: Array<FeatureExtractor<D, T>>;
11+
outputFeatureExtractor: FeatureExtractor<D, T>;
12+
}): Promise<Array<Sample>> => {
13+
const samples = [];
14+
15+
for (const dataItem of data) {
16+
const [inputFeatures, outputFeature] = await Promise.all([
17+
Promise.all(
18+
inputFeatureExtractors.map((featureExtractor) => {
19+
return featureExtractor.extract(dataItem);
20+
})
21+
),
22+
outputFeatureExtractor.extract(dataItem)
23+
]);
24+
25+
const input = inputFeatures.map((feature) => feature.value);
26+
const output = [outputFeature.value];
27+
28+
samples.push({ input, output });
29+
}
30+
31+
return samples;
32+
}

packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { data, TensorContainer } from '@tensorflow/tfjs-node';
22
import { splitSamplesIntoTrainingValidationTestForBinaryClassification } from '../data-splitting/training-validation-test-for-binary-classification';
33
import { makeDataset } from '../utils/make-dataset';
4+
import { extractFeatures } from './extract-features';
45
import { FeatureExtractor } from './feature-extractor';
56

67
export const prepareDatasetsForBinaryClassification = async <D, T>({
@@ -24,23 +25,11 @@ export const prepareDatasetsForBinaryClassification = async <D, T>({
2425
validationDataset: data.Dataset<TensorContainer>;
2526
testingDataset: data.Dataset<TensorContainer>;
2627
}> => {
27-
const samples = [];
28-
29-
for (const dataItem of data) {
30-
const [inputFeatures, outputFeature] = await Promise.all([
31-
Promise.all(
32-
inputFeatureExtractors.map((featureExtractor) => {
33-
return featureExtractor.extract(dataItem);
34-
})
35-
),
36-
outputFeatureExtractor.extract(dataItem)
37-
]);
38-
39-
const input = inputFeatures.map((feature) => feature.value);
40-
const output = [outputFeature.value];
41-
42-
samples.push({ input, output });
43-
}
28+
const samples = await extractFeatures({
29+
data,
30+
inputFeatureExtractors,
31+
outputFeatureExtractor
32+
});
4433

4534
const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(
4635
samples,

packages/tfjs-node-helpers/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
export * from './classification/binary-classification-trainer';
22
export * from './classification/binary-classifier';
33
export * from './data-splitting/training-validation-test-for-binary-classification';
4+
export * from './feature-engineering/extract-features';
45
export * from './feature-engineering/feature-extractor';
56
export * from './feature-engineering/feature';
67
export * from './feature-engineering/prepare-datasets-for-binary-classification';

0 commit comments

Comments
 (0)