-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
59 lines (44 loc) · 1.82 KB
/
main.cpp
File metadata and controls
59 lines (44 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <iostream>
#include "Dataset.h"
#include "MultilayerPerceptron.h"
#include "ActivationFunction.h"
#include "CostFunction.h"
#include "stopwatch/Stopwatch.hpp"
int main()
{
Stopwatch stopwatch{};
// Create a network with 3 input neurons, one hidden layer with 3 neurons, and an output layer of 2 neurons.
Sigmoid hiddenLayerActivationFunction{};
Softmax outputLayerActivationFunction{};
MeanSquaredError costFunction{};
auto network = std::make_unique<MultilayerPerceptron>(
std::vector<int>{784, 100, 10},
hiddenLayerActivationFunction,
outputLayerActivationFunction,
costFunction);
stopwatch.addMeasurement("Before reading data.");
Dataset *trainingData = new MnistDataset("train-images-idx3-ubyte", "train-labels-idx1-ubyte");
stopwatch.addMeasurement("Data read");
network->Train(trainingData, 0.5, 100, 10);
stopwatch.addMeasurement("Network trained");
Dataset *testData = new MnistDataset("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte");
int correctClassifications = 0;
while(!testData->EndOfData())
{
FeatureVector fv = testData->GetNextFeatureVector();
vector<double> predicted = network->ForwardPass(fv.data);
string predictedClass = testData->ClassificationToString(predicted);
string actualClass = testData->ClassificationToString(fv.label);
if(predictedClass == actualClass) { ++correctClassifications; }
cout << "Predicted: " << predictedClass << endl;
for(double element : predicted) { cout << element << " ";}
cout << endl;
cout << "Actual: " << actualClass << endl;
for(double element : fv.label) { cout << element << " ";}
cout << endl;
}
stopwatch.addMeasurement("Accuracy test completed");
cout << "Correct classifications: " << correctClassifications << endl;
cout << "Accuracy: " << (correctClassifications / 10000.0) << endl;
cout << stopwatch.getTimingTrace();
}