Skip to content

Commit

Permalink
Logistic CMF for yelp dataset. Python scripts for getting PRF from pr…
Browse files Browse the repository at this point in the history
…ediction data.
  • Loading branch information
Nitish Gupta committed Dec 25, 2014
0 parents commit 304eac5
Show file tree
Hide file tree
Showing 51 changed files with 3,602 additions and 0 deletions.
19 changes: 19 additions & 0 deletions Project/.classpath
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/JSON-API/javax.json-1.0.4.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/JSON-API/javax.json-api-1.0.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/swtgraphics2d.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/servlet.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/orsonpdf-1.6-eval.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/orsoncharts-1.4-eval-nofx.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/junit-4.11.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreesvg-2.0.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19-swt.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jcommon-1.0.23.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/hamcrest-core-1.3.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19.jar"/>
<classpathentry kind="lib" path="/home/nitish/Thesis/Tools/jfreechart-1.0.19/lib/jfreechart-1.0.19-experimental.jar"/>
<classpathentry kind="output" path="bin"/>
</classpath>
17 changes: 17 additions & 0 deletions Project/.project
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<projectDescription>
<name>logisitic_cmf_yelp</name>
<comment></comment>
<projects>
</projects>
<buildSpec>
<buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.jdt.core.javanature</nature>
</natures>
</projectDescription>
Binary file added Project/bin/logisticCMF/Cell.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/Eval.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/Rating.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/Util.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/codeTest.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/data.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/embedding.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/embeddings.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/learner.class
Binary file not shown.
Binary file added Project/bin/logisticCMF/writeDataToFile.class
Binary file not shown.
Binary file not shown.
Binary file added Project/bin/postProcessing/Similarity.class
Binary file not shown.
Binary file added Project/bin/postProcessing/Util$1.class
Binary file not shown.
Binary file added Project/bin/postProcessing/Util.class
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added Project/bin/yelpDataProcessing/reviewData.class
Binary file not shown.
Binary file added Project/bin/yelpDataProcessing/reviewJson.class
Binary file not shown.
22 changes: 22 additions & 0 deletions Project/src/logisticCMF/Cell.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package logisticCMF;
import java.util.*;


