Skip to content

Commit ff56ed6

Browse files
committed
Refactoring when_train_its_called_it_should_calculate_the_distance_between_the_neighbors
1 parent 1580773 commit ff56ed6

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/test/java/SimpleKNNClassifierTest.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import com.github.felipexw.classifiers.neighbors.Neighbor;
12
import com.github.felipexw.classifiers.neighbors.SimpleKNNClassifier;
23
import com.github.felipexw.core.Model;
34
import com.github.felipexw.core.Prediction;
@@ -81,6 +82,42 @@ the algorithm must predict the label (which its positive or negative) for the po
8182
.isEqualTo(negativeLabel);
8283
}
8384

85+
@Test
86+
public void when_train_its_called_it_should_calculate_the_distance_between_the_neighbors(){
87+
/*
88+
given a set of negative points:
89+
- A(2,4); B(3,2)
90+
and a set of positive points:
91+
- D(4,1); D(5,5)
92+
the algorithm must predict the label (which its positive or negative) for the point E(1,3)
93+
*/
94+
String positiveLabel = "positive";
95+
String negativeLabel = "negative";
96+
97+
LabeledInstance pointA = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(2d, 4d)));
98+
LabeledInstance pointB = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(3d, 2d)));
99+
100+
LabeledInstance pointC = new LabeledInstance(positiveLabel, new TestModel(null, Arrays.asList(4d, 1d)));
101+
LabeledInstance pointD = new LabeledInstance(positiveLabel, new TestModel(null, Arrays.asList(5d, 5d)));
102+
103+
LabeledInstance pointE = new LabeledInstance(negativeLabel, new TestModel(null, Arrays.asList(7d, 7d)));
104+
105+
classifier.setK(2);
106+
classifier.train(Arrays.asList(pointA, pointB, pointC, pointD));
107+
List<Neighbor> similarNeighbors = classifier.similarNeighbors(pointE, 2);
108+
109+
Neighbor n1 = new Neighbor(new LabeledInstance(null, pointA.getModel()), 0d, null);
110+
Neighbor n2 = new Neighbor(new LabeledInstance(null, pointB.getModel()), 0d, null);
111+
112+
Truth.assertThat(similarNeighbors)
113+
.containsAllIn(Arrays.asList(n1, n2));
114+
}
115+
116+
@Test(expected = IllegalArgumentException.class)
117+
public void when_similarNeighbors_its_called_with_null_neighbors_args_it_should_raise_an_exception(){
118+
classifier.similarNeighbors(null, 10);
119+
}
120+
84121
}
85122

86123

0 commit comments

Comments
 (0)