Skip to content

Commit 254fc2e

Browse files
committed
Refactoring.
1 parent 058b5cc commit 254fc2e

File tree

12 files changed

+135
-44
lines changed

12 files changed

+135
-44
lines changed

src/main/java/com/github/felipexw/classifiers/bayes/MultinomialNaiveBayesClassifier.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package com.github.felipexw.classifiers.bayes;
22

3-
import com.github.felipexw.core.Instance;
4-
import com.github.felipexw.core.LabeledInstance;
5-
import com.github.felipexw.core.Prediction;
3+
import com.github.felipexw.core.*;
64

75
import java.util.ArrayList;
86
import java.util.HashMap;
@@ -56,10 +54,10 @@ public void calculateProbs(List<LabeledInstance> instanceList) {
5654
if (instanceList == null || instanceList.isEmpty())
5755
throw new IllegalArgumentException("Args can't be null");
5856

59-
for (LabeledInstance<String, Double> instance : instanceList) {
60-
String key = instance.getLabel();
57+
for (LabeledInstance<Label, Model> instance : instanceList) {
58+
String key = instance.getLabel().toString();
6159
if (!labels.containsKey(key)) {
62-
labels.put(instance.getLabel(), 1);
60+
labels.put(instance.getLabel().toString(), 1);
6361
}
6462
else{
6563
labels.put(key, labels.get(key)+1);
@@ -88,9 +86,9 @@ public void calculatePosterioriProbability(LabeledInstance instance) {
8886
*/
8987
}
9088

91-
private void countFromLabels(List<LabeledInstance<String, Double>> instances, LabeledInstance<String, Double> instance) {
92-
for (LabeledInstance<String, Double> featuresInstance : instances) {
93-
if (featuresInstance.getLabel().equalsIgnoreCase(instance.getLabel())) {
89+
private void countFromLabels(List<LabeledInstance<Label, Model>> instances, LabeledInstance<Label, Model> instance) {
90+
for (LabeledInstance<Label, Model> featuresInstance : instances) {
91+
if (featuresInstance.getLabel().toString().equalsIgnoreCase(instance.getLabel().toString())) {
9492
featuresInstance.setCount(featuresInstance.getCount() + 1);
9593
} else {
9694
instance.setCount(1);

src/main/java/com/github/felipexw/classifiers/neighbors/KNNClassifier.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.github.felipexw.classifiers.Classifier;
44
import com.github.felipexw.core.CrossValidation;
5+
import com.github.felipexw.core.extraction.FeatureExtractor;
56
import com.github.felipexw.evaluations.metrics.SimilarityCalculator;
67
import com.github.felipexw.core.LabeledInstance;
78
import com.github.felipexw.core.Prediction;
@@ -18,6 +19,7 @@ public abstract class KNNClassifier implements Classifier, CrossValidation {
1819
protected Map<Neighbor, List<Neighbor>> features;
1920
protected List<LabeledInstance> instances;
2021
protected SimilarityCalculator similarityCalculator;
22+
protected FeatureExtractor featureExtractor;
2123

2224
public abstract Prediction vote(List<Neighbor> neighbors);
2325

src/main/java/com/github/felipexw/classifiers/neighbors/Neighbor.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
package com.github.felipexw.classifiers.neighbors;
22

33
import com.github.felipexw.core.LabeledInstance;
4+
import com.github.felipexw.core.Model;
5+
import com.github.felipexw.core.extraction.FeatureExtractor;
46

57
/**
68
* Created by felipe.appio on 24/08/2016.
79
*/
8-
public class Neighbor {
9-
private final LabeledInstance instance;
10-
private final double distance;
10+
public class Neighbor extends Model {
11+
private LabeledInstance instance;
12+
private double distance;
1113

12-
public Neighbor(LabeledInstance instance, double distance) {
14+
public Neighbor(LabeledInstance instance, double distance, FeatureExtractor featureExtractor) {
15+
super(featureExtractor);
1316
this.instance = instance;
1417
this.distance = distance;
1518
}

src/main/java/com/github/felipexw/classifiers/neighbors/SimpleKNNClassifier.java

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package com.github.felipexw.classifiers.neighbors;
22

3+
import com.github.felipexw.core.*;
4+
import com.github.felipexw.core.extraction.DoubleFeatureExtractor;
5+
import com.github.felipexw.core.extraction.FeatureExtractor;
36
import com.github.felipexw.evaluations.EvaluatorMetric;
47
import com.github.felipexw.evaluations.metrics.SimilarityCalculator;
5-
import com.github.felipexw.core.Instance;
6-
import com.github.felipexw.core.LabeledInstance;
7-
import com.github.felipexw.core.Prediction;
88
import com.google.common.collect.ImmutableMap;
99
import com.google.common.collect.Lists;
1010

1111
import java.util.*;
12-
import java.util.List;
1312

1413
/**
1514
* Created by felipe.appio on 23/08/2016.
1615
*/
1716
public class SimpleKNNClassifier extends KNNClassifier {
1817

1918

20-
public SimpleKNNClassifier(SimilarityCalculator similarityCalculator) {
19+
public SimpleKNNClassifier(SimilarityCalculator similarityCalculator, FeatureExtractor featureExtractor) {
2120
this.similarityCalculator = similarityCalculator;
2221
k = 5;
2322
}
@@ -53,7 +52,7 @@ protected List<LabeledInstance> getInstancesthatMaximizeAccuracy() {
5352
for (List<LabeledInstance> labeled : partitionedInstances) {
5453
for (LabeledInstance instance : labeled) {
5554
if (i != testIndex) {
56-
Neighbor neighbor = new Neighbor(instance, -1d);
55+
Neighbor neighbor = new Neighbor(instance, -1d, featureExtractor);
5756
List<Neighbor> neighbors = getNeighborsWithDistanceFromARootNeighboor(neighbor, k);
5857
features.put(neighbor, neighbors);
5958
} else
@@ -129,24 +128,29 @@ protected void setUpForTraining(List<LabeledInstance> instances) {
129128
protected void calculateFeatureSimilarities() {
130129
for (int i = 0; i < instances.size(); i++) {
131130
LabeledInstance instance = instances.get(i);
132-
Neighbor neighbor = new Neighbor(instance, -1d);
131+
Neighbor neighbor = new Neighbor(instance, -1d, featureExtractor);
133132

134133
List<Neighbor> neighbors = getNeighborsWithDistanceFromARootNeighboor(neighbor, this.k);
135134
features.put(neighbor, neighbors);
136135
}
137136
}
138137

139138
protected List<Neighbor> getNeighborsWithDistanceFromARootNeighboor(Neighbor neighbor, int threshold) {
140-
throw new UnsupportedOperationException("Continue the implementation");
141139

140+
141+
// LabeledInstance<Double, String> t = new LabeledInstance<>("2");
142+
143+
144+
Neighbor nei = new Neighbor(null, 0d, featureExtractor);
142145
List<Neighbor> neighbors = new ArrayList<>();
146+
143147
LabeledInstance instance = neighbor.getInstance();
144148

145149
for (int j = -1; j < instances.size() - 1; j++) {
146150
LabeledInstance neighborInstance = instances.get(j + 1);
147151
// double similarity = similarityCalculator.calculate(instance.getFeatures(), neighborInstance.getFeatures());
148152
double similarity = 0d;
149-
Neighbor neighborRoot = new Neighbor(neighborInstance, similarity);
153+
Neighbor neighborRoot = new Neighbor(neighborInstance, similarity, new DoubleFeatureExtractor());
150154
neighbors.add(neighborRoot);
151155
if (neighbors.size() == threshold)
152156
return neighbors;
@@ -163,7 +167,7 @@ public List<Neighbor> similarNeighbors(LabeledInstance trainingInstance, int k)
163167
if (features.containsKey(trainingInstance))
164168
features.get(trainingInstance);
165169

166-
Neighbor neighbor1 = new Neighbor(trainingInstance, -1d);
170+
Neighbor neighbor1 = new Neighbor(trainingInstance, -1d, featureExtractor);
167171
return getNeighborsWithDistanceFromARootNeighboor(neighbor1, k);
168172
}
169173

@@ -209,13 +213,13 @@ public Prediction vote(List<Neighbor> neighbors) {
209213
Map<String, Integer> votes = new HashMap<>();
210214

211215
for (Neighbor neighbor : neighbors) {
212-
LabeledInstance<String, Double> instance = neighbor.getInstance();
216+
LabeledInstance<Label, Model> instance = neighbor.getInstance();
213217
if (!votes.containsKey(instance.getLabel()))
214-
votes.put(instance.getLabel(), 1);
218+
votes.put(instance.getLabel().toString(), 1);
215219

216220
else {
217221
Integer count = votes.get(instance.getLabel());
218-
votes.put(instance.getLabel(), count + 1);
222+
votes.put(instance.getLabel().toString(), count + 1);
219223
}
220224

221225
}
@@ -250,8 +254,8 @@ protected int getIndexOfNearestNeighboorVoted(String label, List<Neighbor> neigh
250254

251255
for (int i = 0; i < neighbors.size(); i++) {
252256
Neighbor neighbor = neighbors.get(i);
253-
LabeledInstance<String, Double> instance = neighbor.getInstance();
254-
if (instance.getLabel().equalsIgnoreCase(label) && neighbor.getDistance() < distance) {
257+
LabeledInstance<Label, Model> instance = neighbor.getInstance();
258+
if (instance.getLabel().toString().equalsIgnoreCase(label) && neighbor.getDistance() < distance) {
255259
distance = neighbor.getDistance();
256260
index = i;
257261
}

src/main/java/com/github/felipexw/core/Instance.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
/**
66
* Created by felipe.appio on 23/08/2016.
77
*/
8-
public abstract class Instance<F> {
8+
public class Instance<F> {
99
protected List<F> features;
1010

1111
public List<F> getFeatures() {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.felipexw.core;
2+
3+
/**
4+
* Created by felipe.appio on 31/08/2016.
5+
*/
6+
public class Label<T> {
7+
private final T label;
8+
9+
public T getLabel() {
10+
return label;
11+
}
12+
13+
public Label(T label) {
14+
this.label = label;
15+
}
16+
}
Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
package com.github.felipexw.core;
22

3+
import java.util.Objects;
4+
35
/**
46
* Created by felipe.appio on 23/08/2016.
57
*/
6-
public class LabeledInstance<L, F> extends Instance<F> {
7-
protected final L label;
8-
private int count;
8+
public class LabeledInstance<L extends Label, F extends Model> extends Instance<String> {
9+
protected L label;
10+
private int count;
11+
12+
public LabeledInstance(L label) {
13+
this.label = label;
14+
}
915

10-
public LabeledInstance(L label) {
11-
this.label = label;
12-
}
16+
public L getLabel() {
17+
return this.label;
18+
}
1319

14-
public L getLabel() {
15-
return this.label;
16-
}
20+
@Override
21+
public String toString() {
22+
return label.toString();
23+
}
1724

18-
public void setCount(int count) {
19-
this.count = count;
20-
}
25+
public void setCount(int count) {
26+
this.count = count;
27+
}
2128

22-
public int getCount() {
23-
return count;
24-
}
29+
public int getCount() {
30+
return count;
31+
}
2532
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.github.felipexw.core;
2+
3+
import com.github.felipexw.core.extraction.FeatureExtractor;
4+
import com.github.felipexw.core.extraction.FeatureVector;
5+
6+
import java.util.List;
7+
8+
/**
9+
* Created by felipe.appio on 31/08/2016.
10+
*/
11+
public class Model implements FeatureVector {
12+
13+
private FeatureExtractor featureExtractor;
14+
15+
public Model(FeatureExtractor featureExtractor) {
16+
this.featureExtractor = featureExtractor;
17+
}
18+
19+
@Override
20+
public List getData(List source) {
21+
return featureExtractor.extract(source);
22+
}
23+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.felipexw.core.extraction;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.Objects;
6+
7+
/**
8+
* Created by felipe.appio on 31/08/2016.
9+
*/
10+
public class DoubleFeatureExtractor implements FeatureExtractor<Double>{
11+
12+
@Override
13+
public List<Double> extract(List<Object> source) {
14+
return Arrays.asList(new Double(1d), new Double(2d), new Double(3d), new Double(4d));
15+
}
16+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.github.felipexw.core.extraction;
2+
3+
import java.util.List;
4+
5+
/**
6+
* Created by felipe.appio on 31/08/2016.
7+
*/
8+
public interface FeatureExtractor<T> {
9+
List<T> extract(List<Object> source);
10+
}

0 commit comments

Comments
 (0)