-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Logistic CMF for yelp dataset. Python scripts for getting PRF from pr…
…ediction data.
- Loading branch information
Nitish Gupta
committed
Dec 25, 2014
0 parents
commit 304eac5
Showing
51 changed files
with
3,602 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<classpath> | ||
<classpathentry kind="src" path="src"/> | ||
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/JSON-API/javax.json-1.0.4.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/JSON-API/javax.json-api-1.0.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/swtgraphics2d.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/servlet.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/orsonpdf-1.6-eval.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/orsoncharts-1.4-eval-nofx.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/junit-4.11.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreesvg-2.0.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19-swt.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jcommon-1.0.23.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/hamcrest-core-1.3.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19.jar"/> | ||
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19-experimental.jar"/> | ||
<classpathentry kind="output" path="bin"/> | ||
</classpath> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<projectDescription> | ||
<name>logisitic_cmf_yelp</name> | ||
<comment></comment> | ||
<projects> | ||
</projects> | ||
<buildSpec> | ||
<buildCommand> | ||
<name>org.eclipse.jdt.core.javabuilder</name> | ||
<arguments> | ||
</arguments> | ||
</buildCommand> | ||
</buildSpec> | ||
<natures> | ||
<nature>org.eclipse.jdt.core.javanature</nature> | ||
</natures> | ||
</projectDescription> |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package logisticCMF; | ||
import java.util.*; | ||
|
||
|
||
public class Cell { | ||
String relation_id; // Id of relation in this data element | ||
ArrayList<String> entity_ids; // List of entities participating in Relation ' relation_id ' | ||
boolean truth; // Truth Value for this data element | ||
|
||
public Cell(){ | ||
entity_ids = new ArrayList<String>(); | ||
} | ||
|
||
public Cell(String r, String e1, String e2, boolean t){ | ||
entity_ids = new ArrayList<String>(); | ||
relation_id = r; | ||
entity_ids.add(e1); | ||
entity_ids.add(e2); | ||
truth = t; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
package logisticCMF; | ||
|
||
import java.io.IOException; | ||
import java.math.BigDecimal; | ||
import java.math.RoundingMode; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
public class Eval { | ||
static Map<String, Integer> relTrue = new HashMap<String, Integer>(); // Correct predictions per relation | ||
static Map<String, Integer> relCount = new HashMap<String, Integer>(); // Total test size per relation | ||
static Map<String, Double> relL2 = new HashMap<String, Double>(); | ||
|
||
static Map<String, Integer> relActualTruth = new HashMap<String, Integer>(); // Actual Truth values in each relation | ||
static Map<String, Integer> relPredTruth = new HashMap<String, Integer>(); // Truth predicted in each relation | ||
static Map<String, Integer> relTruthCorrect = new HashMap<String, Integer>(); // Correct Truth predictions in each relation | ||
|
||
static Map<String, ArrayList<Double>> relEvalMap = new HashMap<String, ArrayList<Double>>(); //Map [Rel,<Accuracy, Precision, Recall, F1>] Also contains "weighted average" as relation | ||
|
||
|
||
public static Map<String, ArrayList<Double>> getEvalMap(data Data, embeddings e, String set){ | ||
refreshMaps(); | ||
if(set.equals("test")){ | ||
for(Cell cell : Data.testData) | ||
updateTestMaps(cell, e); | ||
} | ||
else{ | ||
for(Cell cell : Data.valData) | ||
updateTestMaps(cell, e); | ||
} | ||
makeRelationEvalMap(); // Final Map [Relation, <Accuracy, Precision, Recall, F1>] | ||
return relEvalMap; | ||
} | ||
|
||
public static void refreshMaps(){ | ||
relTrue = new HashMap<String, Integer>(); | ||
relCount = new HashMap<String, Integer>(); | ||
relL2 = new HashMap<String, Double>(); | ||
relActualTruth = new HashMap<String, Integer>(); // Actual Truth values in each relation | ||
relPredTruth = new HashMap<String, Integer>(); // Truth predicted in each relation | ||
relTruthCorrect = new HashMap<String, Integer>(); // Correct Truth predictions in each relation | ||
relEvalMap = new HashMap<String, ArrayList<Double>>(); | ||
} | ||
|
||
public static void addInRelCountMap(Cell cell){ | ||
if(!relCount.containsKey(cell.relation_id)) | ||
relCount.put(cell.relation_id, 1); | ||
else | ||
relCount.put(cell.relation_id, relCount.get(cell.relation_id)+1); | ||
} | ||
|
||
public static void addInprfRelationMap(String relation, Integer actual, Integer pred, Integer correct){ | ||
if(!relActualTruth.containsKey(relation)){ | ||
relActualTruth.put(relation, 0); | ||
relPredTruth.put(relation, 0); | ||
relTruthCorrect.put(relation, 0); | ||
} | ||
|
||
else{ | ||
if(!(actual == 0 && pred == 0)){ | ||
relActualTruth.put(relation, relActualTruth.get(relation) + actual); | ||
relPredTruth.put(relation, relPredTruth.get(relation) + pred); | ||
relTruthCorrect.put(relation, relTruthCorrect.get(relation) + correct); | ||
} | ||
} | ||
} | ||
|
||
public static void addInRelAccuracyMap(Cell cell, Integer correct){ | ||
if(!relTrue.containsKey(cell.relation_id)) | ||
relTrue.put(cell.relation_id, correct); | ||
else | ||
relTrue.put(cell.relation_id, relTrue.get(cell.relation_id)+correct); | ||
} | ||
|
||
public static void addInRelL2Map(Cell cell, double l2){ | ||
if(!relL2.containsKey(cell.relation_id)) | ||
relL2.put(cell.relation_id, l2); | ||
else | ||
relL2.put(cell.relation_id, relTrue.get(cell.relation_id) + l2); | ||
} | ||
|
||
//Map [Rel, <Accuracy, Precision, Recall, F1> ] | ||
public static void makeRelationEvalMap(){ | ||
//makeTestMaps(Data, e); | ||
double f = 0; double wf1=0.0, waccuracy=0.0, wp = 0.0, wr = 0.0; int total = 0; | ||
for(String rel : relActualTruth.keySet()){ | ||
double accuracy = ((double)relTrue.get(rel))/relCount.get(rel); | ||
double precision = (double)relTruthCorrect.get(rel) / relPredTruth.get(rel) ; | ||
double recall = (double) relTruthCorrect.get(rel) / relActualTruth.get(rel) ; | ||
double f1 = 2*precision*recall / (precision + recall) ; | ||
//System.out.println("a : " + accuracy + " p : " + precision + " r : " + recall + " f1 : " +f1); | ||
relEvalMap.put(rel, new ArrayList<Double>()); | ||
relEvalMap.get(rel).add(round(accuracy, 3)); // Accuracy | ||
relEvalMap.get(rel).add(round(precision, 3)); // Precision | ||
relEvalMap.get(rel).add(round(recall, 3)); // Recall | ||
relEvalMap.get(rel).add(round(f1, 3)); // F1 | ||
wf1 += (relCount.get(rel)*f1); | ||
wp += (relCount.get(rel)*precision); | ||
wr += (relCount.get(rel)*recall); | ||
waccuracy += (relCount.get(rel)*accuracy); | ||
total += relCount.get(rel); | ||
} | ||
wf1 = round(wf1/total, 3); waccuracy = round(waccuracy/total, 3); wp = round(wp/total, 3); wr = round(wr/total, 3); | ||
relEvalMap.put("average", new ArrayList<Double>()); | ||
relEvalMap.get("average").add(round(waccuracy, 3)); // Accuracy | ||
relEvalMap.get("average").add(round(wp, 3)); // Precision | ||
relEvalMap.get("average").add(round(wr, 3)); // Recall | ||
relEvalMap.get("average").add(round(wf1, 3)); // F1 | ||
|
||
} | ||
|
||
public static void updateTestMaps(Cell cell, embeddings e){ | ||
int correct = 0; int t =0, f=0; int c = 0; double l2Sum = 0.0; | ||
double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha); | ||
double sigmdot = learner.sigm(dot); | ||
int pred = (sigmdot >= 0.5) ? 1 : 0; | ||
int truth = (cell.truth) ? 1 : 0; | ||
double l2 = (truth - sigmdot)*(truth-sigmdot); | ||
l2Sum += l2; | ||
//System.out.println(sigmdot + " " + pred + " " + truth + " " + l2); | ||
|
||
if(pred == truth) | ||
correct = 1; | ||
else | ||
correct = 0; | ||
c += correct; | ||
//System.out.println("rel : " + truth + " pred : " + pred); | ||
//System.out.println(cell.relation_id + " : " + pred); | ||
addInprfRelationMap(cell.relation_id, truth, pred, correct); | ||
addInRelAccuracyMap(cell, correct); | ||
addInRelCountMap(cell); | ||
addInRelL2Map(cell, l2); | ||
} | ||
|
||
public static void printEval(){ | ||
for(String rel : relEvalMap.keySet()){ | ||
ArrayList<Double> eval = relEvalMap.get(rel); | ||
System.out.print(rel + " : "); | ||
System.out.println("P : " + eval.get(1) + " R : " + eval.get(2) + " F1 : " + eval.get(3) + " Accuracy : " + eval.get(0)); | ||
} | ||
} | ||
|
||
|
||
public static double round(double value, int places) { | ||
if (places < 0) throw new IllegalArgumentException(); | ||
if(Double.isNaN(value)) | ||
return Double.NaN; | ||
|
||
BigDecimal bd = new BigDecimal(value); | ||
bd = bd.setScale(places, RoundingMode.HALF_UP); | ||
return bd.doubleValue(); | ||
} | ||
|
||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package logisticCMF; | ||
import java.io.*; | ||
import java.util.*; | ||
|
||
|
||
|
||
public class Rating { | ||
|
||
public Map<String, Map<String, Integer>> ratings; // Map[busId, Map[User, Rating]] | ||
public String loc; | ||
data [] busRate = new data[4]; | ||
|
||
public Rating(String folder, double valp, double testp) throws IOException{ | ||
ratings = new HashMap<String, Map<String, Integer>>(); | ||
loc = folder; | ||
busRate[0] = new data(); | ||
busRate[1] = new data(); | ||
busRate[2] = new data(); | ||
busRate[3] = new data(); | ||
readRatingData(folder); | ||
makeTwoRatingDataCells(valp, testp); | ||
/*for(data rd : busRate){ | ||
rd.dataStats(); | ||
}*/ | ||
} | ||
|
||
public void testRatings(){ | ||
data mergeData = new data(); | ||
ArrayList<data> tomerge = new ArrayList<data>(); | ||
for(data d : busRate) | ||
tomerge.add(d); | ||
mergeData.addDataAfterSplit(tomerge); | ||
mergeData.dataStats(); | ||
codeTest.learnAndTest(mergeData, 30, false, 0, false, 0); | ||
} | ||
|
||
public void readRatingData(String folder) throws NumberFormatException, IOException{ | ||
String address = System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt"; | ||
BufferedReader br = new BufferedReader(new FileReader(address)); | ||
String line; int countl = 0, countcell = 0; | ||
String busid = null, userid = null; boolean value = false; | ||
while((line = br.readLine()) != null){ | ||
countl++; | ||
String[] array = line.split(":"); | ||
if( array[0].trim().equals("bus_id")){ | ||
busid = array[1].trim(); | ||
if(!ratings.containsKey(busid)) | ||
ratings.put(busid, new HashMap<String, Integer>()); | ||
} | ||
if( array[0].trim().equals("user_id")) | ||
userid = array[1].trim(); | ||
if( array[0].trim().equals("star")){ | ||
double t = Double.parseDouble(array[1].trim()); | ||
/*if(ratings.get(busid).containsKey(userid)){ | ||
int r = ratings.get(busid).get(userid); | ||
System.out.println("user already exists " + busid + " " + userid); | ||
}*/ | ||
ratings.get(busid).put(userid, (int)t); | ||
countcell++; | ||
} | ||
} | ||
br.close(); | ||
//System.out.println("Ratings : " + countcell); | ||
|
||
} | ||
|
||
public void makeTwoRatingDataCells(double valPerc, double testPerc){ | ||
String relation = "busrate-"+loc+"-"; | ||
int count2 = 0; int count = 0; | ||
// Read all rating data from Map and create busRate data object for threshold = 2 | ||
for(String bus : ratings.keySet()){ | ||
for(String user : ratings.get(bus).keySet()){ | ||
int rate = ratings.get(bus).get(user); | ||
count++; | ||
Cell cell = new Cell(); | ||
cell.relation_id = relation+"2"; | ||
cell.entity_ids.add(bus); | ||
cell.entity_ids.add(user); | ||
if(rate >= 2){ | ||
cell.truth = true; | ||
busRate[0].Data.add(cell); | ||
} | ||
else{ | ||
cell.truth = false; | ||
busRate[0].Data.add(cell); | ||
} | ||
} | ||
} | ||
busRate[0].splitTrainTestValidation(valPerc, testPerc); | ||
//busRate[0].dataStats(); | ||
makeRestRatingDataCells(); | ||
|
||
System.out.println("Ratings : " + busRate[0].Data.size()); | ||
} | ||
|
||
public void makeRestRatingDataCells(){ | ||
String r = "busrate-"+loc+"-"; | ||
for(Cell cell : busRate[0].trainData){ | ||
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); | ||
int rate = ratings.get(b).get(u); | ||
if(rate <= 2) | ||
addCellsinRestTrain(r, b, u, false, false, false); | ||
|
||
if(rate == 3) | ||
addCellsinRestTrain(r, b, u, true, false, false); | ||
|
||
if(rate == 4) | ||
addCellsinRestTrain(r, b, u, true, true, false); | ||
|
||
if(rate == 5) | ||
addCellsinRestTrain(r, b, u, true, true, true); | ||
} | ||
|
||
for(Cell cell : busRate[0].testData){ | ||
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); | ||
int rate = ratings.get(b).get(u); | ||
if(rate <= 2) | ||
addCellsinRestTest(r, b, u, false, false, false); | ||
if(rate == 3) | ||
addCellsinRestTest(r, b, u, true, false, false); | ||
if(rate == 4) | ||
addCellsinRestTest(r, b, u, true, true, false); | ||
if(rate == 5) | ||
addCellsinRestTest(r, b, u, true, true, true); | ||
} | ||
|
||
for(Cell cell : busRate[0].valData){ | ||
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); | ||
int rate = ratings.get(b).get(u); | ||
if(rate <= 2) | ||
addCellsinRestVal(r, b, u, false, false, false); | ||
if(rate == 3) | ||
addCellsinRestVal(r, b, u, true, false, false); | ||
if(rate == 4) | ||
addCellsinRestVal(r, b, u, true, true, false); | ||
if(rate == 5) | ||
addCellsinRestVal(r, b, u, true, true, true); | ||
} | ||
|
||
} | ||
|
||
public void addCellsinRestTrain(String r, String b, String u, boolean t3, boolean t4, boolean t5){ | ||
busRate[1].Data.add(new Cell(r+3, b, u, t3)); | ||
busRate[1].trainData.add(new Cell(r+3, b, u, t3)); | ||
|
||
busRate[2].Data.add(new Cell(r+4, b, u, t4)); | ||
busRate[2].trainData.add(new Cell(r+4, b, u, t4)); | ||
|
||
busRate[3].Data.add(new Cell(r+5, b, u, t5)); | ||
busRate[3].trainData.add(new Cell(r+5, b, u, t5)); | ||
} | ||
|
||
public void addCellsinRestTest(String r, String b, String u, boolean t3, boolean t4, boolean t5){ | ||
busRate[1].Data.add(new Cell(r+3, b, u, t3)); | ||
busRate[1].testData.add(new Cell(r+3, b, u, t3)); | ||
|
||
busRate[2].Data.add(new Cell(r+4, b, u, t4)); | ||
busRate[2].testData.add(new Cell(r+4, b, u, t4)); | ||
|
||
busRate[3].Data.add(new Cell(r+5, b, u, t5)); | ||
busRate[3].testData.add(new Cell(r+5, b, u, t5)); | ||
} | ||
|
||
public void addCellsinRestVal(String r, String b, String u, boolean t3, boolean t4, boolean t5){ | ||
busRate[1].Data.add(new Cell(r+3, b, u, t3)); | ||
busRate[1].valData.add(new Cell(r+3, b, u, t3)); | ||
|
||
busRate[2].Data.add(new Cell(r+4, b, u, t4)); | ||
busRate[2].valData.add(new Cell(r+4, b, u, t4)); | ||
|
||
busRate[3].Data.add(new Cell(r+5, b, u, t5)); | ||
busRate[3].valData.add(new Cell(r+5, b, u, t5)); | ||
} | ||
|
||
} | ||
|
Oops, something went wrong.