diff --git a/Project/.classpath b/Project/.classpath
new file mode 100644
index 0000000..e197418
--- /dev/null
+++ b/Project/.classpath
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Project/.project b/Project/.project
new file mode 100644
index 0000000..772bcfe
--- /dev/null
+++ b/Project/.project
@@ -0,0 +1,17 @@
+
+
+ logisitic_cmf_yelp
+
+
+
+
+
+ org.eclipse.jdt.core.javabuilder
+
+
+
+
+
+ org.eclipse.jdt.core.javanature
+
+
diff --git a/Project/bin/logisticCMF/Cell.class b/Project/bin/logisticCMF/Cell.class
new file mode 100644
index 0000000..4da479f
Binary files /dev/null and b/Project/bin/logisticCMF/Cell.class differ
diff --git a/Project/bin/logisticCMF/Eval.class b/Project/bin/logisticCMF/Eval.class
new file mode 100644
index 0000000..f9ea524
Binary files /dev/null and b/Project/bin/logisticCMF/Eval.class differ
diff --git a/Project/bin/logisticCMF/Rating.class b/Project/bin/logisticCMF/Rating.class
new file mode 100644
index 0000000..2688baa
Binary files /dev/null and b/Project/bin/logisticCMF/Rating.class differ
diff --git a/Project/bin/logisticCMF/Util.class b/Project/bin/logisticCMF/Util.class
new file mode 100644
index 0000000..3815de8
Binary files /dev/null and b/Project/bin/logisticCMF/Util.class differ
diff --git a/Project/bin/logisticCMF/codeTest.class b/Project/bin/logisticCMF/codeTest.class
new file mode 100644
index 0000000..f745351
Binary files /dev/null and b/Project/bin/logisticCMF/codeTest.class differ
diff --git a/Project/bin/logisticCMF/data.class b/Project/bin/logisticCMF/data.class
new file mode 100644
index 0000000..54f8f0e
Binary files /dev/null and b/Project/bin/logisticCMF/data.class differ
diff --git a/Project/bin/logisticCMF/embedding.class b/Project/bin/logisticCMF/embedding.class
new file mode 100644
index 0000000..2a99242
Binary files /dev/null and b/Project/bin/logisticCMF/embedding.class differ
diff --git a/Project/bin/logisticCMF/embeddings.class b/Project/bin/logisticCMF/embeddings.class
new file mode 100644
index 0000000..2af64dd
Binary files /dev/null and b/Project/bin/logisticCMF/embeddings.class differ
diff --git a/Project/bin/logisticCMF/learner.class b/Project/bin/logisticCMF/learner.class
new file mode 100644
index 0000000..5dae12a
Binary files /dev/null and b/Project/bin/logisticCMF/learner.class differ
diff --git a/Project/bin/logisticCMF/writeDataToFile.class b/Project/bin/logisticCMF/writeDataToFile.class
new file mode 100644
index 0000000..5ab457e
Binary files /dev/null and b/Project/bin/logisticCMF/writeDataToFile.class differ
diff --git a/Project/bin/postProcessing/EntityEmbeddings.class b/Project/bin/postProcessing/EntityEmbeddings.class
new file mode 100644
index 0000000..24cea7f
Binary files /dev/null and b/Project/bin/postProcessing/EntityEmbeddings.class differ
diff --git a/Project/bin/postProcessing/Similarity.class b/Project/bin/postProcessing/Similarity.class
new file mode 100644
index 0000000..58ce3b5
Binary files /dev/null and b/Project/bin/postProcessing/Similarity.class differ
diff --git a/Project/bin/postProcessing/Util$1.class b/Project/bin/postProcessing/Util$1.class
new file mode 100644
index 0000000..c4a5b10
Binary files /dev/null and b/Project/bin/postProcessing/Util$1.class differ
diff --git a/Project/bin/postProcessing/Util.class b/Project/bin/postProcessing/Util.class
new file mode 100644
index 0000000..aa47689
Binary files /dev/null and b/Project/bin/postProcessing/Util.class differ
diff --git a/Project/bin/yelpDataProcessing/AttributeCategory.class b/Project/bin/yelpDataProcessing/AttributeCategory.class
new file mode 100644
index 0000000..e786bb7
Binary files /dev/null and b/Project/bin/yelpDataProcessing/AttributeCategory.class differ
diff --git a/Project/bin/yelpDataProcessing/ProcessYelpJson.class b/Project/bin/yelpDataProcessing/ProcessYelpJson.class
new file mode 100644
index 0000000..18cb83e
Binary files /dev/null and b/Project/bin/yelpDataProcessing/ProcessYelpJson.class differ
diff --git a/Project/bin/yelpDataProcessing/reviewData.class b/Project/bin/yelpDataProcessing/reviewData.class
new file mode 100644
index 0000000..9fffa50
Binary files /dev/null and b/Project/bin/yelpDataProcessing/reviewData.class differ
diff --git a/Project/bin/yelpDataProcessing/reviewJson.class b/Project/bin/yelpDataProcessing/reviewJson.class
new file mode 100644
index 0000000..86e2bfb
Binary files /dev/null and b/Project/bin/yelpDataProcessing/reviewJson.class differ
diff --git a/Project/src/logisticCMF/Cell.java b/Project/src/logisticCMF/Cell.java
new file mode 100644
index 0000000..e7c851e
--- /dev/null
+++ b/Project/src/logisticCMF/Cell.java
@@ -0,0 +1,22 @@
+package logisticCMF;
+import java.util.*;
+
+
+public class Cell {
+ String relation_id; // Id of relation in this data element
+ ArrayList entity_ids; // List of entities participating in Relation ' relation_id '
+ boolean truth; // Truth Value for this data element
+
+ public Cell(){
+ entity_ids = new ArrayList();
+ }
+
+ public Cell(String r, String e1, String e2, boolean t){
+ entity_ids = new ArrayList();
+ relation_id = r;
+ entity_ids.add(e1);
+ entity_ids.add(e2);
+ truth = t;
+ }
+
+}
diff --git a/Project/src/logisticCMF/Eval.java b/Project/src/logisticCMF/Eval.java
new file mode 100644
index 0000000..5fd2542
--- /dev/null
+++ b/Project/src/logisticCMF/Eval.java
@@ -0,0 +1,157 @@
+package logisticCMF;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+
+public class Eval {
+ static Map relTrue = new HashMap(); // Correct predictions per relation
+ static Map relCount = new HashMap(); // Total test size per relation
+ static Map relL2 = new HashMap();
+
+ static Map relActualTruth = new HashMap(); // Actual Truth values in each relation
+ static Map relPredTruth = new HashMap(); // Truth predicted in each relation
+ static Map relTruthCorrect = new HashMap(); // Correct Truth predictions in each relation
+
+ static Map> relEvalMap = new HashMap>(); //Map [Rel,] Also contains "weighted average" as relation
+
+
+ public static Map> getEvalMap(data Data, embeddings e, String set){
+ refreshMaps();
+ if(set.equals("test")){
+ for(Cell cell : Data.testData)
+ updateTestMaps(cell, e);
+ }
+ else{
+ for(Cell cell : Data.valData)
+ updateTestMaps(cell, e);
+ }
+ makeRelationEvalMap(); // Final Map [Relation, ]
+ return relEvalMap;
+ }
+
+ public static void refreshMaps(){
+ relTrue = new HashMap();
+ relCount = new HashMap();
+ relL2 = new HashMap();
+ relActualTruth = new HashMap(); // Actual Truth values in each relation
+ relPredTruth = new HashMap(); // Truth predicted in each relation
+ relTruthCorrect = new HashMap(); // Correct Truth predictions in each relation
+ relEvalMap = new HashMap>();
+ }
+
+ public static void addInRelCountMap(Cell cell){
+ if(!relCount.containsKey(cell.relation_id))
+ relCount.put(cell.relation_id, 1);
+ else
+ relCount.put(cell.relation_id, relCount.get(cell.relation_id)+1);
+ }
+
+ public static void addInprfRelationMap(String relation, Integer actual, Integer pred, Integer correct){
+ if(!relActualTruth.containsKey(relation)){
+ relActualTruth.put(relation, 0);
+ relPredTruth.put(relation, 0);
+ relTruthCorrect.put(relation, 0);
+ }
+
+ else{
+ if(!(actual == 0 && pred == 0)){
+ relActualTruth.put(relation, relActualTruth.get(relation) + actual);
+ relPredTruth.put(relation, relPredTruth.get(relation) + pred);
+ relTruthCorrect.put(relation, relTruthCorrect.get(relation) + correct);
+ }
+ }
+ }
+
+ public static void addInRelAccuracyMap(Cell cell, Integer correct){
+ if(!relTrue.containsKey(cell.relation_id))
+ relTrue.put(cell.relation_id, correct);
+ else
+ relTrue.put(cell.relation_id, relTrue.get(cell.relation_id)+correct);
+ }
+
+ public static void addInRelL2Map(Cell cell, double l2){
+ if(!relL2.containsKey(cell.relation_id))
+ relL2.put(cell.relation_id, l2);
+ else
+ relL2.put(cell.relation_id, relTrue.get(cell.relation_id) + l2);
+ }
+
+ //Map [Rel, ]
+ public static void makeRelationEvalMap(){
+ //makeTestMaps(Data, e);
+ double f = 0; double wf1=0.0, waccuracy=0.0, wp = 0.0, wr = 0.0; int total = 0;
+ for(String rel : relActualTruth.keySet()){
+ double accuracy = ((double)relTrue.get(rel))/relCount.get(rel);
+ double precision = (double)relTruthCorrect.get(rel) / relPredTruth.get(rel) ;
+ double recall = (double) relTruthCorrect.get(rel) / relActualTruth.get(rel) ;
+ double f1 = 2*precision*recall / (precision + recall) ;
+ //System.out.println("a : " + accuracy + " p : " + precision + " r : " + recall + " f1 : " +f1);
+ relEvalMap.put(rel, new ArrayList());
+ relEvalMap.get(rel).add(round(accuracy, 3)); // Accuracy
+ relEvalMap.get(rel).add(round(precision, 3)); // Precision
+ relEvalMap.get(rel).add(round(recall, 3)); // Recall
+ relEvalMap.get(rel).add(round(f1, 3)); // F1
+ wf1 += (relCount.get(rel)*f1);
+ wp += (relCount.get(rel)*precision);
+ wr += (relCount.get(rel)*recall);
+ waccuracy += (relCount.get(rel)*accuracy);
+ total += relCount.get(rel);
+ }
+ wf1 = round(wf1/total, 3); waccuracy = round(waccuracy/total, 3); wp = round(wp/total, 3); wr = round(wr/total, 3);
+ relEvalMap.put("average", new ArrayList());
+ relEvalMap.get("average").add(round(waccuracy, 3)); // Accuracy
+ relEvalMap.get("average").add(round(wp, 3)); // Precision
+ relEvalMap.get("average").add(round(wr, 3)); // Recall
+ relEvalMap.get("average").add(round(wf1, 3)); // F1
+
+ }
+
+ public static void updateTestMaps(Cell cell, embeddings e){
+ int correct = 0; int t =0, f=0; int c = 0; double l2Sum = 0.0;
+ double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha);
+ double sigmdot = learner.sigm(dot);
+ int pred = (sigmdot >= 0.5) ? 1 : 0;
+ int truth = (cell.truth) ? 1 : 0;
+ double l2 = (truth - sigmdot)*(truth-sigmdot);
+ l2Sum += l2;
+ //System.out.println(sigmdot + " " + pred + " " + truth + " " + l2);
+
+ if(pred == truth)
+ correct = 1;
+ else
+ correct = 0;
+ c += correct;
+ //System.out.println("rel : " + truth + " pred : " + pred);
+ //System.out.println(cell.relation_id + " : " + pred);
+ addInprfRelationMap(cell.relation_id, truth, pred, correct);
+ addInRelAccuracyMap(cell, correct);
+ addInRelCountMap(cell);
+ addInRelL2Map(cell, l2);
+ }
+
+ public static void printEval(){
+ for(String rel : relEvalMap.keySet()){
+ ArrayList eval = relEvalMap.get(rel);
+ System.out.print(rel + " : ");
+ System.out.println("P : " + eval.get(1) + " R : " + eval.get(2) + " F1 : " + eval.get(3) + " Accuracy : " + eval.get(0));
+ }
+ }
+
+
+ public static double round(double value, int places) {
+ if (places < 0) throw new IllegalArgumentException();
+ if(Double.isNaN(value))
+ return Double.NaN;
+
+ BigDecimal bd = new BigDecimal(value);
+ bd = bd.setScale(places, RoundingMode.HALF_UP);
+ return bd.doubleValue();
+ }
+
+
+
+}
diff --git a/Project/src/logisticCMF/Rating.java b/Project/src/logisticCMF/Rating.java
new file mode 100644
index 0000000..00d1b26
--- /dev/null
+++ b/Project/src/logisticCMF/Rating.java
@@ -0,0 +1,177 @@
+package logisticCMF;
+import java.io.*;
+import java.util.*;
+
+
+
+public class Rating {
+
+ public Map> ratings; // Map[busId, Map[User, Rating]]
+ public String loc;
+ data [] busRate = new data[4];
+
+ public Rating(String folder, double valp, double testp) throws IOException{
+ ratings = new HashMap>();
+ loc = folder;
+ busRate[0] = new data();
+ busRate[1] = new data();
+ busRate[2] = new data();
+ busRate[3] = new data();
+ readRatingData(folder);
+ makeTwoRatingDataCells(valp, testp);
+ /*for(data rd : busRate){
+ rd.dataStats();
+ }*/
+ }
+
+ public void testRatings(){
+ data mergeData = new data();
+ ArrayList tomerge = new ArrayList();
+ for(data d : busRate)
+ tomerge.add(d);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ codeTest.learnAndTest(mergeData, 30, false, 0, false, 0);
+ }
+
+ public void readRatingData(String folder) throws NumberFormatException, IOException{
+ String address = System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt";
+ BufferedReader br = new BufferedReader(new FileReader(address));
+ String line; int countl = 0, countcell = 0;
+ String busid = null, userid = null; boolean value = false;
+ while((line = br.readLine()) != null){
+ countl++;
+ String[] array = line.split(":");
+ if( array[0].trim().equals("bus_id")){
+ busid = array[1].trim();
+ if(!ratings.containsKey(busid))
+ ratings.put(busid, new HashMap());
+ }
+ if( array[0].trim().equals("user_id"))
+ userid = array[1].trim();
+ if( array[0].trim().equals("star")){
+ double t = Double.parseDouble(array[1].trim());
+ /*if(ratings.get(busid).containsKey(userid)){
+ int r = ratings.get(busid).get(userid);
+ System.out.println("user already exists " + busid + " " + userid);
+
+ }*/
+ ratings.get(busid).put(userid, (int)t);
+ countcell++;
+ }
+ }
+ br.close();
+ //System.out.println("Ratings : " + countcell);
+
+ }
+
+ public void makeTwoRatingDataCells(double valPerc, double testPerc){
+ String relation = "busrate-"+loc+"-";
+ int count2 = 0; int count = 0;
+ // Read all rating data from Map and create busRate data object for threshold = 2
+ for(String bus : ratings.keySet()){
+ for(String user : ratings.get(bus).keySet()){
+ int rate = ratings.get(bus).get(user);
+ count++;
+ Cell cell = new Cell();
+ cell.relation_id = relation+"2";
+ cell.entity_ids.add(bus);
+ cell.entity_ids.add(user);
+ if(rate >= 2){
+ cell.truth = true;
+ busRate[0].Data.add(cell);
+ }
+ else{
+ cell.truth = false;
+ busRate[0].Data.add(cell);
+ }
+ }
+ }
+ busRate[0].splitTrainTestValidation(valPerc, testPerc);
+ //busRate[0].dataStats();
+ makeRestRatingDataCells();
+
+ System.out.println("Ratings : " + busRate[0].Data.size());
+ }
+
+ public void makeRestRatingDataCells(){
+ String r = "busrate-"+loc+"-";
+ for(Cell cell : busRate[0].trainData){
+ String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
+ int rate = ratings.get(b).get(u);
+ if(rate <= 2)
+ addCellsinRestTrain(r, b, u, false, false, false);
+
+ if(rate == 3)
+ addCellsinRestTrain(r, b, u, true, false, false);
+
+ if(rate == 4)
+ addCellsinRestTrain(r, b, u, true, true, false);
+
+ if(rate == 5)
+ addCellsinRestTrain(r, b, u, true, true, true);
+ }
+
+ for(Cell cell : busRate[0].testData){
+ String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
+ int rate = ratings.get(b).get(u);
+ if(rate <= 2)
+ addCellsinRestTest(r, b, u, false, false, false);
+ if(rate == 3)
+ addCellsinRestTest(r, b, u, true, false, false);
+ if(rate == 4)
+ addCellsinRestTest(r, b, u, true, true, false);
+ if(rate == 5)
+ addCellsinRestTest(r, b, u, true, true, true);
+ }
+
+ for(Cell cell : busRate[0].valData){
+ String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1);
+ int rate = ratings.get(b).get(u);
+ if(rate <= 2)
+ addCellsinRestVal(r, b, u, false, false, false);
+ if(rate == 3)
+ addCellsinRestVal(r, b, u, true, false, false);
+ if(rate == 4)
+ addCellsinRestVal(r, b, u, true, true, false);
+ if(rate == 5)
+ addCellsinRestVal(r, b, u, true, true, true);
+ }
+
+ }
+
+ public void addCellsinRestTrain(String r, String b, String u, boolean t3, boolean t4, boolean t5){
+ busRate[1].Data.add(new Cell(r+3, b, u, t3));
+ busRate[1].trainData.add(new Cell(r+3, b, u, t3));
+
+ busRate[2].Data.add(new Cell(r+4, b, u, t4));
+ busRate[2].trainData.add(new Cell(r+4, b, u, t4));
+
+ busRate[3].Data.add(new Cell(r+5, b, u, t5));
+ busRate[3].trainData.add(new Cell(r+5, b, u, t5));
+ }
+
+ public void addCellsinRestTest(String r, String b, String u, boolean t3, boolean t4, boolean t5){
+ busRate[1].Data.add(new Cell(r+3, b, u, t3));
+ busRate[1].testData.add(new Cell(r+3, b, u, t3));
+
+ busRate[2].Data.add(new Cell(r+4, b, u, t4));
+ busRate[2].testData.add(new Cell(r+4, b, u, t4));
+
+ busRate[3].Data.add(new Cell(r+5, b, u, t5));
+ busRate[3].testData.add(new Cell(r+5, b, u, t5));
+ }
+
+ public void addCellsinRestVal(String r, String b, String u, boolean t3, boolean t4, boolean t5){
+ busRate[1].Data.add(new Cell(r+3, b, u, t3));
+ busRate[1].valData.add(new Cell(r+3, b, u, t3));
+
+ busRate[2].Data.add(new Cell(r+4, b, u, t4));
+ busRate[2].valData.add(new Cell(r+4, b, u, t4));
+
+ busRate[3].Data.add(new Cell(r+5, b, u, t5));
+ busRate[3].valData.add(new Cell(r+5, b, u, t5));
+ }
+
+}
+
diff --git a/Project/src/logisticCMF/Util.java b/Project/src/logisticCMF/Util.java
new file mode 100644
index 0000000..4e4ef05
--- /dev/null
+++ b/Project/src/logisticCMF/Util.java
@@ -0,0 +1,156 @@
+package logisticCMF;
+
+import java.io.IOException;
+import java.util.ArrayList;
+
+
+
+/*import org.jfree.chart.ChartFactory;
+import org.jfree.chart.ChartFrame;
+import org.jfree.chart.JFreeChart;
+import org.jfree.chart.axis.ValueAxis;
+import org.jfree.chart.plot.PlotOrientation;
+import org.jfree.chart.plot.XYPlot;
+import org.jfree.data.xy.XYSeries;
+import org.jfree.data.xy.XYSeriesCollection;*/
+import java.util.*;
+
+//import postProcessing.data;
+
+
+public class Util {
+
+ static Random seed = new Random(100);
+
+ public static HashSet getColdEntites(ArrayList data, int index, double perc){
+ HashSet entities = new HashSet();
+ for(Cell cell : data)
+ entities.add(cell.entity_ids.get(index));
+
+ ArrayList ents = new ArrayList(entities);
+ Collections.shuffle(ents, seed);
+ HashSet coldEntities = new HashSet(ents.subList(0, (int)(perc*entities.size()/100.0)));
+ return coldEntities;
+
+ }
+
+ public static int getNegSampleSize(data Data){
+ return (int)(((double)Data.Data.size())/Data.entityIds.get(0).size());
+ }
+
+ public static double getMatrixDetails(data Data){
+ System.out.println("rows : " + Data.entityIds.get(0).size());
+ System.out.println("cols : " + Data.entityIds.get(1).size());
+ System.out.println("size : " + Data.Data.size());
+ return ((double)Data.Data.size())/Data.entityIds.get(0).size();
+
+ }
+
+ /*public static void plotGraph(ArrayList x, ArrayList y, String filename) throws IOException{
+ XYSeries series = new XYSeries(filename);
+ for(int i =0; i entities = new HashSet();
+ int index = 0, flag = 0;
+ for(Cell cell : D.trainData){
+ entities.add(cell.entity_ids.get(index));
+ }
+ for(Cell cell : D.valData){
+ entities.add(cell.entity_ids.get(index));
+ }
+ for(Cell cell : D.testData){
+ if(entities.contains(cell.entity_ids.get(index))){
+ System.out.print("WRONG COLD START SPLIT, Index : " + index + ", ");
+ flag = 1;
+ break;
+ }
+ }
+ if(flag == 0){
+ System.out.print("Index ColdStart : " + index + ", ");
+ }
+
+ index = 1; flag = 0;
+ entities = new HashSet();
+ for(Cell cell : D.trainData){
+ entities.add(cell.entity_ids.get(index));
+ }
+ for(Cell cell : D.valData){
+ entities.add(cell.entity_ids.get(index));
+ }
+ for(Cell cell : D.testData){
+ if(entities.contains(cell.entity_ids.get(index))){
+ System.out.println("NO COLD START SPLIT, Index : " + index);
+ flag = 1;
+ break;
+ }
+ }
+ if(flag == 0)
+ System.out.println("Index ColdStart : " + index);
+
+ }
+
+ public static void countEntities(ArrayList data){
+ HashSet e1 = new HashSet();
+ HashSet e2 = new HashSet();
+
+ for(Cell cell : data){
+ e1.add(cell.entity_ids.get(0));
+ e2.add(cell.entity_ids.get(1));
+ }
+
+ System.out.println("e1 : " + e1.size() + " e2 : "+e2.size());
+
+ }
+
+ public static void implicitColdStart(ArrayList trdata, ArrayList testData){
+ HashSet tre1 = new HashSet();
+ HashSet tre2 = new HashSet();
+
+ for(Cell cell : trdata){
+ tre1.add(cell.entity_ids.get(0));
+ tre2.add(cell.entity_ids.get(1));
+ }
+
+ System.out.println("tre1 : " + tre1.size() + " tre2 : "+ tre2.size());
+
+ HashSet tee1 = new HashSet();
+ HashSet tee2 = new HashSet();
+ for(Cell cell : testData){
+ tee1.add(cell.entity_ids.get(0));
+ tee2.add(cell.entity_ids.get(1));
+ }
+
+ System.out.println("tee1 : " + tee1.size() + " tee2 : "+ tee2.size());
+ tee1.removeAll(tre1);
+ tee2.removeAll(tre2);
+
+
+ System.out.println("e1 cold : " + tee1.size() + " e2 cold : " + tee2.size() );
+
+ }
+}
diff --git a/Project/src/logisticCMF/codeTest.java b/Project/src/logisticCMF/codeTest.java
new file mode 100644
index 0000000..e3460a1
--- /dev/null
+++ b/Project/src/logisticCMF/codeTest.java
@@ -0,0 +1,676 @@
+package logisticCMF;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Random;
+
+public class codeTest {
+
+ static Random rand = new Random();
+
+ private static final long MEGABYTE = 1024L * 1024L;
+ public static long bytesToMegabytes(long bytes) {
+ return bytes / MEGABYTE;
+ }
+ public static void getMemoryDetails(){
+ // Get the Java runtime
+ Runtime runtime = Runtime.getRuntime();
+ // Run the garbage collector
+ //runtime.gc();
+ // Calculate the used memory
+ System.out.println("Total memory : " + bytesToMegabytes(runtime.totalMemory()) + " Free memory : " + bytesToMegabytes(runtime.freeMemory()) +
+ " Used memory is megabytes: " + bytesToMegabytes(runtime.totalMemory() - runtime.freeMemory()));
+
+ }
+
+ // Input Data, that is split in Train, Validation and Test. Input K. Learn and stop on convergence. Then test with learnt no. of epochs.
+ public static embeddings learnAndTest(data Data, int K, boolean busWord, int busWordNegSamSize, boolean userWord, int userWordNegSamSize){
+ embeddings e = new embeddings(Data, K);
+ learner l = new learner();
+ embeddings eBest = l.learnAndStop1(Data, e, false, false, false, busWord, busWordNegSamSize, userWord, userWordNegSamSize);
+ System.out.println("learning done");
+ Eval.getEvalMap(Data, eBest, "test");
+ Eval.printEval();
+ return eBest;
+ }
+
+ public static data readAttributes(String folder, double valPerc, double testPerc, boolean coldStart, int index) throws IOException{
+ data att = new data();
+ att.readBusAtt(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/busAtt.txt", folder);
+ if(coldStart)
+ att.splitColdStart(valPerc, testPerc, index);
+ else
+ att.splitTrainTestValidation(valPerc, testPerc);
+
+ return att;
+ }
+
+ public static data readCategories(String folder, int pruneThresh) throws IOException{
+ data cat = new data();
+ cat.readAndCompleteCategoryData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/busCat.txt", pruneThresh, folder);
+ cat.splitTrainTestValidation(0.0, 0.0);
+ return cat;
+ }
+
+ public static data readRatings(String folder, double valPerc, double testPerc, boolean coldStart, int index) throws IOException{
+ data rate = new data();
+ rate.readRatingData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt", folder);
+ if(coldStart)
+ rate.splitColdStart(valPerc, testPerc, index);
+ else
+ rate.splitTrainTestValidation(valPerc, testPerc);
+ return rate;
+ }
+
+ public static data readReviewData(String folder, int occThresh, boolean busWord, boolean userWord, double valPerc, double testPerc) throws IOException{
+ data rD = new data();
+ rD.readReviewData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews_textProc.txt"); // Makes EnWord maps (Business and User Word maps) and Word-Count map
+ rD.pruneVocab_EntityMap(occThresh);
+ if(busWord)
+ rD.makeEnWordCells("b-word");
+ if(userWord)
+ rD.makeEnWordCells("u-word");
+ rD.splitTrainTestValidation(valPerc, testPerc);
+ return rD;
+ }
+
+ public static void completeEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS,
+ boolean attCold, int coldIndexAtt, boolean rateCold, int coldIndexRate, String folderToWriteData) throws IOException{
+
+ System.out.print("Attribute Cold Start, ");
+ Util.checkColdStartSanity(A);
+ Util.implicitColdStart(A.trainData, A.testData);
+ System.out.print("Rate Cold Start, ");
+ Util.checkColdStartSanity(R);
+ Util.implicitColdStart(R.trainData, R.testData);
+
+ ArrayList tomerge = new ArrayList();
+ data mergeData = new data();
+
+ getMemoryDetails();
+
+ System.out.println("################################################### "+folder+" - Att-Word ######################################");
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"A-A", folder, A, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+BW", folder, A, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+UW", folder, A, e);
+ if(userWord && busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+BUW", folder, A, e);
+ getMemoryDetails();
+
+ System.out.println("################################################### "+folder+" - Att-Cat-Word ######################################");
+ //rD = readReviewData(folder, 10, busWord, false, 0.0, 0.0);
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C", folder, A, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BW", folder, A, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+UW", folder, A, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BUW", folder, A, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Rating - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(R);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"R-R", folder, R, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+BW", folder, R, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+UW", folder, R, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+BUW", folder, R, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Rating - Cat - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(R);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C", folder, R, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BW", folder, R, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+UW", folder, R, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BUW", folder, R, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Att Rate - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R", folder, R, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BW", folder, R, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+UW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+UW", folder, R, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BUW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BUW", folder, R, e);
+ }
+ System.gc();
+
+
+ System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data();
+ mergeData.busWord = W.busWord;
+ mergeData.words = W.words;
+ mergeData.wordCount = W.wordCount;
+ mergeData.userWord = W.userWord;
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R", folder, R, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BW", folder, R, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+UW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+UW", folder, R, e);
+ }
+ if(busWord && userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BUW", folder, A, e);
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BUW", folder, R, e);
+ }
+ getMemoryDetails();
+
+ if(busWord || userWord){
+ if(busWord && !userWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e);
+ }
+ }
+
+ }
+
+ public static void AttBusColdCompleteEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS, String folderToWriteData) throws IOException{
+ System.out.print("Attribute Cold Start, ");
+ Util.checkColdStartSanity(A);
+ Util.implicitColdStart(A.trainData, A.testData);
+
+ ArrayList tomerge = new ArrayList();
+ data mergeData = new data();
+
+ getMemoryDetails();
+
+ System.out.println("################################################### "+folder+" - Att-Word ######################################");
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"A-A", folder, A, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+BW", folder, A, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+UW", folder, A, e);
+ if(userWord && busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+BUW", folder, A, e);
+ getMemoryDetails();
+
+ System.out.println("################################################### "+folder+" - Att-Cat-Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C", folder, A, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BW", folder, A, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+UW", folder, A, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BUW", folder, A, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Att Rate - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R", folder, A, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BW", folder, A, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+UW", folder, A, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BUW", folder, A, e);
+ }
+ System.gc();
+
+
+ System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R", folder, A, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BW", folder, A, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+UW", folder, A, e);
+ }
+ if(busWord && userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BUW", folder, A, e);
+ }
+ getMemoryDetails();
+
+ if(busWord || userWord){
+ if(busWord && !userWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e);
+ }
+ }
+
+ }
+
+ public static void RateColdCompleteEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS,
+ String folderToWriteData) throws IOException{
+
+ System.out.print("Rate Cold Start, ");
+ Util.checkColdStartSanity(R);
+ Util.implicitColdStart(R.trainData, R.testData);
+
+ ArrayList tomerge = new ArrayList();
+ data mergeData = new data();
+
+ getMemoryDetails();
+
+ System.out.println("################################################### "+folder+" - Rating - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(R);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"R-R", folder, R, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+BW", folder, R, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+UW", folder, R, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+BUW", folder, R, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Rating - Cat - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(R);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord))
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C", folder, R, e);
+ if(busWord && !userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BW", folder, R, e);
+ if(userWord && !busWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+UW", folder, R, e);
+ if(busWord && userWord)
+ writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BUW", folder, R, e);
+ System.gc();
+
+ System.out.println("################################################### "+folder+" - Att Rate - Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R", folder, R, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BW", folder, R, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+UW", folder, R, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BUW", folder, R, e);
+ }
+ System.gc();
+
+
+ System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################");
+ tomerge = new ArrayList();
+ tomerge.clear();
+ tomerge.add(A);
+ tomerge.add(R);
+ tomerge.add(C);
+ tomerge.add(W);
+ mergeData = new data(W);
+ mergeData.addDataAfterSplit(tomerge);
+ mergeData.dataStats();
+ e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS);
+ if(!(busWord || userWord)){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R", folder, R, e);
+ }
+ if(busWord && !userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BW", folder, R, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+UW", folder, R, e);
+ }
+ if(busWord && userWord){
+ writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BUW", folder, R, e);
+ }
+ getMemoryDetails();
+
+ if(busWord || userWord){
+ if(busWord && !userWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e);
+ }
+ if(userWord && !busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e);
+ }
+ if(userWord && busWord){
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e);
+ writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e);
+ }
+ }
+ System.gc();
+ }
+
+ public static void performAttBusColdEvaluation(String folder) throws IOException{
+ String folderToWriteData = "AttBusCold/";
+ data A = readAttributes(folder, 15.0, 15.0, true, 0);
+ data C = readCategories(folder, 5);
+ data R = readRatings(folder, 0.0, 0.0, false, 1);
+ data W = new data();
+
+ AttBusColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // Business - Words
+ W = readReviewData(folder, 10, true, false, 0.0, 0.0);
+ int bwNS = Util.getNegSampleSize(W);
+ AttBusColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData);
+ System.gc();
+
+ // User - Words
+ W = readReviewData(folder, 10, false, true, 0.0, 0.0);
+ int uwNS = Util.getNegSampleSize(W);
+ AttBusColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData);
+ System.gc();
+
+ // BusWords and UserWords
+ /*W = readReviewData(folder, 10, true, true, 0.0, 0.0);
+ AttBusColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData);
+ System.gc();*/
+ }
+
+ public static void performRateBusColdEvaluation(String folder) throws IOException{
+ String folderToWriteData = "RateBusCold/";
+ data A = readAttributes(folder, 0.0, 0.0, false, 0);
+ data C = readCategories(folder, 5);
+ data R = readRatings(folder, 15.0, 15.0, true, 0);
+ data W = new data();
+
+ RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // Business - Words
+ W = readReviewData(folder, 10, true, false, 0.0, 0.0);
+ int bwNS = Util.getNegSampleSize(W);
+ RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData);
+ System.gc();
+
+ // User - Words
+ W = readReviewData(folder, 10, false, true, 0.0, 0.0);
+ int uwNS = Util.getNegSampleSize(W);
+ RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData);
+ System.gc();
+
+ // BusWords and UserWords
+ /*W = readReviewData(folder, 10, true, true, 0.0, 0.0);
+ RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData);
+ System.gc();*/
+ }
+
+ public static void performRateUserColdEvaluation(String folder) throws IOException{
+ String folderToWriteData = "RateUserCold/";
+ data A = readAttributes(folder, 0.0, 0.0, false, 0);
+ data C = readCategories(folder, 5);
+ data R = readRatings(folder, 15.0, 15.0, true, 1);
+ data W = new data();
+
+ RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // Business - Words
+ W = readReviewData(folder, 10, true, false, 0.0, 0.0);
+ int bwNS = Util.getNegSampleSize(W);
+ RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData);
+ System.gc();
+
+ // User - Words
+ W = readReviewData(folder, 10, false, true, 0.0, 0.0);
+ int uwNS = Util.getNegSampleSize(W);
+ RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData);
+ System.gc();
+
+ // BusWords and UserWords
+ /*W = readReviewData(folder, 10, true, true, 0.0, 0.0);
+ RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData);
+ System.gc();*/
+ }
+
+ public static void performHeldOutEvaluation(String folder) throws IOException{
+ String folderToWriteData = "HeldOut/";
+ data A = readAttributes(folder, 15.0, 15.0, false, 0);
+ data C = readCategories(folder, 5);
+ data R = readRatings(folder, 15.0, 15.0, false, 1);
+ data W = new data();
+
+ // No Words
+ completeEvaluation(folder, A, C, R, W, false, 0, false, 0, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // Business - Words
+ W = readReviewData(folder, 10, true, false, 0.0, 0.0);
+ int bwNS = Util.getNegSampleSize(W);
+ completeEvaluation(folder, A, C, R, W, true, bwNS, false, 0, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // User - Words
+ W = readReviewData(folder, 10, false, true, 0.0, 0.0);
+ int uwNS = Util.getNegSampleSize(W);
+ completeEvaluation(folder, A, C, R, W, false, 0, true, uwNS, false, 0, false, 0, folderToWriteData);
+ System.gc();
+
+ // BusWords and UserWords
+ /*W = readReviewData(folder, 10, true, true, 0.0, 0.0);
+ getMemoryDetails();
+ completeEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, false, 0, false, 0, folderToWriteData);
+ System.gc();*/
+ }
+
+ // To test one dataset completely and write embeddings for (A + R + C + W)
+ public static void main(String [] args) throws Exception {
+ String folder = args[0];
+ String todo = args[1];
+ todo = "heldOut";
+ folder = "EDH";
+
+ if(todo.equals("heldOut"))
+ performHeldOutEvaluation(folder);
+ if(todo.equals("attBusCold"))
+ performAttBusColdEvaluation(folder);
+ if(todo.equals("rateBusCold"))
+ performRateBusColdEvaluation(folder);
+ if(todo.equals("rateUserCold"))
+ performRateUserColdEvaluation(folder);
+ //attBusColdEvaluations(folder);
+ //rateBusColdEvaluations(folder);
+ }
+
+ // To make the sizes table
+ /*public static void main(String [] args) throws Exception {
+
+ String folder = "WI";
+ data A = readAttributes(folder, 15.0, 15.0, false, 0);
+ //data C = readCategories(folder, 5);
+ data R = readRatings(folder, 0.0, 0.0, false, 1);
+ //data BW = readReviewData(folder, 10, true, false, 0.0, 0.0);
+ //data UW = readReviewData(folder, 10, false, true, 0.0, 0.0);
+
+ A.dataStats();
+
+ Util.getMatrixDetails(A);
+ Util.getMatrixDetails(C);
+ Util.getMatrixDetails(R);
+ //Util.getMatrixDetails(BW);
+
+
+
+ }*/
+
+}
\ No newline at end of file
diff --git a/Project/src/logisticCMF/data.java b/Project/src/logisticCMF/data.java
new file mode 100644
index 0000000..a26f514
--- /dev/null
+++ b/Project/src/logisticCMF/data.java
@@ -0,0 +1,601 @@
+package logisticCMF;
+import java.io.*;
+import java.util.*;
+
+import javax.json.Json;
+import javax.json.JsonNumber;
+import javax.json.JsonObject;
+import javax.json.JsonReader;
+import javax.json.JsonString;
+import javax.json.JsonValue;
+
+public class data {
+
+ public ArrayList trainData;
+ public ArrayList testData;
+ public ArrayList valData;
+ public ArrayList Data;
+ public Map relDataCount; // Map [Relation Id, Count]
+ ArrayList> entityIds;
+
+ public Map wordCount = new HashMap(); // Map [Words -> Count] - In original review Data
+ public Map> busWord = new HashMap>(); // Map [restaurant -> [words used for it]] - Can be pruned according to word occurrence threshold
+ public Map> userWord = new HashMap>(); // Map [user -> [words used for it]] - Can be pruned according to word occurrence threshold
+ public ArrayList words = new ArrayList(); // Set [words] - The vocab for our review data. Can be pruned according to word occurrence threshold
+ public static Random seed = new Random(50); // Also defined in Embedding & Learner Class
+
+
+ public data(){
+ trainData = new ArrayList();
+ valData = new ArrayList();
+ testData = new ArrayList();
+ Data = new ArrayList();
+ relDataCount = new HashMap();
+ wordCount = new HashMap();
+ busWord = new HashMap>();
+ userWord = new HashMap>();
+ words = new ArrayList();
+ entityIds = new ArrayList>();
+ entityIds.add(new HashSet());
+ entityIds.add(new HashSet());
+ }
+
+ public data(data D){
+ busWord = D.busWord;
+ userWord = D.userWord;
+ words = D.words;
+ wordCount = D.wordCount;
+ trainData = new ArrayList();
+ valData = new ArrayList();
+ testData = new ArrayList();
+ Data = new ArrayList();
+ relDataCount = new HashMap();
+ entityIds = new ArrayList>();
+ entityIds.add(new HashSet());
+ entityIds.add(new HashSet());
+ }
+
+ public void addToEntitySets(Cell cell){
+ entityIds.get(0).add(cell.entity_ids.get(0));
+ entityIds.get(1).add(cell.entity_ids.get(1));
+ }
+
+ public void readBusAtt(String fileAddress, String rel) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ String relation = "b-att-"+rel;
+ String line;
+ String bId = "";
+ int count = 0;
+ while( (line = br.readLine()) != null ){
+ //System.out.println(line);
+ String [] arr = line.split(":");
+ if(arr.length >= 2){
+ if(arr[0].trim().equals("business_id"))
+ bId = arr[1].trim();
+
+ else{
+ String att = arr[0].trim();
+ double t = Double.parseDouble(arr[1].trim());
+ Cell cell = new Cell();
+ cell.relation_id = relation;
+ cell.entity_ids.add(bId);
+ cell.entity_ids.add(att);
+ cell.truth = (t == 1.0) ? true : false;
+ addToEntitySets(cell);
+ Data.add(cell);
+ count++;
+ }
+
+
+ }
+ }
+ br.close();
+ System.out.println(relation + " : " + count);
+ }
+
+ public void printData(){
+ for(Cell cell : Data){
+ System.out.print("\n"+cell.relation_id + ", ");
+ for(String e : cell.entity_ids)
+ System.out.print(e + ", ");
+ System.out.print(cell.truth);
+ }
+ }
+
+ public void makeRestaurantSet(String fileAddress) throws IOException{
+ BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress+"1"));
+ HashMap restaurants = new HashMap();
+ for(Cell cell : Data){
+ if(restaurants.containsKey(cell.entity_ids.get(0)))
+ continue;
+ else{
+ restaurants.put(cell.entity_ids.get(0), 1.0);
+ bw.write(cell.entity_ids.get(0) + "\n");
+ }
+ }
+ bw.close();
+ }
+
+ public void refreshTrainTest(){
+ trainData = new ArrayList();
+ valData = new ArrayList();
+ testData = new ArrayList();
+ }
+
+ public void readRatingData(String fileAddress, String rel) throws IOException{
+ Map> ratings = new HashMap>();
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ String line; int countl = 0, countcell = 0;
+ String relation = "b-u-rate-"+rel;
+ String busid = null, userid = null; boolean value = false;
+ while((line = br.readLine()) != null){
+ countl++;
+ String[] array = line.split(":");
+ if( array[0].trim().equals("bus_id")){
+ busid = array[1].trim();
+ if(!ratings.containsKey(busid))
+ ratings.put(busid, new HashMap());
+ }
+ if( array[0].trim().equals("user_id"))
+ userid = array[1].trim();
+ if( array[0].trim().equals("star")){
+ double t = Double.parseDouble(array[1].trim());
+ ratings.get(busid).put(userid, (int)t);
+ }
+ }
+ br.close();
+ for(String bus : ratings.keySet()){
+ for(String user : ratings.get(bus).keySet()){
+ int rate = ratings.get(bus).get(user);
+ Cell cell = new Cell();
+ cell.relation_id = relation;
+ cell.entity_ids.add(bus);
+ cell.entity_ids.add(user);
+ cell.truth = (rate >= 4) ? true : false;
+ Data.add(cell);
+ addToEntitySets(cell);
+ countcell++;
+ }
+ }
+ System.out.println("Ratings read : " + countcell);
+ }
+
+ public void countLines(String fileAddress) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ String line; String relation = "restaurant-user-word"; int count = 0;
+ String busid = null, userid = null; boolean value = false;
+ while((line = br.readLine()) != null){
+ count ++;
+ }
+ System.out.println(count);
+ }
+
+ public void splitTrainTestValidation(double valPercentage, double testPercentage){
+ Collections.shuffle(Data, seed);
+ for(int j = 0; j coldEntities = Util.getColdEntites(Data, index, testPerc);
+ int countTest = 0, i=0;
+ for(Cell cell : Data){
+ if(coldEntities.contains(cell.entity_ids.get(index))){
+ testData.add(cell);
+ countTest++;
+ }
+ else
+ trainData.add(cell);
+ }
+ Collections.shuffle(trainData, seed);
+ Iterator it = trainData.iterator();
+ int cellValidation = (int)((valPercentage/100)*Data.size());
+ System.out.println(cellValidation);
+ while(i < cellValidation){
+ Cell cell = it.next();
+ valData.add(cell);
+ it.remove();
+ i++;
+ }
+ }
+
+ public void addDatainExisting(data d){
+ Collections.shuffle(Data, seed); Collections.shuffle(d.Data, seed);
+ for(Cell cell : d.Data)
+ Data.add(cell);
+
+ Collections.shuffle(Data, seed);
+ }
+
+ public void addDataAfterSplit(ArrayList dataList){
+ for(data d : dataList){
+ for(Cell cell : d.Data)
+ Data.add(cell);
+ for (Cell cell : d.trainData)
+ trainData.add(cell);
+ for(Cell cell : d.valData)
+ valData.add(cell);
+ for(Cell cell : d.testData)
+ testData.add(cell);
+ }
+
+ Collections.shuffle(Data, seed);
+ Collections.shuffle(trainData, seed);
+ Collections.shuffle(valData, seed);
+ Collections.shuffle(testData, seed);
+ }
+
+ // Used by dataStats() function
+ public void countTrueFalse(ArrayList data, Map relTrue, Map relFalse){
+ for(Cell cell : data){
+ int t = (cell.truth) ? 1 : 0; int f = (cell.truth) ? 0 : 1;
+ if(t == 1){
+ if(!relTrue.containsKey(cell.relation_id))
+ relTrue.put(cell.relation_id, 1);
+ else
+ relTrue.put(cell.relation_id, relTrue.get(cell.relation_id) + 1);
+ }
+ if(f == 1){
+ if(!relFalse.containsKey(cell.relation_id))
+ relFalse.put(cell.relation_id, 1);
+ else
+ relFalse.put(cell.relation_id, relFalse.get(cell.relation_id) + 1);
+ }
+ }
+ }
+
+ public void dataStats(){
+ Map relTrue = new HashMap();
+ Map relFalse = new HashMap();
+ countTrueFalse(Data, relTrue, relFalse);
+ System.out.println("\nData Stats");
+ for(String rel : relTrue.keySet())
+ System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel));
+
+
+
+ relTrue.clear();
+ relFalse.clear();
+ countTrueFalse(trainData, relTrue, relFalse);
+ System.out.println("Train Data Stats");
+ for(String rel : relTrue.keySet())
+ System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel));
+
+ relTrue.clear();
+ relFalse.clear();
+ countTrueFalse(valData, relTrue, relFalse);
+ System.out.println("Validation Data Stats");
+ for(String rel : relTrue.keySet())
+ System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel));
+
+ relTrue.clear();
+ relFalse.clear();
+ countTrueFalse(testData, relTrue, relFalse);
+ System.out.println("Test Data Stats");
+ for(String rel : relTrue.keySet())
+ System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel));
+
+ }
+
+ public void reviewDataStats(int entityId, int thresh, boolean removeEntities){
+ Map users = new HashMap(); // Map[EntityID, Count in Set]
+ Set e1 = new HashSet();
+ Set e2 = new HashSet();
+ Set setUsers = new HashSet();
+ int count = 0, max = -1, min = 1000000000;
+ for(Cell cell : Data){
+ String user = cell.entity_ids.get(entityId); /// CHECK WHICH ENTITY MAP IS CREATED
+ e1.add(cell.entity_ids.get(0));
+ e2.add(cell.entity_ids.get(1));
+ setUsers.add(cell.entity_ids.get(entityId));
+ if(!users.containsKey(user)){
+ users.put(user, 1);
+ }
+ else{
+ int revCount = users.get(user);
+ int newRevCount = revCount+1 ;
+ users.put(user, newRevCount);
+ }
+ }
+
+ for(String user : users.keySet()){
+ if(users.get(user) > max)
+ max = users.get(user);
+ if(users.get(user) < min)
+ min = users.get(user);
+ if(users.get(user) <= thresh){
+ count++;
+ setUsers.remove(user);
+ //System.out.println(user + " : " + users.get(user));
+ }
+ }
+ int iterates = 0;
+ if(removeEntities){
+ for(Iterator itr = Data.iterator();itr.hasNext();){
+ iterates++;
+ Cell cell = itr.next();
+ if(!setUsers.contains(cell.entity_ids.get(entityId))){
+ itr.remove();
+ }
+ }
+ }
+ System.out.println("Total iterates = "+iterates);
+ System.out.println(e1.size() + " : " + e2.size() + " : " +users.keySet().size() + " : " + setUsers.size());
+ System.out.println("count : " + count + " max = " + max + " min : " + min);
+ }
+
+ public void readAndCompleteCategoryData(String fileAddress, int thresh, String rel) throws NumberFormatException, IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ Map> resCat = new HashMap>();
+ Map catCount = new HashMap();
+ Set categories = new HashSet();
+ String line; String relation = "b-cat-"+rel;
+ int count = 0;
+ while( (line = br.readLine()) != null) {
+
+ String[] array = line.split(":");
+ resCat.put(array[0].trim(), new ArrayList());
+
+ String [] cats = array[1].trim().split(";"); // Delimiter used in resCat file. Refer to YelpChallenge-yeldData.java
+ for(String cat : cats){
+ String category = cat.trim();
+ if(category.length()>1){
+ categories.add(category);
+ resCat.get(array[0].trim()).add(category);
+
+ // To build Map[Category, Count]
+ if(!catCount.containsKey(category))
+ catCount.put(category, 1);
+ else{
+ int cat_count = catCount.get(category);
+ cat_count++;
+ catCount.put(category, cat_count);
+ }
+ }
+ }
+ }
+ br.close();
+
+
+ int cC = 0;
+ for(String c : catCount.keySet()){
+ if(catCount.get(c) > thresh)
+ cC++;
+ }
+ //System.out.println("Categories after pruning : " + cC);
+
+ for(String res : resCat.keySet()){
+ int categoriesConsidered = 0;
+ for(String cat : categories){
+ if(catCount.get(cat) > thresh){
+ categoriesConsidered++;
+ Cell cell = new Cell();
+ cell.relation_id = relation;
+ cell.entity_ids.add(res);
+ cell.entity_ids.add(cat);
+ if(resCat.get(res).contains(cat))
+ cell.truth = true;
+ else
+ cell.truth = false;
+ count++;
+ Data.add(cell);
+ addToEntitySets(cell);
+ }
+ cC = categoriesConsidered;
+ }
+ }
+ System.out.println(relation + " : " + count + " categories read : " + cC);
+ }
+
+ public void reduceDataSize(double perc){
+ Collections.shuffle(Data, seed);
+ int iterations = 0; int size = Data.size(); int stopAt = (int)((perc/100.0)*size);
+ System.out.println(stopAt);
+ for(Iterator itr = Data.iterator();itr.hasNext();){
+ iterations++;
+ if(iterations <= stopAt)
+ itr.next();
+ else{
+ itr.next();
+ itr.remove();
+ }
+ }
+ }
+
+ public void readReviewData(String fileAddress) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ String line; String bId = null; String userId = null; int countl = 0, countR = 0;
+
+ while((line = br.readLine()) != null){
+ countl++;
+ String[] array = line.split(":");
+
+ if( array[0].trim().equals("user_id")){
+ countR++;
+ userId = array[1].trim();
+ if(!userWord.containsKey(userId))
+ userWord.put(userId, new HashSet());
+ }
+
+ if( array[0].trim().equals("bus_id")){
+ bId = array[1].trim();
+ if(!busWord.containsKey(bId))
+ busWord.put(bId, new HashSet());
+ }
+
+ if(array[0].trim().equals("text")){
+ String [] tokens = array[1].trim().split(" ");
+ if(tokens.length > 0){
+ for(String word : tokens){
+ word = word.trim();
+ if(word.length() >= 3){
+ addWordInMap(word);
+ busWord.get(bId).add(word);
+ userWord.get(userId).add(word);
+ }
+ }
+ }
+ }
+ /*if(countl % 100000 == 0)
+ System.out.println("line : "+countl);*/
+ }
+ System.out.println("Total No. of Reviews : " + countR);
+ }
+
+ public void addWordInMap(String word){
+ if(!wordCount.containsKey(word))
+ wordCount.put(word, 0);
+ int wcount = wordCount.get(word);
+ wcount++;
+ wordCount.put(word, wcount);
+ }
+
+ public void getMapStats(Map> enWord){
+ int potentialResWordCells = 0;
+ int min=100000000;
+ System.out.println("No. of entities = " + enWord.keySet().size());
+ System.out.println("No. of total words in Vocab = " + words.size());
+
+ for(String en : enWord.keySet()){
+ min = (enWord.get(en).size() < min) ? enWord.get(en).size() : min;
+ for(String word : enWord.get(en)){
+ potentialResWordCells++;
+ }
+
+ }
+ System.out.println("Potential Entity-Word Cells : " + potentialResWordCells);
+ System.out.println("Words per entiy : " + ((double) potentialResWordCells)/enWord.keySet().size());
+ System.out.println("Min No. of Words in Entity : " + min);
+
+ }
+
+ public void pruneVocab_EntityMap(int occThresh){
+ makePrunedWordList(occThresh);
+ pruneEntityWordMap(busWord, occThresh);
+ pruneEntityWordMap(userWord, occThresh);
+ }
+
+ // Remove words from Map[Entity -> Set[words]] that occur few times in dictionary. If Set of words for entity go empty, remove Entity from Map.
+ public void pruneEntityWordMap(Map> enWord, int occThresh){
+ Iterator it = enWord.keySet().iterator();
+ while(it.hasNext()){
+ String en = it.next();
+ Iterator itr = enWord.get(en).iterator();
+ while(itr.hasNext()){
+ String word = itr.next();
+ if(wordCount.get(word) <= occThresh)
+ itr.remove();
+ }
+ if(enWord.get(en).size() == 0)
+ it.remove();
+ }
+ }
+
+ public void getWordCountStats(int start, int end){
+ int count = 0;
+ for(int i = start; i<=end; i++){
+ for(String word : wordCount.keySet()){
+ if(wordCount.get(word) > i)
+ count++;
+ }
+ System.out.println("Words with greater that " + i + " count : " + count);
+ count = 0;
+ }
+ }
+
+ // Make a Array of Words that have frequency above the given threshold.
+ public void makePrunedWordList(int occThresh){
+ words = new ArrayList(); int count = 0;
+ for(String word : wordCount.keySet()){
+ if(wordCount.get(word) > occThresh){
+ count++;
+ words.add(word);
+ }
+ }
+ System.out.println("Words with greater than occurence of " + occThresh + " : " + words.size());
+ }
+
+ public void makeEnWordCells(String relation){
+ if(relation.equals("b-word")){
+ enWordCells(busWord, relation);
+ }
+ else if(relation.equals("u-word")){
+ enWordCells(userWord, relation);
+ }
+ }
+
+ public void enWordCells(Map> enWord, String relation){
+ for(String en : enWord.keySet()){
+ for(String word : enWord.get(en)){
+ Cell cell = new Cell();
+ cell.relation_id = relation;
+ cell.entity_ids.add(en);
+ cell.entity_ids.add(word);
+ cell.truth = true;
+ Data.add(cell);
+ addToEntitySets(cell);
+ }
+ }
+ }
+
+ public ArrayList getNegativeSamples(String relation, int negSamplesPerEntity){
+ ArrayList negSamples = new ArrayList();
+ int negSamplesDone = 0;
+ if(relation.equals("b-word")){
+ for(String en : busWord.keySet()){
+ while(negSamplesDone < negSamplesPerEntity){
+ Cell cell = genNegSample(busWord, en, relation);
+ negSamplesDone++;
+ negSamples.add(cell);
+ }
+ negSamplesDone = 0;
+ }
+
+ }
+ else if(relation.equals("u-word")){
+ for(String en : userWord.keySet()){
+ while(negSamplesDone < negSamplesPerEntity){
+ Cell cell = genNegSample(userWord, en, relation);
+ negSamplesDone++;
+ negSamples.add(cell);
+ }
+ negSamplesDone = 0;
+ }
+ }
+ return negSamples;
+ }
+
+ public Cell genNegSample(Map> enWord, String en, String relation){
+ Cell cell = new Cell();
+ cell.relation_id = relation;
+ boolean found = false;
+ while(!found){
+ int pos = randInt(0, words.size() - 1);
+ if(!enWord.get(en).contains(words.get(pos))){
+ cell.entity_ids.add(en);
+ cell.entity_ids.add(words.get(pos));
+ cell.truth = false;
+ found = true;
+ }
+ }
+ return cell;
+ }
+
+ public static int randInt(int min, int max) {
+ // nextInt is normally exclusive of the top value,
+ // so add 1 to make it inclusive
+ int randomNum = seed.nextInt((max - min) + 1) + min;
+
+ return randomNum;
+ }
+
+}
diff --git a/Project/src/logisticCMF/embedding.java b/Project/src/logisticCMF/embedding.java
new file mode 100644
index 0000000..91f38b5
--- /dev/null
+++ b/Project/src/logisticCMF/embedding.java
@@ -0,0 +1,29 @@
+package logisticCMF;
+import java.util.*;
+
+/*
+ * Embedding Class - Stores embedding for each entity which includes :
+ * Bias for each relation the entity belongs to
+ * K- Dimensional Latent Vector for each entity
+*/
+
+public class embedding {
+ Map bias; // Map : [Relation_id, bias]
+ double[] vector; // K-Dimensional Latent Vector
+ static Random rand = new Random(20);
+
+ // For entity realized first time. Put Bias = 0.0 for relation and initialize latent vector to random
+ public embedding(int K){
+ bias = new HashMap();
+ vector = new double[K];
+ for(int k=0; k embs; // Map[Entity_Id, Embedding]
+ Map alpha; // Map [Relation_Id, Alpha (matrix mean) ]
+ HashSet relIds; // Set of relation Ids
+ HashSet entityIds; // Set of Entity Ids
+ int K ;
+ static Random rand = new Random();
+
+ // Instantiating Embeddings for Data
+ public embeddings(data YData, int lK){
+ K = lK;
+ alpha = new HashMap();
+ embs = new HashMap();
+ for(Cell cell : YData.Data){
+ for(String entityId : cell.entity_ids){
+ if(!embs.containsKey(entityId))
+ embs.put(entityId, new embedding(K));
+
+ embs.get(entityId).addRelation(cell.relation_id);
+ }
+ }
+ System.out.println("Unique Entites in Database : " + embs.keySet().size());
+ computeAlpha(YData);
+ }
+
+ public embeddings(embeddings e){
+ this.alpha = e.alpha;
+ this.embs = e.embs;
+ this.relIds = e.relIds;
+ this.entityIds = e.entityIds;
+ this.K = e.K;
+ }
+
+ //Compute Alpha - Relation wise mean of values
+ public void computeAlpha(data D){
+ Map relSum = new HashMap(); // Sum of truth values in each relation
+ Map relCount = new HashMap(); // Count of truth values in each relation
+
+ for(Cell cell : D.trainData){
+ if(!relSum.containsKey(cell.relation_id)){
+ if(cell.truth == true)
+ relSum.put(cell.relation_id, 1);
+ else
+ relSum.put(cell.relation_id, 0);
+ relCount.put(cell.relation_id, 1);
+ }
+ else{
+ int sum = relSum.get(cell.relation_id);
+ int count = relCount.get(cell.relation_id);
+ if(cell.truth == true){
+ sum += 1;
+ }
+ count++;
+ relSum.put(cell.relation_id, sum);
+ relCount.put(cell.relation_id, count);
+ }
+
+ }
+
+ for(String relId : relSum.keySet()){
+ double a = ((double)relSum.get(relId))/relCount.get(relId);
+ a = Math.log((a / (1-a)));
+ //alpha.put(relId, a);
+ alpha.put(relId, rand.nextGaussian()*0.001);
+ }
+
+
+ }
+
+
+ // Dot Product of vectors in cell, but leave out one entity, For fixed k
+ public double coeffVector(Cell cell, String leaveEntity, int k){
+ double result=1.0;
+ for(String entityId : cell.entity_ids){
+ if(!entityId.equals(leaveEntity))
+ result *= embs.get(entityId).vector[k];
+ }
+ return result;
+ }
+
+ // Dot Product of vectors in cell, but leave out one entity, For fixed k
+ public double dot(Cell cell, boolean enableBias, int K, boolean ealpha, boolean onlyAlpha){
+ double result=0.0;
+
+ if(ealpha)
+ result += alpha.get(cell.relation_id);
+ if(!onlyAlpha){
+ if(enableBias){
+ for(String entityId : cell.entity_ids){
+ if(!embs.containsKey(entityId))
+ System.out.println("Entity not found : " + entityId);
+ if(!embs.get(entityId).bias.containsKey(cell.relation_id))
+ System.out.println("Relation Not found : " + cell.relation_id);
+ result += embs.get(entityId).bias.get(cell.relation_id);
+
+
+ }
+ }
+
+ for(int k = 0; k trainData = new ArrayList();
+ embeddings eBest = new embeddings(embedings);
+
+ Map> evalMap = Eval.getEvalMap(Data, eBest, "test");
+ Eval.printEval();
+ while(notConverged){
+
+ trainData.clear();
+ trainData.addAll(Data.trainData);
+ System.out.println("Train Data Original : " + trainData.size());
+ if(busWord){
+ ArrayList negSamples = Data.getNegativeSamples("b-word", busWordNegSamSize);
+ trainData.addAll(negSamples);
+ }
+ if(userWord){
+ ArrayList negSamples = Data.getNegativeSamples("u-word", userWordNegSamSize);
+ trainData.addAll(negSamples);
+ }
+ System.out.println("trainData : " + trainData.size());
+ codeTest.getMemoryDetails();
+ System.gc();
+ Collections.shuffle(trainData, seed); // Shuffle List of Training Data before each iteration of learning parameters
+ for(Cell cell : trainData){
+ update(cell, enableBias, embedings, ealpha, onlyAlpha);
+ }
+ //System.out.println("Train Data size :" + trainData.size());
+ epoch++;
+ System.out.print(epoch + " ");
+ System.gc();
+ if(epoch%5 == 0){
+ System.out.println("################## Epoch : " + epoch + " ############");
+ evalMap = Eval.getEvalMap(Data, embedings, "validation");
+ Eval.printEval();
+ double wf1 = evalMap.get("average").get(3), wacc = evalMap.get("average").get(0);
+ if(epoch == 5){
+ //maxF1 = wf1; maxAcc = wacc;
+ maxF1 = 0.0; maxAcc = 0.0;
+ dropfor = 0; nochange = 0;
+ bestEpoch = epoch;
+ eBest = new embeddings(embedings);
+ }
+ else{
+ //System.out.println("dF : " + dropfor + " nc : " + nochange + " maxF1 : " + maxF1 + " maxAcc : " + maxAcc);
+ if(wf1 > maxF1 || wacc > maxAcc){
+ bestEpoch = epoch;
+ eBest = new embeddings(embedings);
+ maxF1 = wf1;
+ maxAcc = wacc;
+ dropfor = 0; nochange = 0;
+ }
+ else{
+ if(wf1 == maxF1 || wacc == maxAcc)
+ nochange++;
+ else
+ dropfor++;
+ }
+ }
+
+ }
+ if(dropfor >= 3 || nochange >= 4) /// CONDITIONS for STOPPING
+ notConverged = false;
+ }
+ System.out.println("TRAINING CONVEREGED, BEST EPOCH = " + bestEpoch);
+ return eBest;
+ }
+
+
+}
diff --git a/Project/src/logisticCMF/writeDataToFile.java b/Project/src/logisticCMF/writeDataToFile.java
new file mode 100644
index 0000000..158173a
--- /dev/null
+++ b/Project/src/logisticCMF/writeDataToFile.java
@@ -0,0 +1,33 @@
+package logisticCMF;
+
+import java.util.*;
+import java.io.*;
+
+public class writeDataToFile {
+ public static void writePrediction(String fileName, String folder, data Data, embeddings e) throws IOException{
+ String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/pred-data/" + fileName;
+ BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress));
+ for(Cell cell : Data.testData){
+ double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha);
+ double sigmdot = learner.sigm(dot);
+ int truth = (cell.truth) ? 1 : 0;
+ String e1 = cell.entity_ids.get(0), e2 = cell.entity_ids.get(1);
+ bw.write(e1 + " :: " + e2 + " :: " + sigmdot + " :: " + truth +"\n") ;
+ }
+ bw.close();
+ }
+
+ public static void writeEmbeddings(String folder, String fileName, data Data, int entityNumber, embeddings e) throws IOException{
+ String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/embeddings/" + fileName;
+ BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress));
+ for(String entity : Data.entityIds.get(entityNumber)){
+ bw.write(entity + " :: ");
+ embedding em = e.embs.get(entity);
+ for(int i=0; i entityVector;
+
+ public EntityEmbeddings(String folder, String fileName, int K) throws IOException{
+ entityVector = new HashMap();
+ readEmbeddings(folder, fileName, K);
+ }
+
+ public void readEmbeddings(String folder, String fileName, int K) throws IOException{
+ String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/embeddings/" + fileName;
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ String line;
+ while((line = br.readLine()) != null){
+ String[] array = line.split("::");
+
+ // Reading entity from embeddings file and creating
+ String entity = array[0].trim();
+ if(!entityVector.containsKey(entity))
+ entityVector.put(entity, new double[K]);
+ else
+ System.out.println("Duplicate Entity. ERROR !!!!!!");
+
+ // Reading latent vector and storing in Map
+ String[] vec = array[1].trim().split(",");
+ for(int i=0; i> KNN(EntityEmbeddings a, EntityEmbeddings b, int K){
+ Map> kNN = new HashMap>();
+ for(String en1 : a.entityVector.keySet()){
+ double [] vec1 = a.entityVector.get(en1);
+ Map distances = new HashMap();
+ for(String en2 : b.entityVector.keySet()){
+ double[] vec2 = b.entityVector.get(en2);
+ distances.put(en2, Util.dotProd(vec1, vec2));
+ }
+
+ kNN.put(en1, Util.getKNN(distances, K));
+ }
+
+ return kNN;
+ }
+
+ public Map> getSimilarity(EntityEmbeddings ee1, EntityEmbeddings ee2){
+ Map> simMap = KNN(ee1, ee2, 20);
+
+ for(String en : simMap.keySet()){
+ System.out.print(en + " : ");
+ for(String nn : simMap.get(en)){
+ System.out.print(nn + ", ");
+ }
+ System.out.println();
+ }
+
+ return simMap;
+ }
+
+
+ public static void main(String [] args) throws IOException {
+ String folder = "AZ";
+ String evaluation = "HeldOut";
+ System.out.println("Start");
+
+ EntityEmbeddings attributes = new EntityEmbeddings(folder, evaluation+"/"+"attributes-bw", 30);
+ EntityEmbeddings words = new EntityEmbeddings(folder, evaluation+"/"+"words-bw", 30);
+ EntityEmbeddings categories = new EntityEmbeddings(folder, evaluation+"/"+"categories-bw", 30);
+ EntityEmbeddings business = new EntityEmbeddings(folder, evaluation+"/"+"business-bw", 30);
+
+ Similarity s = new Similarity();
+ s.getSimilarity(categories, words);
+
+
+
+
+ }
+
+}
diff --git a/Project/src/postProcessing/Util.java b/Project/src/postProcessing/Util.java
new file mode 100644
index 0000000..5ce7190
--- /dev/null
+++ b/Project/src/postProcessing/Util.java
@@ -0,0 +1,62 @@
+package postProcessing;
+import java.util.*;
+import java.util.Map.Entry;
+
+public class Util {
+
+ public static double sigm(double x){
+ return 1.0 / (1.0 + Math.exp(-x));
+ }
+
+ public static double norm(double [] vec){
+ double norm = 0.0;
+ for(int i=0; i getKNN(Map map, int K){
+ ArrayList top = new ArrayList();
+ Set> set = map.entrySet();
+ List> list = new ArrayList>(set);
+ Collections.sort( list, new Comparator>()
+ {
+ public int compare( Map.Entry o1, Map.Entry o2 )
+ {
+ return (o2.getValue()).compareTo( o1.getValue() );
+ }
+ } );
+
+ for(int i=0; i entry : list)
+ System.out.println(entry.getKey() + " ==== " + entry.getValue());*/
+ }
+
+ public static void writeSelectEmbeddingsToFile(String folder, String selectionEmbeddings, Map> simMap){
+
+ }
+
+}
diff --git a/Project/src/yelpDataProcessing/AttributeCategory.java b/Project/src/yelpDataProcessing/AttributeCategory.java
new file mode 100644
index 0000000..25147de
--- /dev/null
+++ b/Project/src/yelpDataProcessing/AttributeCategory.java
@@ -0,0 +1,311 @@
+package yelpDataProcessing;
+
+import java.util.*;
+import java.io.*;
+
+import javax.json.*;
+import javax.json.spi.*;
+
+/*
+ * Reads the yelp_dataset_restaurant json file and creates
+ * - resAtt.txt - The restaurant-attribute data that can directly be factorized
+ * - resCat.txt - The retaurant-category data that can be factorized but needs negative data
+ */
+
+
+public class AttributeCategory {
+
+ Map> attributes; // Map [Attribute, SubAttribute]
+ Set categories;
+ Map catCount; // Map [Category, Occurrence Count]
+ Map busCatCount; // Map [Restaurant, No. of Category Count]
+ Set catReduced; // Set [Categories] - Thresholded on terms of occurrence
+ Map> busCat; // Map [Res, List[Categories]]
+
+
+ public void printCategories() {
+ for(Object ob : categories){
+ System.out.println(ob);
+ }
+ }
+
+ public void printAttributes() {
+ for(String att : attributes.keySet()){
+ System.out.print(att + " : ");
+ for(String subatt : attributes.get(att).keySet())
+ System.out.print(subatt+", ");
+ System.out.print("\n");
+ }
+ System.out.println("Total Attributes : " + attributes.keySet().size());
+ }
+
+ private void buildCategorySet(String folder) throws IOException {
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ categories = new HashSet();
+ catCount = new HashMap();
+ busCatCount = new HashMap();
+ busCat = new HashMap>();
+
+ String line;
+ int count = 0;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"business\"")){
+ if(object.get("categories").getValueType().toString().equals("ARRAY")){
+ JsonArray cat = (JsonArray) object.get("categories");
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ busCatCount.put(bus_id.getString(), cat.size());
+ busCat.put(bus_id.getString(), new ArrayList());
+ for(JsonValue s : cat){
+ JsonString c = (JsonString) s;
+ String category = c.getString();
+ categories.add(category);
+ busCat.get(bus_id.getString()).add(category);
+ if(!catCount.containsKey(category))
+ catCount.put(category, 1);
+ else
+ catCount.put(category, catCount.get(category)+1 );
+ }
+ }
+ count++;
+ }
+ }
+ //return categories;
+ System.out.println("Businesses Read : " + count);
+ }
+
+ private void buildThresholdCatSet(int thresh){
+ catReduced = new HashSet();
+ for(String cat : catCount.keySet()){
+ if(catCount.get(cat) >= thresh){
+ catReduced.add(cat);
+ }
+ }
+ }
+
+ /* Reads the Yelp JSON Dataset and makes attributes hashmap */
+ private void readAttributes(String folder) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ attributes = new HashMap>();
+ String line;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ //if(object.containsKey("attributes")){
+ if(object.get("type").toString().equals("\"business\"")){
+ if(object.get("attributes").getValueType().toString() == "OBJECT"){
+ JsonObject attributeObject = (JsonObject) object.getJsonObject("attributes");
+ for(String key : attributeObject.keySet()){
+ if(!attributes.containsKey(key)){ // To Make Attributes Keys Set
+ attributes.put(key, new HashMap());
+ getValueSet(attributeObject.get(key), key, attributes);
+ }
+ else
+ getValueSet(attributeObject.get(key), key, attributes);
+ }
+ }
+ }
+ }
+ br.close();
+ }
+
+ /* Creates a dataset with business Id and values for attributes and stores them in a file. Uses the attributes hashmap. */
+ private void buildBusiness_AttributeDataset(String folder) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/data/" + folder + "/busAtt.txt"));
+ String line;
+ int count = 0;
+
+ while(( (line = br.readLine()) != null) && count < 42151 ){
+ StringBuilder str = new StringBuilder();
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ bw.write("business_id : "+bus_id.getString()+"\n");
+
+ if(object.get("attributes").getValueType().toString() == "OBJECT"){
+ JsonObject attributeObject = (JsonObject) object.getJsonObject("attributes");
+ navigateObjectForValues(attributeObject, null, null, str);
+ //System.out.println(str.toString());
+ bw.write(str.toString()+"\n\n");
+
+ }
+ count++;
+ }
+ System.out.println("No. of Business : "+count);
+ br.close();
+ bw.close();
+ }
+
+ /* */
+ public void navigateObjectForValues(JsonValue tree, String key, String prevKey, StringBuilder str) {
+ switch(tree.getValueType()) {
+ case OBJECT:
+ //System.out.println("OBJECT");
+ JsonObject object = (JsonObject) tree;
+ for (String name : object.keySet())
+ navigateObjectForValues(object.get(name), name, key, str);
+ break;
+ case ARRAY:
+ break;
+ case STRING:
+ JsonString st = (JsonString) tree;
+ if (key!= null){
+ if(prevKey != null){
+ //System.out.println(prevKey+"_" + key + "_"+st.getString()+" : "+1);
+ str.append(prevKey+"_" + key + "_"+st.getString()+" : "+1 + "\n");
+ if(attributes.get(key).keySet().size() != 0)
+ for(String s: attributes.get(key).keySet())
+ if(!s.equals(st.getString()))
+ //System.out.println(prevKey+"_" + key + "_"+s+" : "+0);
+ str.append(prevKey+"_" + key + "_"+s+" : "+0+"\n");
+ }
+ else{
+ //System.out.println(key + "_"+st.getString()+" : "+1);
+ str.append(key + "_"+st.getString()+" : "+1+"\n");
+
+ if(attributes.get(key).keySet().size() != 0)
+ for(String s: attributes.get(key).keySet())
+ if(!s.equals(st.getString()))
+ //System.out.println(key + "_"+s+" : "+0);
+ str.append(key + "_"+s+" : "+0+"\n");
+ }
+ }
+
+
+ break;
+ case NUMBER:
+ /*if (key!= null)
+ if(prevKey != null)
+ System.out.print(prevKey+"_" + key + " : ");
+ else
+ System.out.print(key + " : ");
+ JsonNumber num = (JsonNumber) tree;
+ System.out.println(num.toString());
+ */
+ break;
+ case TRUE:
+ if (key!= null)
+ if(prevKey != null)
+ //System.out.print(prevKey+"_" + key + " : ");
+ str.append(prevKey+"_" + key + " : "+1+"\n");
+ else
+// //System.out.print(key + " : " + 1);
+ str.append(key + " : " + 1+"\n");
+ //System.out.println(1);
+ break;
+ case FALSE:
+ case NULL:
+ if (key!= null)
+ if(prevKey != null)
+ //System.out.print(prevKey+"_" + key + " : ");
+ str.append(prevKey+"_"+key + " : " + 0+"\n");
+ else
+ //System.out.print(key + " : ");
+ str.append(key + " : " + 0+"\n");
+ //System.out.println(0);
+ break;
+ }
+ }
+
+ private void getValueSet(JsonValue attribute, String attributeName, Map> attributes){
+ switch(attribute.getValueType()){
+
+ case STRING:
+ JsonString st = (JsonString) attribute;
+ if(!attributes.get(attributeName).containsKey(st))
+ attributes.get(attributeName).put(st.getString(), 1);
+ break;
+
+ case NUMBER:
+ if(!attributes.get(attributeName).containsKey("NUMBER"))
+ attributes.get(attributeName).put("NUMBER", 1);
+ break;
+ case TRUE:
+ if(!attributes.get(attributeName).containsKey("TRUE"))
+ attributes.get(attributeName).put("TRUE", 1);
+ break;
+ case FALSE:
+ if(!attributes.get(attributeName).containsKey("FALSE"))
+ attributes.get(attributeName).put("FALSE", 1);
+ break;
+ case OBJECT:
+ JsonObject att = (JsonObject) attribute;
+ for(String subAttributeName : att.keySet())
+ if(!attributes.get(attributeName).containsKey(subAttributeName))
+ attributes.get(attributeName).put(subAttributeName, 1);
+ break;
+ }
+ }
+
+ // Writes Restaurant : Category1, Category2, ... - to file resCat.txt
+ private void writeResCatToFile(String folder) throws IOException{
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/data/" + folder +"/busCat.txt"));
+ for(String res : busCat.keySet()){
+ bw.write(res + " : ");
+ for(String cat : busCat.get(res))
+ bw.write(cat + "; ");
+ bw.write("\n");
+ }
+ bw.close();
+ }
+
+ private void makeCitySet(String dataset) throws IOException {
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json"+dataset));
+ String line; int count = 0;
+ Map states = new HashMap();
+ while(( (line = br.readLine()) != null) && count < 42151){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+
+ JsonString city = (JsonString) object.get("state");
+ String ci = city.getString();
+ if(!states.containsKey(ci))
+ states.put(ci, 0);
+ int sc = states.get(ci);
+ sc++;
+ states.put(ci, sc);
+
+ count++;
+ }
+
+ for(String state : states.keySet()){
+ System.out.println(state + " : " + states.get(state));
+ }
+ System.out.println(states.keySet().size());
+ }
+
+
+ public static void main(String[] args) throws Exception{
+ System.out.println("Hello");
+ String State = "NV";
+
+ AttributeCategory data = new AttributeCategory();
+
+
+ // Read file for attributes and write to a file
+ data.readAttributes(State);
+ //data.printAttributes();
+ //data.buildBusiness_AttributeDataset(State);
+
+
+ data.buildCategorySet(State);
+ data.printCategories();
+ //data.writeResCatToFile(State);
+
+ int c = 0;
+
+
+ //data.makeCitySet("yelp_dataset");
+
+
+
+
+
+ }
+
+
+}
diff --git a/Project/src/yelpDataProcessing/ProcessYelpJson.java b/Project/src/yelpDataProcessing/ProcessYelpJson.java
new file mode 100644
index 0000000..94f52ad
--- /dev/null
+++ b/Project/src/yelpDataProcessing/ProcessYelpJson.java
@@ -0,0 +1,220 @@
+package yelpDataProcessing;
+
+import java.io.*;
+import java.util.*;
+
+import javax.json.Json;
+import javax.json.JsonArray;
+import javax.json.JsonNumber;
+import javax.json.JsonObject;
+import javax.json.JsonReader;
+import javax.json.JsonString;
+import javax.json.JsonValue;
+import javax.json.stream.JsonParsingException;
+
+
+/*
+ * Processes Yelp Dataset Json and creates different files
+ * - yelp_business - Business dataset in json format.
+ * - yelp_dataset_retaurant - Restaurant Json data from Yelp
+ * -
+ *
+ *
+ */
+
+public class ProcessYelpJson {
+
+ Set busIds = new HashSet();
+
+ // Create file - yelp_reviews - in dataset/json
+ public void createCompleteReviewJson(String yelpDataset) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+yelpDataset));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/complete/reviews"));
+ String line;
+ int count = 0; int cr = 0;
+ while( ((line = br.readLine()) != null) && count < 1199227){
+ if(count <= 73770)
+ count++;
+ else{
+ cr++;
+ bw.write(line+"\n");
+ count++;
+ }
+ }
+ bw.close();
+ br.close();
+ System.out.println("Reviews : " + cr);
+
+
+ }
+
+ // Creates file - yelp_business - in dataset/json folder
+ public void createCompleteBusinessJson(String yelpDataset) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+yelpDataset));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/complete/business"));
+ String line; int count = 0;
+
+ while(( (line = br.readLine()) != null) && count < 42151 ){
+ bw.write(line + "\n");
+ count++;
+ }
+ bw.close();
+ System.out.println("No. of Businesses : " + count);
+
+ }
+
+ public void createStateBusinessJson(String folder, String state) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/"+state+"/business"));
+
+ String line;
+ int count = 0; int cr = 0;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"business\"")){
+ JsonValue s = object.get("state");
+ JsonString st = (JsonString) s;
+ String place = st.getString();
+ if(place.equals(state)){
+ bw.write(line + "\n");
+ cr++;
+ }
+ }
+ }
+ bw.close();
+ System.out.println("No. of Businesses in "+state + " : " + cr);
+ }
+
+
+ // Creates file - yelp_dataset_restaurant - in dataset/json folder
+ public void createRestaurantJson(String folder) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/Restaurant/business"));
+
+ String line;
+ int count = 0; int cr = 0;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"business\"")){
+ if(object.get("categories").getValueType().toString().equals("ARRAY")){
+ JsonArray cat = (JsonArray) object.get("categories");
+ for(JsonValue s : cat){
+ JsonString c = (JsonString) s;
+ String category = c.getString();
+ if(category.equals("Restaurants")){
+ bw.write(line + "\n");
+ cr ++;
+ }
+ }
+ }
+ }
+ }
+ bw.close();
+ System.out.println("No. of Restaurants : " + cr);
+ }
+
+
+ // Create file - yelp_reviews_restaurant - in dataset/json
+ public void createBusReviewJson(String folder_complete, String folder) throws IOException{
+ makeBusIdSet(folder);
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder_complete+"/reviews"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/reviews"));
+ int count = 0;
+ String line;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"review\"")){
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ String bid = bus_id.getString();
+ if(busIds.contains(bid)){
+ bw.write(line+"\n");
+ count++;
+ }
+ }
+ }
+ bw.close();
+ System.out.println("Reviews Count : " + count);
+
+ }
+
+ // Make a Set of Res-Ids to extract reviews
+ public void makeBusIdSet(String folder) throws IOException{
+ busIds = new HashSet();
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business"));
+ String line;
+ while(( (line = br.readLine()) != null)){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"business\"")){
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ String bid = bus_id.getString();
+ busIds.add(bid);
+ }
+ }
+ System.out.println("Size of resIds set :" + busIds.size());
+ }
+
+ public static void putReviewDatatoFile(String folder) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/reviews"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/../Dataset/data/"+folder+"/reviews.txt"));
+ String line;
+ int count = 0;
+ while((line = br.readLine()) != null){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ String bId = bus_id.getString();
+ bw.write("bus_id : " + bId + "\n");
+
+ JsonValue u_id = object.get("user_id");
+ JsonString user_id = (JsonString) u_id;
+ String uId = user_id.getString();
+ bw.write("user_id : " + uId + "\n");
+
+
+ JsonValue star = object.get("stars");
+ JsonNumber s = (JsonNumber) star;
+ Double st = s.doubleValue();
+ bw.write("star : " + st + "\n");
+
+ JsonValue t = object.get("text");
+ JsonString te = (JsonString) t;
+ String text = te.getString();
+ bw.write("text: " + t + "\n\n");
+
+ count++;
+ }
+ br.close();
+ bw.close();
+ System.out.println("Reviews Written : " + count);
+ }
+
+ public static void main(String [] args) throws Exception{
+ String yelpDataset = "yelp_dataset";
+ String State = "NV";
+
+ ProcessYelpJson yelp = new ProcessYelpJson();
+
+
+ //yelp.createCompleteBusinessJson(yelpDataset);
+
+ //yelp.createRestaurantJson("complete");
+ yelp.createStateBusinessJson("complete", State);
+
+ //yelp.createCompleteReviewJson(yelpDataset);
+
+ yelp.createBusReviewJson("complete", State);
+
+ yelp.putReviewDatatoFile(State);
+
+
+
+ }
+}
diff --git a/Project/src/yelpDataProcessing/reviewData.java b/Project/src/yelpDataProcessing/reviewData.java
new file mode 100644
index 0000000..f9ac3e0
--- /dev/null
+++ b/Project/src/yelpDataProcessing/reviewData.java
@@ -0,0 +1,141 @@
+package yelpDataProcessing;
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.*;
+
+import logisticCMF.Cell;
+
+
+public class reviewData {
+
+ public static Map wordCount = new HashMap();
+ public static Map> resWord = new HashMap>();
+ public static Map> userWord = new HashMap>();
+ public static HashSet words = new HashSet();
+
+ public static void addWordInMap(String word){
+ if(!wordCount.containsKey(word))
+ wordCount.put(word, 0);
+ int wcount = wordCount.get(word);
+ wcount++;
+ wordCount.put(word, wcount);
+ }
+
+ public static void readData(String reviewData) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(reviewData));
+ String line; String bId = null; String userId = null; int countl = 0, countR = 0;
+
+ while((line = br.readLine()) != null){
+ countl++;
+ String[] array = line.split(":");
+
+ if( array[0].trim().equals("user_id")){
+ countR++;
+ userId = array[1].trim();
+ if(!userWord.containsKey(userId))
+ userWord.put(userId, new HashSet());
+ }
+
+ if( array[0].trim().equals("bus_id")){
+ bId = array[1].trim();
+ if(!resWord.containsKey(bId))
+ resWord.put(bId, new HashSet());
+ }
+
+ if(array[0].trim().equals("text")){
+ String [] tokens = array[1].trim().split(" ");
+ if(tokens.length > 0){
+ for(String word : tokens){
+ word = word.trim();
+ if(word.length() >= 3){
+ addWordInMap(word);
+ resWord.get(bId).add(word);
+ userWord.get(userId).add(word);
+ }
+ }
+ }
+ }
+ if(countl % 100000 == 0)
+ System.out.println("line : "+countl);
+ }
+ System.out.println("Total No. of Reviews : " + countR);
+ }
+
+ public static void getMapStats(Map> enWord){
+ int potentialResWordCells = 0;
+ int min=100000000;
+ System.out.println("No. of entities = " + enWord.keySet().size());
+ System.out.println("No. of total words in Vocab = " + words.size());
+
+ for(String en : enWord.keySet()){
+ min = (enWord.get(en).size() < min) ? enWord.get(en).size() : min;
+ for(String word : enWord.get(en)){
+ potentialResWordCells++;
+ }
+
+ }
+ System.out.println("Potential Entity-Word Cells : " + potentialResWordCells);
+ System.out.println("Min No. of Words in Entity : " + min);
+ }
+
+ // Remove words from Map[Entity -> Set[words]] that occur few times in dictionary. If Set of words for entity go empty, remove Entity from Map.
+ public static void pruneEntityWordMap(Map> enWord){
+ Iterator it = enWord.keySet().iterator();
+ while(it.hasNext()){
+ String en = it.next();
+ Iterator itr = enWord.get(en).iterator();
+ while(itr.hasNext()){
+ String word = itr.next();
+ if(!words.contains(word))
+ itr.remove();
+ }
+ if(enWord.get(en).size() == 0)
+ it.remove();
+ }
+ }
+
+
+ // Make a Array of Words that have frequency above the given threshold.
+ public static void makePrunedWordList(int occThresh){
+ words = new HashSet(); int count = 0;
+ for(String word : wordCount.keySet()){
+ if(wordCount.get(word) > occThresh){
+ count++;
+ words.add(word);
+ }
+ }
+ System.out.println("Words with greater than occurence of " + occThresh + " : " + words.size());
+ }
+
+ public static int countLines(String file) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(file));
+ String line; int count = 0;
+ while((line = br.readLine()) != null){
+ count++;
+ }
+ return count;
+ }
+
+
+ public static void main(String [] args) throws IOException{
+
+ String reviewData = System.getProperty("user.dir")+"/../Dataset/data/ON/reviews_textProc.txt";
+ readData(reviewData);
+
+ int occThresh = 1;
+ System.out.println("Total Words in Review Data : " + wordCount.keySet().size());
+ int count = 0;
+
+ /*for(occThresh = 0; occThresh <= 50; occThresh++){
+ makePrunedWordList(occThresh);
+ }*/
+
+ makePrunedWordList(4);
+ pruneEntityWordMap(resWord);
+ getMapStats(resWord);
+
+ //pruneEntityWordMap(userWord, occThresh);
+ //getMapStats(userWord);
+ }
+}
diff --git a/Project/src/yelpDataProcessing/reviewJson.java b/Project/src/yelpDataProcessing/reviewJson.java
new file mode 100644
index 0000000..8add377
--- /dev/null
+++ b/Project/src/yelpDataProcessing/reviewJson.java
@@ -0,0 +1,131 @@
+package yelpDataProcessing;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.StringReader;
+import java.util.*;
+
+import javax.json.Json;
+import javax.json.JsonArray;
+import javax.json.JsonNumber;
+import javax.json.JsonObject;
+import javax.json.JsonReader;
+import javax.json.JsonString;
+import javax.json.JsonValue;
+
+import logisticCMF.Cell;
+
+public class reviewJson {
+
+ public static HashSet resIds = new HashSet(); // Set - [Restaurant Ids]
+
+ // Reads Yelp Dataset Json - extracts only Review Jsons
+ public static void extractReviewJson(String fileAddress) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/Data/json/yelp_reviews.txt"));
+ String line;
+ int count=1;
+ while( ((line = br.readLine()) != null) && count < 1199228){
+ if(count <= 73770)
+ count++;
+ else{
+ bw.write(line+"\n");
+ count++;
+ }
+ }
+ bw.close();
+ br.close();
+ }
+
+ // Makes resIds Set - Then extracts json objects for restaurant reviews and stores them in file
+ public static void extractResReviewJson() throws IOException{
+ makeResIds();
+ BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir") + "/Data/json/yelp_reviews.txt"));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/Data/json/yelp_reviews_restaurants.txt"));
+ String line;
+ int count = 0;
+ while(( (line = br.readLine()) != null) ){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+ if(object.get("type").toString().equals("\"review\"")){
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ String bId = bus_id.getString();
+ if(resIds.contains(bId)){
+ bw.write(line.trim() + "\n");
+ count++;
+ }
+ }
+ }
+
+ System.out.println("Count : " + count);
+ br.close();
+ bw.close();
+ }
+
+ // Reads file that contains Restaurant Ids and stores in Set - resIds
+ public static void makeResIds() throws IOException{
+ String fA = System.getProperty("user.dir") + "/Data/new/restaurant_ids";
+ BufferedReader br = new BufferedReader(new FileReader(fA));
+ String line;
+ while((line = br.readLine()) != null){
+ resIds.add(line.trim());
+ }
+ br.close();
+ System.out.println("No. of res : " + resIds.size());
+ }
+
+ public static void putReviewDatatoFile(String fileAddress ) throws IOException{
+ BufferedReader br = new BufferedReader(new FileReader(fileAddress));
+ BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/../Data/new/res_review_data.txt"));
+ String line;
+ int count = 0; String relation = "restaurant-user";
+ while((line = br.readLine()) != null){
+ JsonReader reader = Json.createReader(new StringReader(line));
+ JsonObject object = reader.readObject();
+
+ JsonValue b_id = object.get("business_id");
+ JsonString bus_id = (JsonString) b_id;
+ String bId = bus_id.getString();
+ bw.write("bus_id : " + bId + "\n");
+
+ JsonValue u_id = object.get("user_id");
+ JsonString user_id = (JsonString) u_id;
+ String uId = user_id.getString();
+ bw.write("user_id : " + uId + "\n");
+
+
+ JsonValue star = object.get("stars");
+ JsonNumber s = (JsonNumber) star;
+ Double st = s.doubleValue();
+ bw.write("star : " + st + "\n");
+
+ JsonValue t = object.get("text");
+ JsonString te = (JsonString) t;
+ String text = te.getString();
+ bw.write("text: " + t + "\n\n");
+
+ count++;
+ }
+ br.close();
+ bw.close();
+ System.out.println(relation + " : " + count);
+ }
+
+
+
+ public static void main(String [] args) throws IOException{
+ //convertResAttLogisticReadableFile(System.getProperty("user.dir")+"/Data/yelp_dataset_restaurant_att");
+ //String fileAddress = System.getProperty("user.dir") + "/../Dataset/yelp_dataset";
+ String fileAddress = System.getProperty("user.dir") + "/../Data/json/yelp_reviews_restaurants.txt";
+ //putReviewDatatoFile(fileAddress);
+ //extractReviewJson(fileAddress);
+ //extractResReviewJson();
+
+ }
+
+}
diff --git a/PythonScript/clean_text.py b/PythonScript/clean_text.py
new file mode 100644
index 0000000..c247d04
--- /dev/null
+++ b/PythonScript/clean_text.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+## Inputs review file in format
+
+# businessId : asnhdjkasld
+# userId : kjhdjkfadfjbasdbjfbasdb
+# stars : 4.0
+# text : I was amazed with the quality of the food
+
+##### OUTPUTS the file in the same order, but tokenizes, removes punctuatuations, removes stop words and stems the words before outputing. Also each word in review occurs once in output.
+
+
+from nltk.corpus import stopwords
+from nltk.tokenize import word_tokenize
+import re
+import string
+from nltk.stem import PorterStemmer
+
+stemmer = PorterStemmer()
+
+def getTokens(doc):
+ doc = re.sub("\d+", "", doc)
+
+ tokenized = word_tokenize(doc.decode('utf-8'))
+
+ regex = re.compile('[%s]' % re.escape(string.punctuation)) #see documentation here: http://docs.python.org/2/library/string.html
+
+ tokenized_no_punctuation = []
+
+ new_review = []
+
+
+ ## REMOVING PUNCTUATION
+ #for token in tokenized:
+ # new_token = regex.sub(u'', token.lower())
+ # if not new_token == u'':
+ # tokenized_no_punctuation.append(new_token)
+
+ tokenized_no_punctuation = [re.sub(r'[^A-Za-z0-9]+', '', x.lower()) for x in tokenized]
+ tokenized_no_punctuation = [s for s in tokenized_no_punctuation if (len(s)>1)]
+
+ #print tokenized_no_punctuation
+
+ token_no_stop = []
+ ## REMOVING STOP WORDS
+ for word in tokenized_no_punctuation:
+ if not word in stopwords.words('english'):
+ try:
+ word = stemmer.stem(word.encode('utf-8'))
+ except UnicodeDecodeError:
+ word = word #.encode('utf-8')
+ token_no_stop.append(word.encode('utf-8'))
+
+
+ return token_no_stop
+
+
+
+fin = open('reviews.txt', 'r')
+fout = open('reviews_textProc.txt', 'w')
+count = 0
+for line in fin:
+# print count
+ if(count %100000 == 0):
+ print count
+ if(line.strip().split(':')[0] == 'text'):
+ tokens = getTokens(line.strip().split(':')[1])
+ tokenSet = set()
+ for i in tokens:
+ if (len(i) >= 3):
+ tokenSet.add(i)
+ fout.write("text : ")
+ fout.write(" ".join(tokenSet))
+ fout.write("\n\n")
+ else:
+ fout.write(line)
+ count = count + 1
diff --git a/PythonScript/clean_text.py~ b/PythonScript/clean_text.py~
new file mode 100644
index 0000000..c247d04
--- /dev/null
+++ b/PythonScript/clean_text.py~
@@ -0,0 +1,78 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+## Inputs review file in format
+
+# businessId : asnhdjkasld
+# userId : kjhdjkfadfjbasdbjfbasdb
+# stars : 4.0
+# text : I was amazed with the quality of the food
+
+##### OUTPUTS the file in the same order, but tokenizes, removes punctuatuations, removes stop words and stems the words before outputing. Also each word in review occurs once in output.
+
+
+from nltk.corpus import stopwords
+from nltk.tokenize import word_tokenize
+import re
+import string
+from nltk.stem import PorterStemmer
+
+stemmer = PorterStemmer()
+
+def getTokens(doc):
+ doc = re.sub("\d+", "", doc)
+
+ tokenized = word_tokenize(doc.decode('utf-8'))
+
+ regex = re.compile('[%s]' % re.escape(string.punctuation)) #see documentation here: http://docs.python.org/2/library/string.html
+
+ tokenized_no_punctuation = []
+
+ new_review = []
+
+
+ ## REMOVING PUNCTUATION
+ #for token in tokenized:
+ # new_token = regex.sub(u'', token.lower())
+ # if not new_token == u'':
+ # tokenized_no_punctuation.append(new_token)
+
+ tokenized_no_punctuation = [re.sub(r'[^A-Za-z0-9]+', '', x.lower()) for x in tokenized]
+ tokenized_no_punctuation = [s for s in tokenized_no_punctuation if (len(s)>1)]
+
+ #print tokenized_no_punctuation
+
+ token_no_stop = []
+ ## REMOVING STOP WORDS
+ for word in tokenized_no_punctuation:
+ if not word in stopwords.words('english'):
+ try:
+ word = stemmer.stem(word.encode('utf-8'))
+ except UnicodeDecodeError:
+ word = word #.encode('utf-8')
+ token_no_stop.append(word.encode('utf-8'))
+
+
+ return token_no_stop
+
+
+
+fin = open('reviews.txt', 'r')
+fout = open('reviews_textProc.txt', 'w')
+count = 0
+for line in fin:
+# print count
+ if(count %100000 == 0):
+ print count
+ if(line.strip().split(':')[0] == 'text'):
+ tokens = getTokens(line.strip().split(':')[1])
+ tokenSet = set()
+ for i in tokens:
+ if (len(i) >= 3):
+ tokenSet.add(i)
+ fout.write("text : ")
+ fout.write(" ".join(tokenSet))
+ fout.write("\n\n")
+ else:
+ fout.write(line)
+ count = count + 1
diff --git a/PythonScript/combineAllPredData.py b/PythonScript/combineAllPredData.py
new file mode 100644
index 0000000..e652ecb
--- /dev/null
+++ b/PythonScript/combineAllPredData.py
@@ -0,0 +1,29 @@
+# From Logisitc CMF : python PythonScript/combineAllPredData.py Embeddings_Prediction_Data HeldOut
+
+import sys
+import os
+from os import walk
+
+embeddingsPath = sys.argv[1];
+evToMerge = sys.argv[2];
+print embeddingsPath+"/All/pred-data/"+evToMerge
+if not os.path.exists(embeddingsPath+"/All/pred-data/"+evToMerge):
+ os.makedirs(embeddingsPath+"/All/pred-data/"+evToMerge)
+
+foldersToMerge = ['AZ', 'NV', 'WI', 'EDH']
+
+print embeddingsPath+"/WI/pred-data/"+evToMerge
+files = []
+for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToMerge):
+ files.extend(filenames)
+ break
+
+print files;
+
+for folder in foldersToMerge:
+ for filename in files:
+ with open(embeddingsPath+"/All/pred-data/"+evToMerge+"/"+filename, 'a') as outfile:
+ with open(embeddingsPath+"/"+folder+"/pred-data/"+evToMerge+"/"+filename) as infile:
+ for line in infile:
+ outfile.write(line)
+
diff --git a/PythonScript/combineAllPredData.py~ b/PythonScript/combineAllPredData.py~
new file mode 100644
index 0000000..e652ecb
--- /dev/null
+++ b/PythonScript/combineAllPredData.py~
@@ -0,0 +1,29 @@
+# From Logisitc CMF : python PythonScript/combineAllPredData.py Embeddings_Prediction_Data HeldOut
+
+import sys
+import os
+from os import walk
+
+embeddingsPath = sys.argv[1];
+evToMerge = sys.argv[2];
+print embeddingsPath+"/All/pred-data/"+evToMerge
+if not os.path.exists(embeddingsPath+"/All/pred-data/"+evToMerge):
+ os.makedirs(embeddingsPath+"/All/pred-data/"+evToMerge)
+
+foldersToMerge = ['AZ', 'NV', 'WI', 'EDH']
+
+print embeddingsPath+"/WI/pred-data/"+evToMerge
+files = []
+for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToMerge):
+ files.extend(filenames)
+ break
+
+print files;
+
+for folder in foldersToMerge:
+ for filename in files:
+ with open(embeddingsPath+"/All/pred-data/"+evToMerge+"/"+filename, 'a') as outfile:
+ with open(embeddingsPath+"/"+folder+"/pred-data/"+evToMerge+"/"+filename) as infile:
+ for line in infile:
+ outfile.write(line)
+
diff --git a/PythonScript/getPRCurveData.py b/PythonScript/getPRCurveData.py
new file mode 100644
index 0000000..8f1ee6d
--- /dev/null
+++ b/PythonScript/getPRCurveData.py
@@ -0,0 +1,28 @@
+# From Logistic_CMF : python PythonScript/getPRCurveData.py Embeddings_Prediction_Data All HeldOut
+
+from os import walk
+import sys
+import os
+
+embeddingsPath = sys.argv[1]
+folderToTest = sys.argv[2]
+evToTest = sys.argv[3]
+
+path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest
+
+files = []
+for (dirpath, dirnames, filenames) in walk(path):
+ files.extend(filenames)
+ break
+
+for fileName in files:
+ f = open(path+"/"+fileName, 'r')
+ if not os.path.exists(path+"/PRCurve"):
+ os.makedirs(path+"/PRCurve")
+ o = open(path+"/PRCurve/"+fileName, 'w')
+ for line in f:
+ line.strip();
+ a = line.split("::")
+ o.write(a[2].strip() + "\t" + a[3].strip() + "\n");
+ o.close();
+ f.close();
diff --git a/PythonScript/getPRCurveData.py~ b/PythonScript/getPRCurveData.py~
new file mode 100644
index 0000000..8f1ee6d
--- /dev/null
+++ b/PythonScript/getPRCurveData.py~
@@ -0,0 +1,28 @@
+# From Logistic_CMF : python PythonScript/getPRCurveData.py Embeddings_Prediction_Data All HeldOut
+
+from os import walk
+import sys
+import os
+
+embeddingsPath = sys.argv[1]
+folderToTest = sys.argv[2]
+evToTest = sys.argv[3]
+
+path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest
+
+files = []
+for (dirpath, dirnames, filenames) in walk(path):
+ files.extend(filenames)
+ break
+
+for fileName in files:
+ f = open(path+"/"+fileName, 'r')
+ if not os.path.exists(path+"/PRCurve"):
+ os.makedirs(path+"/PRCurve")
+ o = open(path+"/PRCurve/"+fileName, 'w')
+ for line in f:
+ line.strip();
+ a = line.split("::")
+ o.write(a[2].strip() + "\t" + a[3].strip() + "\n");
+ o.close();
+ f.close();
diff --git a/PythonScript/getPRF.py b/PythonScript/getPRF.py
new file mode 100644
index 0000000..c51be2c
--- /dev/null
+++ b/PythonScript/getPRF.py
@@ -0,0 +1,22 @@
+# From Logistic CMF : python PythonScript/getPRF.py Embeddings_Prediction_Data WI HeldOut
+
+import prf;
+from os import walk
+import sys
+
+embeddingsPath = sys.argv[1]
+folderToTest = sys.argv[2]
+evToTest = sys.argv[3]
+
+path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest
+print path
+fs = []
+for (dirpath, dirnames, filenames) in walk(path):
+ fs.extend(filenames)
+ break
+fs.sort()
+for fileName in fs:
+ print fileName
+ print prf.getPRF(path+"/"+fileName)
+
+
diff --git a/PythonScript/getPRF.py~ b/PythonScript/getPRF.py~
new file mode 100644
index 0000000..c51be2c
--- /dev/null
+++ b/PythonScript/getPRF.py~
@@ -0,0 +1,22 @@
+# From Logistic CMF : python PythonScript/getPRF.py Embeddings_Prediction_Data WI HeldOut
+
+import prf;
+from os import walk
+import sys
+
+embeddingsPath = sys.argv[1]
+folderToTest = sys.argv[2]
+evToTest = sys.argv[3]
+
+path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest
+print path
+fs = []
+for (dirpath, dirnames, filenames) in walk(path):
+ fs.extend(filenames)
+ break
+fs.sort()
+for fileName in fs:
+ print fileName
+ print prf.getPRF(path+"/"+fileName)
+
+
diff --git a/PythonScript/prf.py b/PythonScript/prf.py
new file mode 100644
index 0000000..2373821
--- /dev/null
+++ b/PythonScript/prf.py
@@ -0,0 +1,39 @@
+## Call getPRF(fileName) to get [P, R, F]
+
+import numpy as np;
+from sklearn.metrics import precision_recall_fscore_support as prf
+
+def readFile(fileName):
+ y_pred = []
+ y_true = []
+ y_pred_true = []
+ f = open(fileName, 'r');
+ for line in f:
+ line.strip();
+ a = line.split("::")
+ if(float(a[2].strip()) >= 0.5):
+ pred = 1;
+ y_pred.append(pred);
+ else:
+ pred = 0;
+ y_pred.append(pred);
+
+ true = float(a[3].strip());
+ y_true.append(true);
+ y_pred_true.append(y_pred)
+ y_pred_true.append(y_true)
+ return y_pred_true
+
+def getPRF(fileName):
+ y_p_t = readFile(fileName)
+ y_pred = y_p_t[0];
+ y_true = y_p_t[1];
+ acc = prf(y_true, y_pred, average = 'micro');
+ p = round(acc[0]*100, 1);
+ r = round(acc[1]*100, 1);
+ f = round(acc[2]*100, 1);
+ return np.array([p,r,f])
+
+#fileName = "A-A"
+#print getPRF(fileName)
+
diff --git a/PythonScript/prf.pyc b/PythonScript/prf.pyc
new file mode 100644
index 0000000..c908aca
Binary files /dev/null and b/PythonScript/prf.pyc differ
diff --git a/PythonScript/prf.py~ b/PythonScript/prf.py~
new file mode 100644
index 0000000..2373821
--- /dev/null
+++ b/PythonScript/prf.py~
@@ -0,0 +1,39 @@
+## Call getPRF(fileName) to get [P, R, F]
+
+import numpy as np;
+from sklearn.metrics import precision_recall_fscore_support as prf
+
+def readFile(fileName):
+ y_pred = []
+ y_true = []
+ y_pred_true = []
+ f = open(fileName, 'r');
+ for line in f:
+ line.strip();
+ a = line.split("::")
+ if(float(a[2].strip()) >= 0.5):
+ pred = 1;
+ y_pred.append(pred);
+ else:
+ pred = 0;
+ y_pred.append(pred);
+
+ true = float(a[3].strip());
+ y_true.append(true);
+ y_pred_true.append(y_pred)
+ y_pred_true.append(y_true)
+ return y_pred_true
+
+def getPRF(fileName):
+ y_p_t = readFile(fileName)
+ y_pred = y_p_t[0];
+ y_true = y_p_t[1];
+ acc = prf(y_true, y_pred, average = 'micro');
+ p = round(acc[0]*100, 1);
+ r = round(acc[1]*100, 1);
+ f = round(acc[2]*100, 1);
+ return np.array([p,r,f])
+
+#fileName = "A-A"
+#print getPRF(fileName)
+
diff --git a/PythonScript/writePRFTable.py b/PythonScript/writePRFTable.py
new file mode 100644
index 0000000..e89284e
--- /dev/null
+++ b/PythonScript/writePRFTable.py
@@ -0,0 +1,36 @@
+# From Logisitc CMF : python PythonScript/writePRFTable.py Embeddings_Prediction_Data A HeldOut
+
+import sys
+import os
+from os import walk
+import prf
+
+embeddingsPath = sys.argv[1]
+relation = sys.argv[2]
+evToTest = sys.argv[3]
+
+writePath = embeddingsPath+"/Tables/"+relation+"-"+evToTest
+out = open(writePath, 'w')
+
+foldersToWrite = ['AZ', 'NV', 'WI', 'EDH', 'All']
+filesToWrite = []
+for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToTest):
+ for filename in filenames:
+ r = filename.split("-")[0].strip()
+ if r == relation:
+ filesToWrite.append(filename)
+ break
+
+for model in filesToWrite:
+ out.write("\\textbf{"+model.split("-")[1].strip()+"}\n")
+ for folder in foldersToWrite:
+ PRF = prf.getPRF(embeddingsPath+"/"+folder+"/pred-data/"+evToTest+"/"+model)
+ out.write("& " + str(PRF[0]) + "\t & " + str(PRF[1]) + "\t & " + str(PRF[2]) + "\n")
+ out.write("\\\\ \n")
+
+
+
+
+
+
+
diff --git a/PythonScript/writePRFTable.py~ b/PythonScript/writePRFTable.py~
new file mode 100644
index 0000000..e89284e
--- /dev/null
+++ b/PythonScript/writePRFTable.py~
@@ -0,0 +1,36 @@
+# From Logisitc CMF : python PythonScript/writePRFTable.py Embeddings_Prediction_Data A HeldOut
+
+import sys
+import os
+from os import walk
+import prf
+
+embeddingsPath = sys.argv[1]
+relation = sys.argv[2]
+evToTest = sys.argv[3]
+
+writePath = embeddingsPath+"/Tables/"+relation+"-"+evToTest
+out = open(writePath, 'w')
+
+foldersToWrite = ['AZ', 'NV', 'WI', 'EDH', 'All']
+filesToWrite = []
+for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToTest):
+ for filename in filenames:
+ r = filename.split("-")[0].strip()
+ if r == relation:
+ filesToWrite.append(filename)
+ break
+
+for model in filesToWrite:
+ out.write("\\textbf{"+model.split("-")[1].strip()+"}\n")
+ for folder in foldersToWrite:
+ PRF = prf.getPRF(embeddingsPath+"/"+folder+"/pred-data/"+evToTest+"/"+model)
+ out.write("& " + str(PRF[0]) + "\t & " + str(PRF[1]) + "\t & " + str(PRF[2]) + "\n")
+ out.write("\\\\ \n")
+
+
+
+
+
+
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ada5b5b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,21 @@
+# README #
+
+### What is this repository for? ###
+
+* Yelp Dataset Challenge
+ * Parse Yelp Json format data to get required data in the required format
+ * Perform logisitc CMF on the parsed data to predict relations.
+
+### Details on Logistic CMF Code ###
+
+* This Java project currently has 2 packages :
+ * yelpDataProcessing - Contains classes/functions to read the yelp dataset in json format and parse to get different data in required format.
+ * logisticCMF - Contains classes/functions to read data produced in required format and then split train/validation/test data. Learn the embeddings for entities and print prediction evaluation.
+
+* The folder PythonSCript contains a file cleantext.py that reads the yelp review data in user format and pre-processes the text review.
+ * The text pre-processing contains tokenization, stemming, removal of stop words and punctuations.
+ * Each word is kept only once if occurs multiple times in a review.
+
+#### Project Contributors ####
+* Nitish Gupta
+* Sameer Singh
\ No newline at end of file
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |