Skip to content

Commit 6f74c17

Browse files
committed
Add test case to calculate the posterior probability.
1 parent 21d373d commit 6f74c17

File tree

9 files changed

+172
-110
lines changed

9 files changed

+172
-110
lines changed

build.gradle

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
description = "ml-dive application"
1+
description = "knn-java-library"
22

3-
apply plugin: "war"
43
apply plugin: "java"
5-
apply plugin: "jetty"
64

7-
sourceCompatibility = "1.6"
8-
targetCompatibility = "1.6"
5+
6+
sourceCompatibility = "1.8"
7+
targetCompatibility = "1.8"
98

109
group = "com.github.felipexw"
1110
version = "1.0-SNAPSHOT"
@@ -50,37 +49,16 @@ sourceSets {
5049
}
5150

5251
dependencies {
52+
// https://mvnrepository.com/artifact/com.google.truth/truth
53+
compile group: 'com.google.truth', name: 'truth', version: '0.29'
54+
compile group: 'com.google.guava', name: 'guava', version: '19.0'
5355

54-
compile "org.apache.tapestry:tapestry-core:5.4.1"
55-
56-
// Uncomment this to add support for file uploads:
57-
// compile "org.apache.tapestry:tapestry-upload:5.4.1"
58-
59-
// CoffeeScript & Less support, plus resource minification:
60-
compile "org.apache.tapestry:tapestry-webresources:5.4.1"
61-
62-
test "org.apache.tapestry:tapestry-test:5.4.1"
63-
64-
// Log implementation choose one:
65-
// Log4j 1.x
6656
runtime "log4j:log4j:1.2.17"
6757
runtime "org.slf4j:slf4j-log4j12:1.7.19"
68-
// Logback
69-
// runtime "ch.qos.logback:logback-classic:1.0.13"
7058

71-
provided "javax.servlet:servlet-api:2.5"
7259
}
7360

7461
test {
75-
useTestNG()
76-
77-
options.suites("src/test/conf/testng.xml")
78-
79-
systemProperties["tapestry.service-reloading-enabled"] = "false"
80-
systemProperties["tapestry.execution-mode"] = "development"
81-
82-
maxHeapSize = "600M"
83-
8462
jvmArgs("-XX:MaxPermSize=256M")
8563

8664
enableAssertions = true

gradle/wrapper/gradle-wrapper.jar

-52.4 KB
Binary file not shown.

gradle/wrapper/gradle-wrapper.properties

Lines changed: 0 additions & 6 deletions
This file was deleted.

pom.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
<artifactId>truth</artifactId>
2525
<version>0.28</version>
2626
</dependency>
27+
28+
<!-- https://mvnrepository.com/artifact/junit/junit -->
29+
<dependency>
30+
<groupId>junit</groupId>
31+
<artifactId>junit</artifactId>
32+
<version>4.12</version>
33+
</dependency>
34+
35+
2736
</dependencies>
2837
<build>
2938
<finalName>ml-dive</finalName>

src/main/java/com/github/felipexw/classifier/bayes/MultinomialNaiveBayesClassifier.java

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,85 +2,91 @@
22

33
import com.github.felipexw.classifier.Classifier;
44
import com.github.felipexw.classifier.CrossValidateClassifier;
5-
import com.github.felipexw.types.Instance;
65
import com.github.felipexw.types.LabeledInstance;
76
import com.github.felipexw.types.LabeledTrainingInstance;
87
import com.github.felipexw.types.PredictedInstance;
8+
99
import java.util.ArrayList;
1010
import java.util.Arrays;
1111
import java.util.HashMap;
1212
import java.util.List;
13-
import java.util.Set;
1413

1514
/**
1615
* Created by felipe.appio on 29/08/2016.
1716
*/
18-
public class MultinomialNaiveBayesClassifier extends NaiveBayes
19-
implements Classifier, CrossValidateClassifier {
17+
public class MultinomialNaiveBayesClassifier extends NaiveBayes {
2018

21-
private void init() {
22-
prioriProbs = new HashMap<>();
23-
posterioriProbs = new HashMap<>();
24-
labels = new HashMap<>();
25-
}
19+
private void init() {
20+
prioriProbs = new HashMap<>();
21+
posterioriProbs = new HashMap<>();
22+
labels = new HashMap<>();
23+
}
2624

27-
@Override public void train(List<LabeledTrainingInstance> instances) {
25+
public MultinomialNaiveBayesClassifier() {
26+
init();
27+
}
2828

29-
}
29+
@Override
30+
public void train(List<LabeledTrainingInstance> instances) {
31+
calculateProbs(instances);
32+
}
3033

31-
@Override public PredictedInstance predict(LabeledTrainingInstance instance) {
32-
return null;
33-
}
34+
@Override
35+
public PredictedInstance predict(LabeledTrainingInstance instance) {
36+
return null;
37+
}
3438

35-
@Override public List<PredictedInstance> predict(List<LabeledTrainingInstance> instances) {
36-
List<PredictedInstance> predictions = new ArrayList<>();
39+
@Override
40+
public List<PredictedInstance> predict(List<LabeledTrainingInstance> instances) {
41+
List<PredictedInstance> predictions = new ArrayList<>();
3742

38-
instances.forEach((instance) -> {
39-
predictions.add(predict(instance));
40-
});
43+
instances.forEach((instance) -> {
44+
predictions.add(predict(instance));
45+
});
4146

42-
return predictions;
43-
}
47+
return predictions;
48+
}
4449

45-
@Override public void train(List<LabeledTrainingInstance> instances, int k) {
50+
@Override
51+
public void train(List<LabeledTrainingInstance> instances, int k) {
4652

47-
}
53+
}
4854

49-
@Override
50-
public void calculatePrioriProbs(List<LabeledTrainingInstance> instanceList) {
51-
for (LabeledTrainingInstance instance : instanceList) {
52-
if (!labels.containsKey(instance.getLabel())) {
53-
labels.put(instance.getLabel(), 1);
54-
}
55+
@Override
56+
public void calculateProbs(List<LabeledTrainingInstance> instanceList) {
57+
for (LabeledTrainingInstance instance : instanceList) {
58+
if (!labels.containsKey(instance.getLabel())) {
59+
labels.put(instance.getLabel(), 1);
60+
}
5561

56-
calculatePosterioriProbability(instance);
62+
calculatePosterioriProbability(instance);
63+
}
5764
}
58-
}
5965

60-
@Override
61-
public void calculatePosterioriProbability(LabeledTrainingInstance instance) {
62-
double[] features = instance.getFeatures();
66+
@Override
67+
public void calculatePosterioriProbability(LabeledTrainingInstance instance) {
68+
double[] features = instance.getFeatures();
6369

64-
for (int i = 0; i < features.length; i++) {
65-
double key = features[i];
70+
for (int i = 0; i < features.length; i++) {
71+
double key = features[i];
6672

67-
if (!this.posterioriProbs.containsKey(key)) {
68-
List<LabeledInstance> instances = Arrays.asList(new LabeledInstance(instance.getLabel()));
69-
this.posterioriProbs.put(key, instances);
70-
} else {
71-
countFromLabels(this.posterioriProbs.get(key), instance);
72-
}
73+
if (!this.posterioriProbs.containsKey(key)) {
74+
List<LabeledInstance> instances = Arrays.asList(new LabeledInstance(instance.getLabel()));
75+
this.posterioriProbs.put(String.valueOf(key), instances);
76+
} else {
77+
countFromLabels(this.posterioriProbs.get(key), instance);
78+
}
79+
}
7380
}
74-
}
75-
76-
private void countFromLabels(List<LabeledInstance> instances, LabeledInstance instance) {
77-
for (LabeledInstance featuresInstance : instances) {
78-
if (featuresInstance.getLabel().equalsIgnoreCase(instance.getLabel())) {
79-
featuresInstance.setCount(featuresInstance.getCount() + 1);
80-
} else {
81-
instance.setCount(1);
82-
instances.add(instance);
83-
}
81+
82+
private void countFromLabels(List<LabeledInstance> instances, LabeledInstance instance) {
83+
for (LabeledInstance featuresInstance : instances) {
84+
if (featuresInstance.getLabel().equalsIgnoreCase(instance.getLabel())) {
85+
featuresInstance.setCount(featuresInstance.getCount() + 1);
86+
} else {
87+
instance.setCount(1);
88+
instances.add(instance);
89+
}
90+
}
8491
}
85-
}
8692
}
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.github.felipexw.classifier.bayes;
22

3+
import com.github.felipexw.classifier.Classifier;
4+
import com.github.felipexw.classifier.CrossValidateClassifier;
35
import com.github.felipexw.types.LabeledInstance;
46
import com.github.felipexw.types.LabeledTrainingInstance;
57
import java.util.List;
@@ -8,16 +10,20 @@
810
/**
911
* Created by felipe.appio on 29/08/2016.
1012
*/
11-
public abstract class NaiveBayes {
12-
protected Map<Double, List<LabeledInstance>> posterioriProbs;
13+
public abstract class NaiveBayes implements Classifier, CrossValidateClassifier {
14+
protected Map<String, List<LabeledInstance>> posterioriProbs;
1315
protected Map<String, Double> prioriProbs;
1416
protected Map<String, Integer> labels;
1517

16-
public abstract void calculatePrioriProbs(List<LabeledTrainingInstance> instanceList);
18+
public abstract void calculateProbs(List<LabeledTrainingInstance> instanceList);
1719

1820
public Map<String, Double> getPrioriProbs() {
1921
return prioriProbs;
2022
}
2123

24+
public Map<String, List<LabeledInstance>> getPosterioriProbs() {
25+
return posterioriProbs;
26+
}
27+
2228
public abstract void calculatePosterioriProbability(LabeledTrainingInstance instances);
2329
}

src/main/main.iml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<module type="JAVA_MODULE" version="4">
3+
<component name="NewModuleRootManager" inherit-compiler-output="true">
4+
<exclude-output />
5+
<content url="file://$MODULE_DIR$">
6+
<sourceFolder url="file://$MODULE_DIR$/java" isTestSource="false" />
7+
</content>
8+
<orderEntry type="inheritedJdk" />
9+
<orderEntry type="sourceFolder" forTests="false" />
10+
</component>
11+
</module>
Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import com.github.felipexw.classifier.bayes.MultinomialNaiveBayesClassifier;
22
import com.github.felipexw.classifier.bayes.NaiveBayes;
3+
import com.github.felipexw.types.LabeledInstance;
34
import com.github.felipexw.types.LabeledTrainingInstance;
5+
46
import java.util.Arrays;
57
import java.util.List;
68
import java.util.Map;
9+
710
import org.junit.Before;
811
import org.junit.Test;
912

@@ -14,30 +17,73 @@
1417
*/
1518
public class MultinomialNaiveBayesClassifierTest {
1619

17-
private NaiveBayes naiveBayesClassifier;
20+
private NaiveBayes naiveBayesClassifier;
21+
22+
@Before
23+
public void setUp() {
24+
naiveBayesClassifier = new MultinomialNaiveBayesClassifier();
25+
}
26+
27+
@Test
28+
public void it_should_calculate_a_priori_probs() {
29+
String negativeLabel = "negative";
30+
String positiveLabel = "positive";
31+
32+
List<LabeledTrainingInstance> training =
33+
Arrays.asList(new LabeledTrainingInstance(new double[]{2}, negativeLabel),
34+
new LabeledTrainingInstance(new double[]{2}, negativeLabel),
35+
new LabeledTrainingInstance(new double[]{2}, positiveLabel));
36+
37+
Map<String, Double> probs = naiveBayesClassifier.getPrioriProbs();
38+
39+
assertThat(probs.get(negativeLabel))
40+
.isEqualTo(2);
41+
assertThat(probs.get(positiveLabel))
42+
.isEqualTo(1);
43+
}
44+
45+
@Test
46+
public void it_should_calculate_the_posterior_probabilities() {
47+
/**
48+
* given the following sentences:
49+
* 1) 'windows sucks' (negative)
50+
* 2) 'amd sucs' (negative)
51+
* 3) 'intel its good' (positive)
52+
* 4) 'linux its good' (positive
53+
* it should calculate de posterior probabilities
54+
*/
55+
56+
String negativeLabel = "negative";
57+
String positiveLabel = "positive";
58+
59+
double[] featuresA = {"windows".hashCode(), "sucks".hashCode()};
60+
double[] featuresB = {"amd".hashCode(), "sucks".hashCode()};
61+
double[] featuresC = {"intel".hashCode(), "its".hashCode(), "good".hashCode()};
62+
double[] featuresD = {"linux".hashCode(), "its".hashCode(), "good".hashCode()};
63+
64+
LabeledTrainingInstance intance1 = new LabeledTrainingInstance(featuresA, negativeLabel);
65+
LabeledTrainingInstance intance2 = new LabeledTrainingInstance(featuresB, negativeLabel);
66+
LabeledTrainingInstance intance3 = new LabeledTrainingInstance(featuresC, positiveLabel);
67+
LabeledTrainingInstance intance4 = new LabeledTrainingInstance(featuresD, positiveLabel);
68+
69+
naiveBayesClassifier.train(Arrays.asList(intance1, intance2, intance3, intance4));
70+
Map<String, List<LabeledInstance>> probs = naiveBayesClassifier.getPosterioriProbs();;
71+
72+
/*for(Double key: probs.keySet()){
73+
1874
19-
@Before
20-
public void setUp() {
21-
naiveBayesClassifier = new MultinomialNaiveBayesClassifier();
22-
}
2375
24-
@Test
25-
public void it_should_calculate_a_priori_probs() {
26-
String negativeLabel = "negative";
27-
String positiveLabel = "positive";
76+
}*/
2877

29-
List<LabeledTrainingInstance> training =
30-
Arrays.asList(new LabeledTrainingInstance(new double[] {2}, negativeLabel),
31-
new LabeledTrainingInstance(new double[] {2}, negativeLabel),
32-
new LabeledTrainingInstance(new double[] {2}, positiveLabel));
78+
List<LabeledInstance> is = probs.get(String.valueOf(Double.parseDouble(String.valueOf("its".hashCode()))));
79+
for(LabeledInstance instance: is){
80+
System.out.println(instance.getCount());
81+
}
3382

34-
Map<String, Double> probs = naiveBayesClassifier.getPrioriProbs();
83+
assertThat(probs.size())
84+
.isEqualTo(7);
3585

36-
assertThat(probs.get(negativeLabel))
37-
.isEqualTo(2);
38-
assertThat(probs.get(positiveLabel))
39-
.isEqualTo(1);
40-
}
86+
}
4187

4288

4389
}

src/test/test.iml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<module type="JAVA_MODULE" version="4">
3+
<component name="NewModuleRootManager" inherit-compiler-output="true">
4+
<exclude-output />
5+
<content url="file://$MODULE_DIR$">
6+
<sourceFolder url="file://$MODULE_DIR$/java" isTestSource="true" />
7+
</content>
8+
<orderEntry type="inheritedJdk" />
9+
<orderEntry type="sourceFolder" forTests="false" />
10+
<orderEntry type="module" module-name="main" />
11+
</component>
12+
</module>

0 commit comments

Comments
 (0)