public class Cell {
String relation_id; // Id of relation in this data element
ArrayList<String> entity_ids; // List of entities participating in Relation ' relation_id '
boolean truth; // Truth Value for this data element

public Cell(){
entity_ids = new ArrayList<String>();
}

public Cell(String r, String e1, String e2, boolean t){
entity_ids = new ArrayList<String>();
relation_id = r;
entity_ids.add(e1);
entity_ids.add(e2);
truth = t;
}

}
157 changes: 157 additions & 0 deletions Project/src/logisticCMF/Eval.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package logisticCMF;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Eval {
static Map<String, Integer> relTrue = new HashMap<String, Integer>(); // Correct predictions per relation
static Map<String, Integer> relCount = new HashMap<String, Integer>(); // Total test size per relation
static Map<String, Double> relL2 = new HashMap<String, Double>();

static Map<String, Integer> relActualTruth = new HashMap<String, Integer>(); // Actual Truth values in each relation
static Map<String, Integer> relPredTruth = new HashMap<String, Integer>(); // Truth predicted in each relation
static Map<String, Integer> relTruthCorrect = new HashMap<String, Integer>(); // Correct Truth predictions in each relation

static Map<String, ArrayList<Double>> relEvalMap = new HashMap<String, ArrayList<Double>>(); //Map [Rel,<Accuracy, Precision, Recall, F1>] Also contains "weighted average" as relation


public static Map<String, ArrayList<Double>> getEvalMap(data Data, embeddings e, String set){
refreshMaps();
if(set.equals("test")){
for(Cell cell : Data.testData)
updateTestMaps(cell, e);
}
else{
for(Cell cell : Data.valData)
updateTestMaps(cell, e);
}
makeRelationEvalMap(); // Final Map [Relation, <Accuracy, Precision, Recall, F1>]
return relEvalMap;
}

public static void refreshMaps(){
relTrue = new HashMap<String, Integer>();
relCount = new HashMap<String, Integer>();
relL2 = new HashMap<String, Double>();
relActualTruth = new HashMap<String, Integer>(); // Actual Truth values in each relation
relPredTruth = new HashMap<String, Integer>(); // Truth predicted in each relation
relTruthCorrect = new HashMap<String, Integer>(); // Correct Truth predictions in each relation
relEvalMap = new HashMap<String, ArrayList<Double>>();
}

public static void addInRelCountMap(Cell cell){
if(!relCount.containsKey(cell.relation_id))
relCount.put(cell.relation_id, 1);
else
relCount.put(cell.relation_id, relCount.get(cell.relation_id)+1);
}

public static void addInprfRelationMap(String relation, Integer actual, Integer pred, Integer correct){
if(!relActualTruth.containsKey(relation)){
relActualTruth.put(relation, 0);
relPredTruth.put(relation, 0);
relTruthCorrect.put(relation, 0);
}

else{
if(!(actual == 0 && pred == 0)){
relActualTruth.put(relation, relActualTruth.get(relation) + actual);
relPredTruth.put(relation, relPredTruth.get(relation) + pred);
relTruthCorrect.put(relation, relTruthCorrect.get(relation) + correct);
}
}
}

public static void addInRelAccuracyMap(Cell cell, Integer correct){
if(!relTrue.containsKey(cell.relation_id))
relTrue.put(cell.relation_id, correct);
else
relTrue.put(cell.relation_id, relTrue.get(cell.relation_id)+correct);
}

public static void addInRelL2Map(Cell cell, double l2){
if(!relL2.containsKey(cell.relation_id))
relL2.put(cell.relation_id, l2);
else
relL2.put(cell.relation_id, relTrue.get(cell.relation_id) + l2);
}

//Map [Rel, <Accuracy, Precision, Recall, F1> ]
public static void makeRelationEvalMap(){
//makeTestMaps(Data, e);
double f = 0; double wf1=0.0, waccuracy=0.0, wp = 0.0, wr = 0.0; int total = 0;
for(String rel : relActualTruth.keySet()){
double accuracy = ((double)relTrue.get(rel))/relCount.get(rel);
double precision = (double)relTruthCorrect.get(rel) / relPredTruth.get(rel) ;
double recall = (double) relTruthCorrect.get(rel) / relActualTruth.get(rel) ;
double f1 = 2*precision*recall / (precision + recall) ;
//System.out.println("a : " + accuracy + " p : " + precision + " r : " + recall + " f1 : " +f1);
relEvalMap.put(rel, new ArrayList<Double>());
relEvalMap.get(rel).add(round(accuracy, 3)); // Accuracy
relEvalMap.get(rel).add(round(precision, 3)); // Precision
relEvalMap.get(rel).add(round(recall, 3)); // Recall
relEvalMap.get(rel).add(round(f1, 3)); // F1
wf1 += (relCount.get(rel)*f1);
wp += (relCount.get(rel)*precision);
wr += (relCount.get(rel)*recall);
waccuracy += (relCount.get(rel)*accuracy);
total += relCount.get(rel);
}
wf1 = round(wf1/total, 3); waccuracy = round(waccuracy/total, 3); wp = round(wp/total, 3); wr = round(wr/total, 3);
relEvalMap.put("average", new ArrayList<Double>());
relEvalMap.get("average").add(round(waccuracy, 3)); // Accuracy
relEvalMap.get("average").add(round(wp, 3)); // Precision
relEvalMap.get("average").add(round(wr, 3)); // Recall
relEvalMap.get("average").add(round(wf1, 3)); // F1

}

public static void updateTestMaps(Cell cell, embeddings e){
int correct = 0; int t =0, f=0; int c = 0; double l2Sum = 0.0;
double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha);
double sigmdot = learner.sigm(dot);
int pred = (sigmdot >= 0.5) ? 1 : 0;
int truth = (cell.truth) ? 1 : 0;
double l2 = (truth - sigmdot)*(truth-sigmdot);
l2Sum += l2;
//System.out.println(sigmdot + " " + pred + " " + truth + " " + l2);

if(pred == truth)
correct = 1;
else
correct = 0;
c += correct;
//System.out.println("rel : " + truth + " pred : " + pred);
//System.out.println(cell.relation_id + " : " + pred);
addInprfRelationMap(cell.relation_id, truth, pred, correct);
addInRelAccuracyMap(cell, correct);
addInRelCountMap(cell);
addInRelL2Map(cell, l2);
}

public static void printEval(){
for(String rel : relEvalMap.keySet()){
ArrayList<Double> eval = relEvalMap.get(rel);
System.out.print(rel + " : ");
System.out.println("P : " + eval.get(1) + " R : " + eval.get(2) + " F1 : " + eval.get(3) + " Accuracy : " + eval.get(0));
}
}


public static double round(double value, int places) {
if (places < 0) throw new IllegalArgumentException();
if(Double.isNaN(value))
return Double.NaN;

BigDecimal bd = new BigDecimal(value);
bd = bd.setScale(places, RoundingMode.HALF_UP);
return bd.doubleValue();
}



}
177 changes: 177 additions & 0 deletions Project/src/logisticCMF/Rating.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package logisticCMF;
import java.io.*;
import java.util.*;



public class Rating {

public Map<String, Map<String, Integer>> ratings; // Map[busId, Map[User, Rating]]
public String loc;
data [] busRate = new data[4];

public Rating(String folder, double valp, double testp) throws IOException{
ratings = new HashMap<String, Map<String, Integer>>();
loc = folder;
busRate[0] = new data();
busRate[1] = new data();
busRate[2] = new data();
busRate[3] = new data();
readRatingData(folder);
makeTwoRatingDataCells(valp, testp);
/*for(data rd : busRate){
rd.dataStats();
}*/
}

public void testRatings(){
data mergeData = new data();
ArrayList<data> tomerge = new ArrayList<data>();
for(data d : busRate)
tomerge.add(d);
mergeData.addDataAfterSplit(tomerge);
mergeData.dataStats();
codeTest.learnAndTest(mergeData, 30, false, 0, false, 0);
}

public void readRatingData(String folder) throws NumberFormatException, IOException{
String address = System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt";
BufferedReader br = new BufferedReader(new FileReader(address));
String line; int countl = 0, countcell = 0;
String busid = null, userid = null; boolean value = false;
while((line = br.readLine()) != null){
countl++;
String[] array = line.split(":");
if( array[0].trim().equals("bus_id")){
busid = array[1].trim();
if(!ratings.containsKey(busid))
ratings.put(busid, new HashMap<String, Integer>());
}
if( array[0].trim().equals("user_id"))
userid = array[1].trim();
if( array[0].trim().equals("star")){
double t = Double.parseDouble(array[1].trim());
/*if(ratings.get(busid).containsKey(userid)){
int r = ratings.get(busid).get(userid);
System.out.println("user already exists " + busid + " " + userid);
}*/
ratings.get(busid).put(userid, (int)t);
countcell++;
}
}
br.close();
//System.out.println("Ratings : " + countcell);

}

public void makeTwoRatingDataCells(double valPerc, double testPerc){
String relation = "busrate-"+loc+"-";
int count2 = 0; int count = 0;
// Read all rating data from Map and create busRate data object for threshold = 2
for(String bus : ratings.keySet()){
for(String user : ratings.get(bus).keySet()){
int rate = ratings.get(bus).get(user);
count++;
Cell cell = new Cell();
cell.relation_id = relation+"2";
cell.entity_ids.add(bus);
cell.entity_ids.add(user);
if(rate >= 2){
cell.truth = true;
busRate[0].Data.add(cell);
}
else{
cell.truth = false;
busRate[0].Data.add(cell);
}
}
}
busRate[0].splitTrainTestValidation(valPerc, testPerc);
//busRate[0].dataStats();
makeRestRatingDataCells();

System.out.println("Ratings : " + busRate[0].Data.size());
}

public void makeRestRatingDataCells(){
String r = "busrate-"+loc+"-";
for(Cell cell : busRate[0].trainData){
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
int rate = ratings.get(b).get(u);
if(rate <= 2)
addCellsinRestTrain(r, b, u, false, false, false);

if(rate == 3)
addCellsinRestTrain(r, b, u, true, false, false);

if(rate == 4)
addCellsinRestTrain(r, b, u, true, true, false);

if(rate == 5)
addCellsinRestTrain(r, b, u, true, true, true);
}

for(Cell cell : busRate[0].testData){
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
int rate = ratings.get(b).get(u);
if(rate <= 2)
addCellsinRestTest(r, b, u, false, false, false);
if(rate == 3)
addCellsinRestTest(r, b, u, true, false, false);
if(rate == 4)
addCellsinRestTest(r, b, u, true, true, false);
if(rate == 5)
addCellsinRestTest(r, b, u, true, true, true);
}

for(Cell cell : busRate[0].valData){
String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
int rate = ratings.get(b).get(u);
if(rate <= 2)
addCellsinRestVal(r, b, u, false, false, false);
if(rate == 3)
addCellsinRestVal(r, b, u, true, false, false);
if(rate == 4)
addCellsinRestVal(r, b, u, true, true, false);
if(rate == 5)
addCellsinRestVal(r, b, u, true, true, true);
}

}

public void addCellsinRestTrain(String r, String b, String u, boolean t3, boolean t4, boolean t5){
busRate[1].Data.add(new Cell(r+3, b, u, t3));
busRate[1].trainData.add(new Cell(r+3, b, u, t3));

busRate[2].Data.add(new Cell(r+4, b, u, t4));
busRate[2].trainData.add(new Cell(r+4, b, u, t4));

busRate[3].Data.add(new Cell(r+5, b, u, t5));
busRate[3].trainData.add(new Cell(r+5, b, u, t5));
}

public void addCellsinRestTest(String r, String b, String u, boolean t3, boolean t4, boolean t5){
busRate[1].Data.add(new Cell(r+3, b, u, t3));
busRate[1].testData.add(new Cell(r+3, b, u, t3));

busRate[2].Data.add(new Cell(r+4, b, u, t4));
busRate[2].testData.add(new Cell(r+4, b, u, t4));

busRate[3].Data.add(new Cell(r+5, b, u, t5));
busRate[3].testData.add(new Cell(r+5, b, u, t5));
}

public void addCellsinRestVal(String r, String b, String u, boolean t3, boolean t4, boolean t5){
busRate[1].Data.add(new Cell(r+3, b, u, t3));
busRate[1].valData.add(new Cell(r+3, b, u, t3));

busRate[2].Data.add(new Cell(r+4, b, u, t4));
busRate[2].valData.add(new Cell(r+4, b, u, t4));

busRate[3].Data.add(new Cell(r+5, b, u, t5));
busRate[3].valData.add(new Cell(r+5, b, u, t5));
}

}

Loading

0 comments on commit 304eac5

Please sign in to comment.