Skip to content

Commit 8f70f53

Browse files
committed
Refactoring.
1 parent 64aea5e commit 8f70f53

File tree

8 files changed

+52
-83
lines changed

8 files changed

+52
-83
lines changed

src/main/java/com/github/felipexw/classifiers/bayes/MultinomialNaiveBayesClassifier.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void calculateProbs(List<LabeledInstance> instanceList) {
5454
if (instanceList == null || instanceList.isEmpty())
5555
throw new IllegalArgumentException("Args can't be null");
5656

57-
for (LabeledInstance<Label, Model> instance : instanceList) {
57+
for (LabeledInstance<Model> instance : instanceList) {
5858
String key = instance.getLabel().toString();
5959
if (!labels.containsKey(key)) {
6060
labels.put(instance.getLabel().toString(), 1);
@@ -86,8 +86,8 @@ public void calculatePosterioriProbability(LabeledInstance instance) {
8686
*/
8787
}
8888

89-
private void countFromLabels(List<LabeledInstance<Label, Model>> instances, LabeledInstance<Label, Model> instance) {
90-
for (LabeledInstance<Label, Model> featuresInstance : instances) {
89+
private void countFromLabels(List<LabeledInstance<Model>> instances, LabeledInstance<Model> instance) {
90+
for (LabeledInstance<Model> featuresInstance : instances) {
9191
if (featuresInstance.getLabel().toString().equalsIgnoreCase(instance.getLabel().toString())) {
9292
featuresInstance.setCount(featuresInstance.getCount() + 1);
9393
} else {

src/main/java/com/github/felipexw/classifiers/neighbors/SimpleKNNClassifier.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,11 @@ protected List<Neighbor> getNeighborsWithDistanceFromARootNeighboor(Neighbor nei
142142
LabeledInstance instance = neighbor.getInstance();
143143

144144
for (int j = -1; j < instances.size() - 1; j++) {
145-
LabeledInstance neighborInstance = instances.get(j + 1);
146-
double similarity = similarityCalculator.calculate(instance.getFeatures(), neighborInstance.getFeatures());
145+
LabeledInstance<Model> neighborInstance = instances.get(j + 1);
146+
Model<Double> model = instance.getModel();
147+
Model<Double> neighborModel = neighborInstance.getModel();
148+
149+
double similarity = similarityCalculator.calculate(model.getData(), neighborModel.getData());
147150
Neighbor neighborRoot = new Neighbor(neighborInstance, similarity, new DoubleFeatureExtractor());
148151
neighbors.add(neighborRoot);
149152

@@ -193,7 +196,7 @@ protected List<Neighbor> getAllNeighbors(Instance labeledInstance) {
193196
List<Neighbor> neighborses = new ArrayList<>();
194197
for (short i = 0; i < instances.size(); i++) {
195198
LabeledInstance trainingInstance = instances.get(i);
196-
double distance = similarityCalculator.calculate(labeledInstance.getFeatures(), trainingInstance.getFeatures());
199+
double distance = similarityCalculator.calculate(labeledInstance.getModel().getData(), trainingInstance.getModel().getData());
197200

198201
Neighbor neighbor = new Neighbor(trainingInstance, distance, featureExtractor);
199202
neighborses.add(neighbor);
@@ -207,7 +210,7 @@ public Prediction vote(List<Neighbor> neighbors) {
207210
Map<String, Integer> votes = new HashMap<>();
208211

209212
for (Neighbor neighbor : neighbors) {
210-
LabeledInstance<Label, Model> instance = neighbor.getInstance();
213+
LabeledInstance<Model> instance = neighbor.getInstance();
211214
if (!votes.containsKey(instance.getLabel()))
212215
votes.put(instance.getLabel().toString(), 1);
213216

@@ -222,7 +225,6 @@ public Prediction vote(List<Neighbor> neighbors) {
222225
int nearestNeighborIndex = getIndexOfNearestNeighboorVoted(mostVotedLabel, neighbors);
223226
Neighbor neighbor = neighbors.get(nearestNeighborIndex);
224227

225-
226228
return new Prediction(mostVotedLabel, neighbor.getDistance() / 100);
227229
}
228230

@@ -248,7 +250,7 @@ protected int getIndexOfNearestNeighboorVoted(String label, List<Neighbor> neigh
248250

249251
for (int i = 0; i < neighbors.size(); i++) {
250252
Neighbor neighbor = neighbors.get(i);
251-
LabeledInstance<Label, Model> instance = neighbor.getInstance();
253+
LabeledInstance<Model> instance = neighbor.getInstance();
252254
if (instance.getLabel().toString().equalsIgnoreCase(label) && neighbor.getDistance() < distance) {
253255
distance = neighbor.getDistance();
254256
index = i;

src/main/java/com/github/felipexw/core/Instance.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
/**
66
* Created by felipe.appio on 23/08/2016.
77
*/
8-
public abstract class Instance<F> {
9-
protected List<F> features;
8+
public abstract class Instance<F extends Model> {
9+
protected final F model;
1010

11-
public abstract List<F> getFeatures();
11+
public Instance(F model) {
12+
this.model = model;
13+
}
14+
15+
public Model getModel() {
16+
return model;
17+
}
1218
}

src/main/java/com/github/felipexw/core/Label.java

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/main/java/com/github/felipexw/core/LabeledInstance.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,24 @@
66
/**
77
* Created by felipe.appio on 23/08/2016.
88
*/
9-
public class LabeledInstance<L extends Label, F extends Model> extends Instance<F> {
10-
protected L label;
9+
public class LabeledInstance<F extends Model> extends Instance<F> {
10+
protected String label;
1111
private int count;
1212

13-
public LabeledInstance(List<F> features,L label) {
13+
public LabeledInstance(String label, F model) {
14+
super(model);
1415
this.label = label;
15-
this.features = features;
1616
}
1717

18-
public L getLabel() {
18+
public String getLabel() {
1919
return this.label;
2020
}
2121

22+
2223
public void setCount(int count) {
2324
this.count = count;
2425
}
2526

26-
@Override
27-
public List<F> getFeatures() {
28-
29-
return null;
30-
}
3127

3228
public int getCount() {
3329
return count;
@@ -38,7 +34,7 @@ public boolean equals(Object o) {
3834
if (this == o) return true;
3935
if (o == null || getClass() != o.getClass()) return false;
4036

41-
LabeledInstance<?, ?> that = (LabeledInstance<?, ?>) o;
37+
LabeledInstance<?> that = (LabeledInstance<?>) o;
4238

4339
return label.equals(that.label);
4440

src/main/java/com/github/felipexw/core/Model.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
/**
99
* Created by felipe.appio on 31/08/2016.
1010
*/
11-
public class Model implements FeatureVector {
11+
public class Model<T> implements FeatureVector {
1212

1313
protected FeatureExtractor featureExtractor;
14-
protected List features;
14+
protected List<T> features;
1515

1616
public Model(FeatureExtractor featureExtractor) {
1717
this.featureExtractor = featureExtractor;
1818
}
1919

2020
@Override
21-
public List getData() {
22-
return featureExtractor.extract(features);
21+
public List<T> getData() {
22+
if (features == null)
23+
features = featureExtractor.extract(features);
24+
return features;
2325
}
2426

2527
}

src/main/java/com/github/felipexw/core/Prediction.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
/**
44
* Created by felipe.appio on 24/08/2016.
55
*/
6-
public class Prediction<L> {
7-
private final L label;
6+
public class Prediction{
7+
private final String label;
88
private final double score;
99
private int count;
1010

11-
public L getLabel() {
11+
public String getLabel() {
1212
return label;
1313
}
1414

15-
public Prediction(L label, double score) {
15+
public Prediction(String label, double score) {
1616
this.label = label;
1717
this.score = score;
1818
}
@@ -29,21 +29,6 @@ public double getScore() {
2929
return score;
3030
}
3131

32-
@Override
33-
public boolean equals(Object o) {
34-
if (this == o) return true;
35-
if (!(o instanceof Prediction)) return false;
36-
37-
Prediction that = (Prediction) o;
38-
39-
return label.equals(that.label);
40-
}
41-
42-
@Override
43-
public int hashCode() {
44-
return label.hashCode();
45-
}
46-
4732
@Override
4833
public String toString() {
4934
return "Prediction{" +

src/test/java/SimpleKNNClassifierTest.java

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1-
import com.github.felipexw.classifiers.neighbors.KNNClassifier;
2-
import com.github.felipexw.classifiers.neighbors.Neighbor;
31
import com.github.felipexw.classifiers.neighbors.SimpleKNNClassifier;
4-
import com.github.felipexw.core.Label;
52
import com.github.felipexw.core.Model;
63
import com.github.felipexw.core.Prediction;
74
import com.github.felipexw.core.extraction.DoubleFeatureExtractor;
85
import com.github.felipexw.core.extraction.FeatureExtractor;
96
import com.github.felipexw.evaluations.metrics.EuclidianSimilarityCalculator;
107
import com.github.felipexw.core.LabeledInstance;
118
import com.google.common.truth.Truth;
12-
import javafx.scene.control.Labeled;
139
import org.junit.Before;
1410
import org.junit.Test;
15-
import com.google.common.truth.Truth.*;
1611

17-
import java.sql.Array;
1812
import java.util.*;
1913

2014
/**
@@ -24,9 +18,9 @@
2418
*/
2519
public class SimpleKNNClassifierTest {
2620
static class TestModel extends Model{
27-
private List<String> stringFeatures;
21+
private List<Double> stringFeatures;
2822

29-
public TestModel(FeatureExtractor featureExtractor, List<String> stringFeatures) {
23+
public TestModel(FeatureExtractor featureExtractor, List<Double> stringFeatures) {
3024
super(featureExtractor);
3125
this.features = stringFeatures;
3226
}
@@ -60,31 +54,31 @@ public void it_should_predict_a_negative_label(){
6054
the algorithm must predict the label (which its positive or negative) for the point E(1,3)
6155
*/
6256

63-
Label positiveLabel = new Label("positive");
64-
Label negativeLabel = new Label("negative");
57+
String positiveLabel = new String("positive");
58+
String negativeLabel = new String("negative");
6559
FeatureExtractor featureExtractor = new DoubleFeatureExtractor();
6660

67-
TestModel t1 = new TestModel(featureExtractor, Arrays.asList("3.0", "4.0"));
68-
LabeledInstance<Label, Model> pointA = new LabeledInstance<Label, Model>(Arrays.asList(t1), negativeLabel);
61+
TestModel t1 = new TestModel(featureExtractor, Arrays.asList(3d, 4d));
62+
LabeledInstance<Model> pointA = new LabeledInstance<Model>(negativeLabel, t1);
6963

70-
TestModel t2 = new TestModel(featureExtractor, Arrays.asList("3.0", "2.0"));
71-
LabeledInstance<Label, Model> pointB = new LabeledInstance<Label, Model>(Arrays.asList(t2), negativeLabel);
64+
TestModel t2 = new TestModel(featureExtractor, Arrays.asList(3d, 2d));
65+
LabeledInstance<Model> pointB = new LabeledInstance<Model>(negativeLabel, t2);
7266

73-
TestModel t3 = new TestModel(featureExtractor, Arrays.asList("4.0", "1.0"));
74-
LabeledInstance<Label, Model> pointC = new LabeledInstance<Label, Model>(Arrays.asList(t3), negativeLabel);
67+
TestModel t3 = new TestModel(featureExtractor, Arrays.asList(4d, 1d));
68+
LabeledInstance<Model> pointC = new LabeledInstance<Model>(positiveLabel, t3);
7569

76-
TestModel t4 = new TestModel(featureExtractor, Arrays.asList("5.0", "5.0"));
77-
LabeledInstance<Label, Model> pointD = new LabeledInstance<Label, Model>(Arrays.asList(t2), negativeLabel);
70+
TestModel t4 = new TestModel(featureExtractor, Arrays.asList(50d, 5d));
71+
LabeledInstance<Model> pointD = new LabeledInstance<Model>(positiveLabel, t4);
7872

79-
TestModel predictingTest = new TestModel(featureExtractor, Arrays.asList("1.0", "3.0"));
80-
LabeledInstance<Label, Model> pointE = new LabeledInstance<Label, Model>(Arrays.asList(predictingTest), negativeLabel);
73+
TestModel predictingTest = new TestModel(featureExtractor, Arrays.asList(1d, 3d));
74+
LabeledInstance<Model> pointE = new LabeledInstance<Model>(negativeLabel,predictingTest);
8175

8276
classifier.setK(5);
8377
classifier.train(Arrays.asList(pointA, pointB, pointC, pointD));
8478
Prediction predictedInstance = classifier.predict(pointE);
8579

8680
Truth.assertThat(predictedInstance.getLabel())
87-
.isEqualTo(negativeLabel);
81+
.isEqualTo(positiveLabel);
8882
}
8983

9084
}

0 commit comments

Comments
 (0)