diff --git a/Project/.classpath b/Project/.classpath new file mode 100644 index 0000000..e197418 --- /dev/null +++ b/Project/.classpath @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/Project/.project b/Project/.project new file mode 100644 index 0000000..772bcfe --- /dev/null +++ b/Project/.project @@ -0,0 +1,17 @@ + + + logisitic_cmf_yelp + + + + + + org.eclipse.jdt.core.javabuilder + + + + + + org.eclipse.jdt.core.javanature + + diff --git a/Project/bin/logisticCMF/Cell.class b/Project/bin/logisticCMF/Cell.class new file mode 100644 index 0000000..4da479f Binary files /dev/null and b/Project/bin/logisticCMF/Cell.class differ diff --git a/Project/bin/logisticCMF/Eval.class b/Project/bin/logisticCMF/Eval.class new file mode 100644 index 0000000..f9ea524 Binary files /dev/null and b/Project/bin/logisticCMF/Eval.class differ diff --git a/Project/bin/logisticCMF/Rating.class b/Project/bin/logisticCMF/Rating.class new file mode 100644 index 0000000..2688baa Binary files /dev/null and b/Project/bin/logisticCMF/Rating.class differ diff --git a/Project/bin/logisticCMF/Util.class b/Project/bin/logisticCMF/Util.class new file mode 100644 index 0000000..3815de8 Binary files /dev/null and b/Project/bin/logisticCMF/Util.class differ diff --git a/Project/bin/logisticCMF/codeTest.class b/Project/bin/logisticCMF/codeTest.class new file mode 100644 index 0000000..f745351 Binary files /dev/null and b/Project/bin/logisticCMF/codeTest.class differ diff --git a/Project/bin/logisticCMF/data.class b/Project/bin/logisticCMF/data.class new file mode 100644 index 0000000..54f8f0e Binary files /dev/null and b/Project/bin/logisticCMF/data.class differ diff --git a/Project/bin/logisticCMF/embedding.class b/Project/bin/logisticCMF/embedding.class new file mode 100644 index 0000000..2a99242 Binary files /dev/null and b/Project/bin/logisticCMF/embedding.class differ diff --git a/Project/bin/logisticCMF/embeddings.class b/Project/bin/logisticCMF/embeddings.class new file mode 100644 index 0000000..2af64dd Binary files /dev/null and b/Project/bin/logisticCMF/embeddings.class differ diff --git a/Project/bin/logisticCMF/learner.class b/Project/bin/logisticCMF/learner.class new file mode 100644 index 0000000..5dae12a Binary files /dev/null and b/Project/bin/logisticCMF/learner.class differ diff --git a/Project/bin/logisticCMF/writeDataToFile.class b/Project/bin/logisticCMF/writeDataToFile.class new file mode 100644 index 0000000..5ab457e Binary files /dev/null and b/Project/bin/logisticCMF/writeDataToFile.class differ diff --git a/Project/bin/postProcessing/EntityEmbeddings.class b/Project/bin/postProcessing/EntityEmbeddings.class new file mode 100644 index 0000000..24cea7f Binary files /dev/null and b/Project/bin/postProcessing/EntityEmbeddings.class differ diff --git a/Project/bin/postProcessing/Similarity.class b/Project/bin/postProcessing/Similarity.class new file mode 100644 index 0000000..58ce3b5 Binary files /dev/null and b/Project/bin/postProcessing/Similarity.class differ diff --git a/Project/bin/postProcessing/Util$1.class b/Project/bin/postProcessing/Util$1.class new file mode 100644 index 0000000..c4a5b10 Binary files /dev/null and b/Project/bin/postProcessing/Util$1.class differ diff --git a/Project/bin/postProcessing/Util.class b/Project/bin/postProcessing/Util.class new file mode 100644 index 0000000..aa47689 Binary files /dev/null and b/Project/bin/postProcessing/Util.class differ diff --git a/Project/bin/yelpDataProcessing/AttributeCategory.class b/Project/bin/yelpDataProcessing/AttributeCategory.class new file mode 100644 index 0000000..e786bb7 Binary files /dev/null and b/Project/bin/yelpDataProcessing/AttributeCategory.class differ diff --git a/Project/bin/yelpDataProcessing/ProcessYelpJson.class b/Project/bin/yelpDataProcessing/ProcessYelpJson.class new file mode 100644 index 0000000..18cb83e Binary files /dev/null and b/Project/bin/yelpDataProcessing/ProcessYelpJson.class differ diff --git a/Project/bin/yelpDataProcessing/reviewData.class b/Project/bin/yelpDataProcessing/reviewData.class new file mode 100644 index 0000000..9fffa50 Binary files /dev/null and b/Project/bin/yelpDataProcessing/reviewData.class differ diff --git a/Project/bin/yelpDataProcessing/reviewJson.class b/Project/bin/yelpDataProcessing/reviewJson.class new file mode 100644 index 0000000..86e2bfb Binary files /dev/null and b/Project/bin/yelpDataProcessing/reviewJson.class differ diff --git a/Project/src/logisticCMF/Cell.java b/Project/src/logisticCMF/Cell.java new file mode 100644 index 0000000..e7c851e --- /dev/null +++ b/Project/src/logisticCMF/Cell.java @@ -0,0 +1,22 @@ +package logisticCMF; +import java.util.*; + + +public class Cell { + String relation_id; // Id of relation in this data element + ArrayList entity_ids; // List of entities participating in Relation ' relation_id ' + boolean truth; // Truth Value for this data element + + public Cell(){ + entity_ids = new ArrayList(); + } + + public Cell(String r, String e1, String e2, boolean t){ + entity_ids = new ArrayList(); + relation_id = r; + entity_ids.add(e1); + entity_ids.add(e2); + truth = t; + } + +} diff --git a/Project/src/logisticCMF/Eval.java b/Project/src/logisticCMF/Eval.java new file mode 100644 index 0000000..5fd2542 --- /dev/null +++ b/Project/src/logisticCMF/Eval.java @@ -0,0 +1,157 @@ +package logisticCMF; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +public class Eval { + static Map relTrue = new HashMap(); // Correct predictions per relation + static Map relCount = new HashMap(); // Total test size per relation + static Map relL2 = new HashMap(); + + static Map relActualTruth = new HashMap(); // Actual Truth values in each relation + static Map relPredTruth = new HashMap(); // Truth predicted in each relation + static Map relTruthCorrect = new HashMap(); // Correct Truth predictions in each relation + + static Map> relEvalMap = new HashMap>(); //Map [Rel,] Also contains "weighted average" as relation + + + public static Map> getEvalMap(data Data, embeddings e, String set){ + refreshMaps(); + if(set.equals("test")){ + for(Cell cell : Data.testData) + updateTestMaps(cell, e); + } + else{ + for(Cell cell : Data.valData) + updateTestMaps(cell, e); + } + makeRelationEvalMap(); // Final Map [Relation, ] + return relEvalMap; + } + + public static void refreshMaps(){ + relTrue = new HashMap(); + relCount = new HashMap(); + relL2 = new HashMap(); + relActualTruth = new HashMap(); // Actual Truth values in each relation + relPredTruth = new HashMap(); // Truth predicted in each relation + relTruthCorrect = new HashMap(); // Correct Truth predictions in each relation + relEvalMap = new HashMap>(); + } + + public static void addInRelCountMap(Cell cell){ + if(!relCount.containsKey(cell.relation_id)) + relCount.put(cell.relation_id, 1); + else + relCount.put(cell.relation_id, relCount.get(cell.relation_id)+1); + } + + public static void addInprfRelationMap(String relation, Integer actual, Integer pred, Integer correct){ + if(!relActualTruth.containsKey(relation)){ + relActualTruth.put(relation, 0); + relPredTruth.put(relation, 0); + relTruthCorrect.put(relation, 0); + } + + else{ + if(!(actual == 0 && pred == 0)){ + relActualTruth.put(relation, relActualTruth.get(relation) + actual); + relPredTruth.put(relation, relPredTruth.get(relation) + pred); + relTruthCorrect.put(relation, relTruthCorrect.get(relation) + correct); + } + } + } + + public static void addInRelAccuracyMap(Cell cell, Integer correct){ + if(!relTrue.containsKey(cell.relation_id)) + relTrue.put(cell.relation_id, correct); + else + relTrue.put(cell.relation_id, relTrue.get(cell.relation_id)+correct); + } + + public static void addInRelL2Map(Cell cell, double l2){ + if(!relL2.containsKey(cell.relation_id)) + relL2.put(cell.relation_id, l2); + else + relL2.put(cell.relation_id, relTrue.get(cell.relation_id) + l2); + } + + //Map [Rel, ] + public static void makeRelationEvalMap(){ + //makeTestMaps(Data, e); + double f = 0; double wf1=0.0, waccuracy=0.0, wp = 0.0, wr = 0.0; int total = 0; + for(String rel : relActualTruth.keySet()){ + double accuracy = ((double)relTrue.get(rel))/relCount.get(rel); + double precision = (double)relTruthCorrect.get(rel) / relPredTruth.get(rel) ; + double recall = (double) relTruthCorrect.get(rel) / relActualTruth.get(rel) ; + double f1 = 2*precision*recall / (precision + recall) ; + //System.out.println("a : " + accuracy + " p : " + precision + " r : " + recall + " f1 : " +f1); + relEvalMap.put(rel, new ArrayList()); + relEvalMap.get(rel).add(round(accuracy, 3)); // Accuracy + relEvalMap.get(rel).add(round(precision, 3)); // Precision + relEvalMap.get(rel).add(round(recall, 3)); // Recall + relEvalMap.get(rel).add(round(f1, 3)); // F1 + wf1 += (relCount.get(rel)*f1); + wp += (relCount.get(rel)*precision); + wr += (relCount.get(rel)*recall); + waccuracy += (relCount.get(rel)*accuracy); + total += relCount.get(rel); + } + wf1 = round(wf1/total, 3); waccuracy = round(waccuracy/total, 3); wp = round(wp/total, 3); wr = round(wr/total, 3); + relEvalMap.put("average", new ArrayList()); + relEvalMap.get("average").add(round(waccuracy, 3)); // Accuracy + relEvalMap.get("average").add(round(wp, 3)); // Precision + relEvalMap.get("average").add(round(wr, 3)); // Recall + relEvalMap.get("average").add(round(wf1, 3)); // F1 + + } + + public static void updateTestMaps(Cell cell, embeddings e){ + int correct = 0; int t =0, f=0; int c = 0; double l2Sum = 0.0; + double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha); + double sigmdot = learner.sigm(dot); + int pred = (sigmdot >= 0.5) ? 1 : 0; + int truth = (cell.truth) ? 1 : 0; + double l2 = (truth - sigmdot)*(truth-sigmdot); + l2Sum += l2; + //System.out.println(sigmdot + " " + pred + " " + truth + " " + l2); + + if(pred == truth) + correct = 1; + else + correct = 0; + c += correct; + //System.out.println("rel : " + truth + " pred : " + pred); + //System.out.println(cell.relation_id + " : " + pred); + addInprfRelationMap(cell.relation_id, truth, pred, correct); + addInRelAccuracyMap(cell, correct); + addInRelCountMap(cell); + addInRelL2Map(cell, l2); + } + + public static void printEval(){ + for(String rel : relEvalMap.keySet()){ + ArrayList eval = relEvalMap.get(rel); + System.out.print(rel + " : "); + System.out.println("P : " + eval.get(1) + " R : " + eval.get(2) + " F1 : " + eval.get(3) + " Accuracy : " + eval.get(0)); + } + } + + + public static double round(double value, int places) { + if (places < 0) throw new IllegalArgumentException(); + if(Double.isNaN(value)) + return Double.NaN; + + BigDecimal bd = new BigDecimal(value); + bd = bd.setScale(places, RoundingMode.HALF_UP); + return bd.doubleValue(); + } + + + +} diff --git a/Project/src/logisticCMF/Rating.java b/Project/src/logisticCMF/Rating.java new file mode 100644 index 0000000..00d1b26 --- /dev/null +++ b/Project/src/logisticCMF/Rating.java @@ -0,0 +1,177 @@ +package logisticCMF; +import java.io.*; +import java.util.*; + + + +public class Rating { + + public Map> ratings; // Map[busId, Map[User, Rating]] + public String loc; + data [] busRate = new data[4]; + + public Rating(String folder, double valp, double testp) throws IOException{ + ratings = new HashMap>(); + loc = folder; + busRate[0] = new data(); + busRate[1] = new data(); + busRate[2] = new data(); + busRate[3] = new data(); + readRatingData(folder); + makeTwoRatingDataCells(valp, testp); + /*for(data rd : busRate){ + rd.dataStats(); + }*/ + } + + public void testRatings(){ + data mergeData = new data(); + ArrayList tomerge = new ArrayList(); + for(data d : busRate) + tomerge.add(d); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + codeTest.learnAndTest(mergeData, 30, false, 0, false, 0); + } + + public void readRatingData(String folder) throws NumberFormatException, IOException{ + String address = System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt"; + BufferedReader br = new BufferedReader(new FileReader(address)); + String line; int countl = 0, countcell = 0; + String busid = null, userid = null; boolean value = false; + while((line = br.readLine()) != null){ + countl++; + String[] array = line.split(":"); + if( array[0].trim().equals("bus_id")){ + busid = array[1].trim(); + if(!ratings.containsKey(busid)) + ratings.put(busid, new HashMap()); + } + if( array[0].trim().equals("user_id")) + userid = array[1].trim(); + if( array[0].trim().equals("star")){ + double t = Double.parseDouble(array[1].trim()); + /*if(ratings.get(busid).containsKey(userid)){ + int r = ratings.get(busid).get(userid); + System.out.println("user already exists " + busid + " " + userid); + + }*/ + ratings.get(busid).put(userid, (int)t); + countcell++; + } + } + br.close(); + //System.out.println("Ratings : " + countcell); + + } + + public void makeTwoRatingDataCells(double valPerc, double testPerc){ + String relation = "busrate-"+loc+"-"; + int count2 = 0; int count = 0; + // Read all rating data from Map and create busRate data object for threshold = 2 + for(String bus : ratings.keySet()){ + for(String user : ratings.get(bus).keySet()){ + int rate = ratings.get(bus).get(user); + count++; + Cell cell = new Cell(); + cell.relation_id = relation+"2"; + cell.entity_ids.add(bus); + cell.entity_ids.add(user); + if(rate >= 2){ + cell.truth = true; + busRate[0].Data.add(cell); + } + else{ + cell.truth = false; + busRate[0].Data.add(cell); + } + } + } + busRate[0].splitTrainTestValidation(valPerc, testPerc); + //busRate[0].dataStats(); + makeRestRatingDataCells(); + + System.out.println("Ratings : " + busRate[0].Data.size()); + } + + public void makeRestRatingDataCells(){ + String r = "busrate-"+loc+"-"; + for(Cell cell : busRate[0].trainData){ + String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); + int rate = ratings.get(b).get(u); + if(rate <= 2) + addCellsinRestTrain(r, b, u, false, false, false); + + if(rate == 3) + addCellsinRestTrain(r, b, u, true, false, false); + + if(rate == 4) + addCellsinRestTrain(r, b, u, true, true, false); + + if(rate == 5) + addCellsinRestTrain(r, b, u, true, true, true); + } + + for(Cell cell : busRate[0].testData){ + String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); + int rate = ratings.get(b).get(u); + if(rate <= 2) + addCellsinRestTest(r, b, u, false, false, false); + if(rate == 3) + addCellsinRestTest(r, b, u, true, false, false); + if(rate == 4) + addCellsinRestTest(r, b, u, true, true, false); + if(rate == 5) + addCellsinRestTest(r, b, u, true, true, true); + } + + for(Cell cell : busRate[0].valData){ + String b = cell.entity_ids.get(0), u = cell.entity_ids.get(1); + int rate = ratings.get(b).get(u); + if(rate <= 2) + addCellsinRestVal(r, b, u, false, false, false); + if(rate == 3) + addCellsinRestVal(r, b, u, true, false, false); + if(rate == 4) + addCellsinRestVal(r, b, u, true, true, false); + if(rate == 5) + addCellsinRestVal(r, b, u, true, true, true); + } + + } + + public void addCellsinRestTrain(String r, String b, String u, boolean t3, boolean t4, boolean t5){ + busRate[1].Data.add(new Cell(r+3, b, u, t3)); + busRate[1].trainData.add(new Cell(r+3, b, u, t3)); + + busRate[2].Data.add(new Cell(r+4, b, u, t4)); + busRate[2].trainData.add(new Cell(r+4, b, u, t4)); + + busRate[3].Data.add(new Cell(r+5, b, u, t5)); + busRate[3].trainData.add(new Cell(r+5, b, u, t5)); + } + + public void addCellsinRestTest(String r, String b, String u, boolean t3, boolean t4, boolean t5){ + busRate[1].Data.add(new Cell(r+3, b, u, t3)); + busRate[1].testData.add(new Cell(r+3, b, u, t3)); + + busRate[2].Data.add(new Cell(r+4, b, u, t4)); + busRate[2].testData.add(new Cell(r+4, b, u, t4)); + + busRate[3].Data.add(new Cell(r+5, b, u, t5)); + busRate[3].testData.add(new Cell(r+5, b, u, t5)); + } + + public void addCellsinRestVal(String r, String b, String u, boolean t3, boolean t4, boolean t5){ + busRate[1].Data.add(new Cell(r+3, b, u, t3)); + busRate[1].valData.add(new Cell(r+3, b, u, t3)); + + busRate[2].Data.add(new Cell(r+4, b, u, t4)); + busRate[2].valData.add(new Cell(r+4, b, u, t4)); + + busRate[3].Data.add(new Cell(r+5, b, u, t5)); + busRate[3].valData.add(new Cell(r+5, b, u, t5)); + } + +} + diff --git a/Project/src/logisticCMF/Util.java b/Project/src/logisticCMF/Util.java new file mode 100644 index 0000000..4e4ef05 --- /dev/null +++ b/Project/src/logisticCMF/Util.java @@ -0,0 +1,156 @@ +package logisticCMF; + +import java.io.IOException; +import java.util.ArrayList; + + + +/*import org.jfree.chart.ChartFactory; +import org.jfree.chart.ChartFrame; +import org.jfree.chart.JFreeChart; +import org.jfree.chart.axis.ValueAxis; +import org.jfree.chart.plot.PlotOrientation; +import org.jfree.chart.plot.XYPlot; +import org.jfree.data.xy.XYSeries; +import org.jfree.data.xy.XYSeriesCollection;*/ +import java.util.*; + +//import postProcessing.data; + + +public class Util { + + static Random seed = new Random(100); + + public static HashSet getColdEntites(ArrayList data, int index, double perc){ + HashSet entities = new HashSet(); + for(Cell cell : data) + entities.add(cell.entity_ids.get(index)); + + ArrayList ents = new ArrayList(entities); + Collections.shuffle(ents, seed); + HashSet coldEntities = new HashSet(ents.subList(0, (int)(perc*entities.size()/100.0))); + return coldEntities; + + } + + public static int getNegSampleSize(data Data){ + return (int)(((double)Data.Data.size())/Data.entityIds.get(0).size()); + } + + public static double getMatrixDetails(data Data){ + System.out.println("rows : " + Data.entityIds.get(0).size()); + System.out.println("cols : " + Data.entityIds.get(1).size()); + System.out.println("size : " + Data.Data.size()); + return ((double)Data.Data.size())/Data.entityIds.get(0).size(); + + } + + /*public static void plotGraph(ArrayList x, ArrayList y, String filename) throws IOException{ + XYSeries series = new XYSeries(filename); + for(int i =0; i entities = new HashSet(); + int index = 0, flag = 0; + for(Cell cell : D.trainData){ + entities.add(cell.entity_ids.get(index)); + } + for(Cell cell : D.valData){ + entities.add(cell.entity_ids.get(index)); + } + for(Cell cell : D.testData){ + if(entities.contains(cell.entity_ids.get(index))){ + System.out.print("WRONG COLD START SPLIT, Index : " + index + ", "); + flag = 1; + break; + } + } + if(flag == 0){ + System.out.print("Index ColdStart : " + index + ", "); + } + + index = 1; flag = 0; + entities = new HashSet(); + for(Cell cell : D.trainData){ + entities.add(cell.entity_ids.get(index)); + } + for(Cell cell : D.valData){ + entities.add(cell.entity_ids.get(index)); + } + for(Cell cell : D.testData){ + if(entities.contains(cell.entity_ids.get(index))){ + System.out.println("NO COLD START SPLIT, Index : " + index); + flag = 1; + break; + } + } + if(flag == 0) + System.out.println("Index ColdStart : " + index); + + } + + public static void countEntities(ArrayList data){ + HashSet e1 = new HashSet(); + HashSet e2 = new HashSet(); + + for(Cell cell : data){ + e1.add(cell.entity_ids.get(0)); + e2.add(cell.entity_ids.get(1)); + } + + System.out.println("e1 : " + e1.size() + " e2 : "+e2.size()); + + } + + public static void implicitColdStart(ArrayList trdata, ArrayList testData){ + HashSet tre1 = new HashSet(); + HashSet tre2 = new HashSet(); + + for(Cell cell : trdata){ + tre1.add(cell.entity_ids.get(0)); + tre2.add(cell.entity_ids.get(1)); + } + + System.out.println("tre1 : " + tre1.size() + " tre2 : "+ tre2.size()); + + HashSet tee1 = new HashSet(); + HashSet tee2 = new HashSet(); + for(Cell cell : testData){ + tee1.add(cell.entity_ids.get(0)); + tee2.add(cell.entity_ids.get(1)); + } + + System.out.println("tee1 : " + tee1.size() + " tee2 : "+ tee2.size()); + tee1.removeAll(tre1); + tee2.removeAll(tre2); + + + System.out.println("e1 cold : " + tee1.size() + " e2 cold : " + tee2.size() ); + + } +} diff --git a/Project/src/logisticCMF/codeTest.java b/Project/src/logisticCMF/codeTest.java new file mode 100644 index 0000000..e3460a1 --- /dev/null +++ b/Project/src/logisticCMF/codeTest.java @@ -0,0 +1,676 @@ +package logisticCMF; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Map; +import java.util.Random; + +public class codeTest { + + static Random rand = new Random(); + + private static final long MEGABYTE = 1024L * 1024L; + public static long bytesToMegabytes(long bytes) { + return bytes / MEGABYTE; + } + public static void getMemoryDetails(){ + // Get the Java runtime + Runtime runtime = Runtime.getRuntime(); + // Run the garbage collector + //runtime.gc(); + // Calculate the used memory + System.out.println("Total memory : " + bytesToMegabytes(runtime.totalMemory()) + " Free memory : " + bytesToMegabytes(runtime.freeMemory()) + + " Used memory is megabytes: " + bytesToMegabytes(runtime.totalMemory() - runtime.freeMemory())); + + } + + // Input Data, that is split in Train, Validation and Test. Input K. Learn and stop on convergence. Then test with learnt no. of epochs. + public static embeddings learnAndTest(data Data, int K, boolean busWord, int busWordNegSamSize, boolean userWord, int userWordNegSamSize){ + embeddings e = new embeddings(Data, K); + learner l = new learner(); + embeddings eBest = l.learnAndStop1(Data, e, false, false, false, busWord, busWordNegSamSize, userWord, userWordNegSamSize); + System.out.println("learning done"); + Eval.getEvalMap(Data, eBest, "test"); + Eval.printEval(); + return eBest; + } + + public static data readAttributes(String folder, double valPerc, double testPerc, boolean coldStart, int index) throws IOException{ + data att = new data(); + att.readBusAtt(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/busAtt.txt", folder); + if(coldStart) + att.splitColdStart(valPerc, testPerc, index); + else + att.splitTrainTestValidation(valPerc, testPerc); + + return att; + } + + public static data readCategories(String folder, int pruneThresh) throws IOException{ + data cat = new data(); + cat.readAndCompleteCategoryData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/busCat.txt", pruneThresh, folder); + cat.splitTrainTestValidation(0.0, 0.0); + return cat; + } + + public static data readRatings(String folder, double valPerc, double testPerc, boolean coldStart, int index) throws IOException{ + data rate = new data(); + rate.readRatingData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews.txt", folder); + if(coldStart) + rate.splitColdStart(valPerc, testPerc, index); + else + rate.splitTrainTestValidation(valPerc, testPerc); + return rate; + } + + public static data readReviewData(String folder, int occThresh, boolean busWord, boolean userWord, double valPerc, double testPerc) throws IOException{ + data rD = new data(); + rD.readReviewData(System.getProperty("user.dir")+"/../Dataset/data/"+ folder +"/reviews_textProc.txt"); // Makes EnWord maps (Business and User Word maps) and Word-Count map + rD.pruneVocab_EntityMap(occThresh); + if(busWord) + rD.makeEnWordCells("b-word"); + if(userWord) + rD.makeEnWordCells("u-word"); + rD.splitTrainTestValidation(valPerc, testPerc); + return rD; + } + + public static void completeEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS, + boolean attCold, int coldIndexAtt, boolean rateCold, int coldIndexRate, String folderToWriteData) throws IOException{ + + System.out.print("Attribute Cold Start, "); + Util.checkColdStartSanity(A); + Util.implicitColdStart(A.trainData, A.testData); + System.out.print("Rate Cold Start, "); + Util.checkColdStartSanity(R); + Util.implicitColdStart(R.trainData, R.testData); + + ArrayList tomerge = new ArrayList(); + data mergeData = new data(); + + getMemoryDetails(); + + System.out.println("################################################### "+folder+" - Att-Word ######################################"); + tomerge.clear(); + tomerge.add(A); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"A-A", folder, A, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+BW", folder, A, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+UW", folder, A, e); + if(userWord && busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+BUW", folder, A, e); + getMemoryDetails(); + + System.out.println("################################################### "+folder+" - Att-Cat-Word ######################################"); + //rD = readReviewData(folder, 10, busWord, false, 0.0, 0.0); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C", folder, A, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BW", folder, A, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+UW", folder, A, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BUW", folder, A, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Rating - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(R); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"R-R", folder, R, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+BW", folder, R, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+UW", folder, R, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+BUW", folder, R, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Rating - Cat - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(R); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C", folder, R, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BW", folder, R, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+UW", folder, R, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BUW", folder, R, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Att Rate - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+R", folder, R, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BW", folder, R, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+UW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+UW", folder, R, e); + } + if(userWord && busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BUW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BUW", folder, R, e); + } + System.gc(); + + + System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(); + mergeData.busWord = W.busWord; + mergeData.words = W.words; + mergeData.wordCount = W.wordCount; + mergeData.userWord = W.userWord; + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R", folder, R, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BW", folder, R, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+UW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+UW", folder, R, e); + } + if(busWord && userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BUW", folder, A, e); + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BUW", folder, R, e); + } + getMemoryDetails(); + + if(busWord || userWord){ + if(busWord && !userWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e); + } + if(userWord && !busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e); + } + if(userWord && busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e); + } + } + + } + + public static void AttBusColdCompleteEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS, String folderToWriteData) throws IOException{ + System.out.print("Attribute Cold Start, "); + Util.checkColdStartSanity(A); + Util.implicitColdStart(A.trainData, A.testData); + + ArrayList tomerge = new ArrayList(); + data mergeData = new data(); + + getMemoryDetails(); + + System.out.println("################################################### "+folder+" - Att-Word ######################################"); + tomerge.clear(); + tomerge.add(A); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"A-A", folder, A, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+BW", folder, A, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+UW", folder, A, e); + if(userWord && busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+BUW", folder, A, e); + getMemoryDetails(); + + System.out.println("################################################### "+folder+" - Att-Cat-Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C", folder, A, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BW", folder, A, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+UW", folder, A, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+BUW", folder, A, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Att Rate - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R", folder, A, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BW", folder, A, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+UW", folder, A, e); + } + if(userWord && busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+R+BUW", folder, A, e); + } + System.gc(); + + + System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R", folder, A, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BW", folder, A, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+UW", folder, A, e); + } + if(busWord && userWord){ + writeDataToFile.writePrediction(folderToWriteData+"A-A+C+R+BUW", folder, A, e); + } + getMemoryDetails(); + + if(busWord || userWord){ + if(busWord && !userWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e); + } + if(userWord && !busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e); + } + if(userWord && busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e); + } + } + + } + + public static void RateColdCompleteEvaluation(String folder, data A, data C, data R, data W, boolean busWord, int bwNS, boolean userWord, int uwNS, + String folderToWriteData) throws IOException{ + + System.out.print("Rate Cold Start, "); + Util.checkColdStartSanity(R); + Util.implicitColdStart(R.trainData, R.testData); + + ArrayList tomerge = new ArrayList(); + data mergeData = new data(); + + getMemoryDetails(); + + System.out.println("################################################### "+folder+" - Rating - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(R); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + embeddings e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"R-R", folder, R, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+BW", folder, R, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+UW", folder, R, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+BUW", folder, R, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Rating - Cat - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(R); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C", folder, R, e); + if(busWord && !userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BW", folder, R, e); + if(userWord && !busWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+UW", folder, R, e); + if(busWord && userWord) + writeDataToFile.writePrediction(folderToWriteData+"R-R+C+BUW", folder, R, e); + System.gc(); + + System.out.println("################################################### "+folder+" - Att Rate - Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+R", folder, R, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BW", folder, R, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+UW", folder, R, e); + } + if(userWord && busWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+R+BUW", folder, R, e); + } + System.gc(); + + + System.out.println("################################################### "+folder+" - Att Cat Rate Word ######################################"); + tomerge = new ArrayList(); + tomerge.clear(); + tomerge.add(A); + tomerge.add(R); + tomerge.add(C); + tomerge.add(W); + mergeData = new data(W); + mergeData.addDataAfterSplit(tomerge); + mergeData.dataStats(); + e = learnAndTest(mergeData, 30, busWord, bwNS, userWord, uwNS); + if(!(busWord || userWord)){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R", folder, R, e); + } + if(busWord && !userWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BW", folder, R, e); + } + if(userWord && !busWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+UW", folder, R, e); + } + if(busWord && userWord){ + writeDataToFile.writePrediction(folderToWriteData+"R-A+C+R+BUW", folder, R, e); + } + getMemoryDetails(); + + if(busWord || userWord){ + if(busWord && !userWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-bw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-bw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-bw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-bw", R, 0, e); + } + if(userWord && !busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-uw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-uw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-uw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-uw", R, 0, e); + } + if(userWord && busWord){ + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"words-buw", W, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"attributes-buw", A, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"categories-buw", C, 1, e); + writeDataToFile.writeEmbeddings(folder, folderToWriteData+"business-buw", R, 0, e); + } + } + System.gc(); + } + + public static void performAttBusColdEvaluation(String folder) throws IOException{ + String folderToWriteData = "AttBusCold/"; + data A = readAttributes(folder, 15.0, 15.0, true, 0); + data C = readCategories(folder, 5); + data R = readRatings(folder, 0.0, 0.0, false, 1); + data W = new data(); + + AttBusColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData); + System.gc(); + + // Business - Words + W = readReviewData(folder, 10, true, false, 0.0, 0.0); + int bwNS = Util.getNegSampleSize(W); + AttBusColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData); + System.gc(); + + // User - Words + W = readReviewData(folder, 10, false, true, 0.0, 0.0); + int uwNS = Util.getNegSampleSize(W); + AttBusColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData); + System.gc(); + + // BusWords and UserWords + /*W = readReviewData(folder, 10, true, true, 0.0, 0.0); + AttBusColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData); + System.gc();*/ + } + + public static void performRateBusColdEvaluation(String folder) throws IOException{ + String folderToWriteData = "RateBusCold/"; + data A = readAttributes(folder, 0.0, 0.0, false, 0); + data C = readCategories(folder, 5); + data R = readRatings(folder, 15.0, 15.0, true, 0); + data W = new data(); + + RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData); + System.gc(); + + // Business - Words + W = readReviewData(folder, 10, true, false, 0.0, 0.0); + int bwNS = Util.getNegSampleSize(W); + RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData); + System.gc(); + + // User - Words + W = readReviewData(folder, 10, false, true, 0.0, 0.0); + int uwNS = Util.getNegSampleSize(W); + RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData); + System.gc(); + + // BusWords and UserWords + /*W = readReviewData(folder, 10, true, true, 0.0, 0.0); + RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData); + System.gc();*/ + } + + public static void performRateUserColdEvaluation(String folder) throws IOException{ + String folderToWriteData = "RateUserCold/"; + data A = readAttributes(folder, 0.0, 0.0, false, 0); + data C = readCategories(folder, 5); + data R = readRatings(folder, 15.0, 15.0, true, 1); + data W = new data(); + + RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, false, 0, folderToWriteData); + System.gc(); + + // Business - Words + W = readReviewData(folder, 10, true, false, 0.0, 0.0); + int bwNS = Util.getNegSampleSize(W); + RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, false, 0, folderToWriteData); + System.gc(); + + // User - Words + W = readReviewData(folder, 10, false, true, 0.0, 0.0); + int uwNS = Util.getNegSampleSize(W); + RateColdCompleteEvaluation(folder, A, C, R, W, false, 0, true, uwNS, folderToWriteData); + System.gc(); + + // BusWords and UserWords + /*W = readReviewData(folder, 10, true, true, 0.0, 0.0); + RateColdCompleteEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, folderToWriteData); + System.gc();*/ + } + + public static void performHeldOutEvaluation(String folder) throws IOException{ + String folderToWriteData = "HeldOut/"; + data A = readAttributes(folder, 15.0, 15.0, false, 0); + data C = readCategories(folder, 5); + data R = readRatings(folder, 15.0, 15.0, false, 1); + data W = new data(); + + // No Words + completeEvaluation(folder, A, C, R, W, false, 0, false, 0, false, 0, false, 0, folderToWriteData); + System.gc(); + + // Business - Words + W = readReviewData(folder, 10, true, false, 0.0, 0.0); + int bwNS = Util.getNegSampleSize(W); + completeEvaluation(folder, A, C, R, W, true, bwNS, false, 0, false, 0, false, 0, folderToWriteData); + System.gc(); + + // User - Words + W = readReviewData(folder, 10, false, true, 0.0, 0.0); + int uwNS = Util.getNegSampleSize(W); + completeEvaluation(folder, A, C, R, W, false, 0, true, uwNS, false, 0, false, 0, folderToWriteData); + System.gc(); + + // BusWords and UserWords + /*W = readReviewData(folder, 10, true, true, 0.0, 0.0); + getMemoryDetails(); + completeEvaluation(folder, A, C, R, W, true, bwNS, true, uwNS, false, 0, false, 0, folderToWriteData); + System.gc();*/ + } + + // To test one dataset completely and write embeddings for (A + R + C + W) + public static void main(String [] args) throws Exception { + String folder = args[0]; + String todo = args[1]; + todo = "heldOut"; + folder = "EDH"; + + if(todo.equals("heldOut")) + performHeldOutEvaluation(folder); + if(todo.equals("attBusCold")) + performAttBusColdEvaluation(folder); + if(todo.equals("rateBusCold")) + performRateBusColdEvaluation(folder); + if(todo.equals("rateUserCold")) + performRateUserColdEvaluation(folder); + //attBusColdEvaluations(folder); + //rateBusColdEvaluations(folder); + } + + // To make the sizes table + /*public static void main(String [] args) throws Exception { + + String folder = "WI"; + data A = readAttributes(folder, 15.0, 15.0, false, 0); + //data C = readCategories(folder, 5); + data R = readRatings(folder, 0.0, 0.0, false, 1); + //data BW = readReviewData(folder, 10, true, false, 0.0, 0.0); + //data UW = readReviewData(folder, 10, false, true, 0.0, 0.0); + + A.dataStats(); + + Util.getMatrixDetails(A); + Util.getMatrixDetails(C); + Util.getMatrixDetails(R); + //Util.getMatrixDetails(BW); + + + + }*/ + +} \ No newline at end of file diff --git a/Project/src/logisticCMF/data.java b/Project/src/logisticCMF/data.java new file mode 100644 index 0000000..a26f514 --- /dev/null +++ b/Project/src/logisticCMF/data.java @@ -0,0 +1,601 @@ +package logisticCMF; +import java.io.*; +import java.util.*; + +import javax.json.Json; +import javax.json.JsonNumber; +import javax.json.JsonObject; +import javax.json.JsonReader; +import javax.json.JsonString; +import javax.json.JsonValue; + +public class data { + + public ArrayList trainData; + public ArrayList testData; + public ArrayList valData; + public ArrayList Data; + public Map relDataCount; // Map [Relation Id, Count] + ArrayList> entityIds; + + public Map wordCount = new HashMap(); // Map [Words -> Count] - In original review Data + public Map> busWord = new HashMap>(); // Map [restaurant -> [words used for it]] - Can be pruned according to word occurrence threshold + public Map> userWord = new HashMap>(); // Map [user -> [words used for it]] - Can be pruned according to word occurrence threshold + public ArrayList words = new ArrayList(); // Set [words] - The vocab for our review data. Can be pruned according to word occurrence threshold + public static Random seed = new Random(50); // Also defined in Embedding & Learner Class + + + public data(){ + trainData = new ArrayList(); + valData = new ArrayList(); + testData = new ArrayList(); + Data = new ArrayList(); + relDataCount = new HashMap(); + wordCount = new HashMap(); + busWord = new HashMap>(); + userWord = new HashMap>(); + words = new ArrayList(); + entityIds = new ArrayList>(); + entityIds.add(new HashSet()); + entityIds.add(new HashSet()); + } + + public data(data D){ + busWord = D.busWord; + userWord = D.userWord; + words = D.words; + wordCount = D.wordCount; + trainData = new ArrayList(); + valData = new ArrayList(); + testData = new ArrayList(); + Data = new ArrayList(); + relDataCount = new HashMap(); + entityIds = new ArrayList>(); + entityIds.add(new HashSet()); + entityIds.add(new HashSet()); + } + + public void addToEntitySets(Cell cell){ + entityIds.get(0).add(cell.entity_ids.get(0)); + entityIds.get(1).add(cell.entity_ids.get(1)); + } + + public void readBusAtt(String fileAddress, String rel) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + String relation = "b-att-"+rel; + String line; + String bId = ""; + int count = 0; + while( (line = br.readLine()) != null ){ + //System.out.println(line); + String [] arr = line.split(":"); + if(arr.length >= 2){ + if(arr[0].trim().equals("business_id")) + bId = arr[1].trim(); + + else{ + String att = arr[0].trim(); + double t = Double.parseDouble(arr[1].trim()); + Cell cell = new Cell(); + cell.relation_id = relation; + cell.entity_ids.add(bId); + cell.entity_ids.add(att); + cell.truth = (t == 1.0) ? true : false; + addToEntitySets(cell); + Data.add(cell); + count++; + } + + + } + } + br.close(); + System.out.println(relation + " : " + count); + } + + public void printData(){ + for(Cell cell : Data){ + System.out.print("\n"+cell.relation_id + ", "); + for(String e : cell.entity_ids) + System.out.print(e + ", "); + System.out.print(cell.truth); + } + } + + public void makeRestaurantSet(String fileAddress) throws IOException{ + BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress+"1")); + HashMap restaurants = new HashMap(); + for(Cell cell : Data){ + if(restaurants.containsKey(cell.entity_ids.get(0))) + continue; + else{ + restaurants.put(cell.entity_ids.get(0), 1.0); + bw.write(cell.entity_ids.get(0) + "\n"); + } + } + bw.close(); + } + + public void refreshTrainTest(){ + trainData = new ArrayList(); + valData = new ArrayList(); + testData = new ArrayList(); + } + + public void readRatingData(String fileAddress, String rel) throws IOException{ + Map> ratings = new HashMap>(); + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + String line; int countl = 0, countcell = 0; + String relation = "b-u-rate-"+rel; + String busid = null, userid = null; boolean value = false; + while((line = br.readLine()) != null){ + countl++; + String[] array = line.split(":"); + if( array[0].trim().equals("bus_id")){ + busid = array[1].trim(); + if(!ratings.containsKey(busid)) + ratings.put(busid, new HashMap()); + } + if( array[0].trim().equals("user_id")) + userid = array[1].trim(); + if( array[0].trim().equals("star")){ + double t = Double.parseDouble(array[1].trim()); + ratings.get(busid).put(userid, (int)t); + } + } + br.close(); + for(String bus : ratings.keySet()){ + for(String user : ratings.get(bus).keySet()){ + int rate = ratings.get(bus).get(user); + Cell cell = new Cell(); + cell.relation_id = relation; + cell.entity_ids.add(bus); + cell.entity_ids.add(user); + cell.truth = (rate >= 4) ? true : false; + Data.add(cell); + addToEntitySets(cell); + countcell++; + } + } + System.out.println("Ratings read : " + countcell); + } + + public void countLines(String fileAddress) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + String line; String relation = "restaurant-user-word"; int count = 0; + String busid = null, userid = null; boolean value = false; + while((line = br.readLine()) != null){ + count ++; + } + System.out.println(count); + } + + public void splitTrainTestValidation(double valPercentage, double testPercentage){ + Collections.shuffle(Data, seed); + for(int j = 0; j coldEntities = Util.getColdEntites(Data, index, testPerc); + int countTest = 0, i=0; + for(Cell cell : Data){ + if(coldEntities.contains(cell.entity_ids.get(index))){ + testData.add(cell); + countTest++; + } + else + trainData.add(cell); + } + Collections.shuffle(trainData, seed); + Iterator it = trainData.iterator(); + int cellValidation = (int)((valPercentage/100)*Data.size()); + System.out.println(cellValidation); + while(i < cellValidation){ + Cell cell = it.next(); + valData.add(cell); + it.remove(); + i++; + } + } + + public void addDatainExisting(data d){ + Collections.shuffle(Data, seed); Collections.shuffle(d.Data, seed); + for(Cell cell : d.Data) + Data.add(cell); + + Collections.shuffle(Data, seed); + } + + public void addDataAfterSplit(ArrayList dataList){ + for(data d : dataList){ + for(Cell cell : d.Data) + Data.add(cell); + for (Cell cell : d.trainData) + trainData.add(cell); + for(Cell cell : d.valData) + valData.add(cell); + for(Cell cell : d.testData) + testData.add(cell); + } + + Collections.shuffle(Data, seed); + Collections.shuffle(trainData, seed); + Collections.shuffle(valData, seed); + Collections.shuffle(testData, seed); + } + + // Used by dataStats() function + public void countTrueFalse(ArrayList data, Map relTrue, Map relFalse){ + for(Cell cell : data){ + int t = (cell.truth) ? 1 : 0; int f = (cell.truth) ? 0 : 1; + if(t == 1){ + if(!relTrue.containsKey(cell.relation_id)) + relTrue.put(cell.relation_id, 1); + else + relTrue.put(cell.relation_id, relTrue.get(cell.relation_id) + 1); + } + if(f == 1){ + if(!relFalse.containsKey(cell.relation_id)) + relFalse.put(cell.relation_id, 1); + else + relFalse.put(cell.relation_id, relFalse.get(cell.relation_id) + 1); + } + } + } + + public void dataStats(){ + Map relTrue = new HashMap(); + Map relFalse = new HashMap(); + countTrueFalse(Data, relTrue, relFalse); + System.out.println("\nData Stats"); + for(String rel : relTrue.keySet()) + System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel)); + + + + relTrue.clear(); + relFalse.clear(); + countTrueFalse(trainData, relTrue, relFalse); + System.out.println("Train Data Stats"); + for(String rel : relTrue.keySet()) + System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel)); + + relTrue.clear(); + relFalse.clear(); + countTrueFalse(valData, relTrue, relFalse); + System.out.println("Validation Data Stats"); + for(String rel : relTrue.keySet()) + System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel)); + + relTrue.clear(); + relFalse.clear(); + countTrueFalse(testData, relTrue, relFalse); + System.out.println("Test Data Stats"); + for(String rel : relTrue.keySet()) + System.out.println(rel + " : " + "t : " + relTrue.get(rel) + " f : " + relFalse.get(rel)); + + } + + public void reviewDataStats(int entityId, int thresh, boolean removeEntities){ + Map users = new HashMap(); // Map[EntityID, Count in Set] + Set e1 = new HashSet(); + Set e2 = new HashSet(); + Set setUsers = new HashSet(); + int count = 0, max = -1, min = 1000000000; + for(Cell cell : Data){ + String user = cell.entity_ids.get(entityId); /// CHECK WHICH ENTITY MAP IS CREATED + e1.add(cell.entity_ids.get(0)); + e2.add(cell.entity_ids.get(1)); + setUsers.add(cell.entity_ids.get(entityId)); + if(!users.containsKey(user)){ + users.put(user, 1); + } + else{ + int revCount = users.get(user); + int newRevCount = revCount+1 ; + users.put(user, newRevCount); + } + } + + for(String user : users.keySet()){ + if(users.get(user) > max) + max = users.get(user); + if(users.get(user) < min) + min = users.get(user); + if(users.get(user) <= thresh){ + count++; + setUsers.remove(user); + //System.out.println(user + " : " + users.get(user)); + } + } + int iterates = 0; + if(removeEntities){ + for(Iterator itr = Data.iterator();itr.hasNext();){ + iterates++; + Cell cell = itr.next(); + if(!setUsers.contains(cell.entity_ids.get(entityId))){ + itr.remove(); + } + } + } + System.out.println("Total iterates = "+iterates); + System.out.println(e1.size() + " : " + e2.size() + " : " +users.keySet().size() + " : " + setUsers.size()); + System.out.println("count : " + count + " max = " + max + " min : " + min); + } + + public void readAndCompleteCategoryData(String fileAddress, int thresh, String rel) throws NumberFormatException, IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + Map> resCat = new HashMap>(); + Map catCount = new HashMap(); + Set categories = new HashSet(); + String line; String relation = "b-cat-"+rel; + int count = 0; + while( (line = br.readLine()) != null) { + + String[] array = line.split(":"); + resCat.put(array[0].trim(), new ArrayList()); + + String [] cats = array[1].trim().split(";"); // Delimiter used in resCat file. Refer to YelpChallenge-yeldData.java + for(String cat : cats){ + String category = cat.trim(); + if(category.length()>1){ + categories.add(category); + resCat.get(array[0].trim()).add(category); + + // To build Map[Category, Count] + if(!catCount.containsKey(category)) + catCount.put(category, 1); + else{ + int cat_count = catCount.get(category); + cat_count++; + catCount.put(category, cat_count); + } + } + } + } + br.close(); + + + int cC = 0; + for(String c : catCount.keySet()){ + if(catCount.get(c) > thresh) + cC++; + } + //System.out.println("Categories after pruning : " + cC); + + for(String res : resCat.keySet()){ + int categoriesConsidered = 0; + for(String cat : categories){ + if(catCount.get(cat) > thresh){ + categoriesConsidered++; + Cell cell = new Cell(); + cell.relation_id = relation; + cell.entity_ids.add(res); + cell.entity_ids.add(cat); + if(resCat.get(res).contains(cat)) + cell.truth = true; + else + cell.truth = false; + count++; + Data.add(cell); + addToEntitySets(cell); + } + cC = categoriesConsidered; + } + } + System.out.println(relation + " : " + count + " categories read : " + cC); + } + + public void reduceDataSize(double perc){ + Collections.shuffle(Data, seed); + int iterations = 0; int size = Data.size(); int stopAt = (int)((perc/100.0)*size); + System.out.println(stopAt); + for(Iterator itr = Data.iterator();itr.hasNext();){ + iterations++; + if(iterations <= stopAt) + itr.next(); + else{ + itr.next(); + itr.remove(); + } + } + } + + public void readReviewData(String fileAddress) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + String line; String bId = null; String userId = null; int countl = 0, countR = 0; + + while((line = br.readLine()) != null){ + countl++; + String[] array = line.split(":"); + + if( array[0].trim().equals("user_id")){ + countR++; + userId = array[1].trim(); + if(!userWord.containsKey(userId)) + userWord.put(userId, new HashSet()); + } + + if( array[0].trim().equals("bus_id")){ + bId = array[1].trim(); + if(!busWord.containsKey(bId)) + busWord.put(bId, new HashSet()); + } + + if(array[0].trim().equals("text")){ + String [] tokens = array[1].trim().split(" "); + if(tokens.length > 0){ + for(String word : tokens){ + word = word.trim(); + if(word.length() >= 3){ + addWordInMap(word); + busWord.get(bId).add(word); + userWord.get(userId).add(word); + } + } + } + } + /*if(countl % 100000 == 0) + System.out.println("line : "+countl);*/ + } + System.out.println("Total No. of Reviews : " + countR); + } + + public void addWordInMap(String word){ + if(!wordCount.containsKey(word)) + wordCount.put(word, 0); + int wcount = wordCount.get(word); + wcount++; + wordCount.put(word, wcount); + } + + public void getMapStats(Map> enWord){ + int potentialResWordCells = 0; + int min=100000000; + System.out.println("No. of entities = " + enWord.keySet().size()); + System.out.println("No. of total words in Vocab = " + words.size()); + + for(String en : enWord.keySet()){ + min = (enWord.get(en).size() < min) ? enWord.get(en).size() : min; + for(String word : enWord.get(en)){ + potentialResWordCells++; + } + + } + System.out.println("Potential Entity-Word Cells : " + potentialResWordCells); + System.out.println("Words per entiy : " + ((double) potentialResWordCells)/enWord.keySet().size()); + System.out.println("Min No. of Words in Entity : " + min); + + } + + public void pruneVocab_EntityMap(int occThresh){ + makePrunedWordList(occThresh); + pruneEntityWordMap(busWord, occThresh); + pruneEntityWordMap(userWord, occThresh); + } + + // Remove words from Map[Entity -> Set[words]] that occur few times in dictionary. If Set of words for entity go empty, remove Entity from Map. + public void pruneEntityWordMap(Map> enWord, int occThresh){ + Iterator it = enWord.keySet().iterator(); + while(it.hasNext()){ + String en = it.next(); + Iterator itr = enWord.get(en).iterator(); + while(itr.hasNext()){ + String word = itr.next(); + if(wordCount.get(word) <= occThresh) + itr.remove(); + } + if(enWord.get(en).size() == 0) + it.remove(); + } + } + + public void getWordCountStats(int start, int end){ + int count = 0; + for(int i = start; i<=end; i++){ + for(String word : wordCount.keySet()){ + if(wordCount.get(word) > i) + count++; + } + System.out.println("Words with greater that " + i + " count : " + count); + count = 0; + } + } + + // Make a Array of Words that have frequency above the given threshold. + public void makePrunedWordList(int occThresh){ + words = new ArrayList(); int count = 0; + for(String word : wordCount.keySet()){ + if(wordCount.get(word) > occThresh){ + count++; + words.add(word); + } + } + System.out.println("Words with greater than occurence of " + occThresh + " : " + words.size()); + } + + public void makeEnWordCells(String relation){ + if(relation.equals("b-word")){ + enWordCells(busWord, relation); + } + else if(relation.equals("u-word")){ + enWordCells(userWord, relation); + } + } + + public void enWordCells(Map> enWord, String relation){ + for(String en : enWord.keySet()){ + for(String word : enWord.get(en)){ + Cell cell = new Cell(); + cell.relation_id = relation; + cell.entity_ids.add(en); + cell.entity_ids.add(word); + cell.truth = true; + Data.add(cell); + addToEntitySets(cell); + } + } + } + + public ArrayList getNegativeSamples(String relation, int negSamplesPerEntity){ + ArrayList negSamples = new ArrayList(); + int negSamplesDone = 0; + if(relation.equals("b-word")){ + for(String en : busWord.keySet()){ + while(negSamplesDone < negSamplesPerEntity){ + Cell cell = genNegSample(busWord, en, relation); + negSamplesDone++; + negSamples.add(cell); + } + negSamplesDone = 0; + } + + } + else if(relation.equals("u-word")){ + for(String en : userWord.keySet()){ + while(negSamplesDone < negSamplesPerEntity){ + Cell cell = genNegSample(userWord, en, relation); + negSamplesDone++; + negSamples.add(cell); + } + negSamplesDone = 0; + } + } + return negSamples; + } + + public Cell genNegSample(Map> enWord, String en, String relation){ + Cell cell = new Cell(); + cell.relation_id = relation; + boolean found = false; + while(!found){ + int pos = randInt(0, words.size() - 1); + if(!enWord.get(en).contains(words.get(pos))){ + cell.entity_ids.add(en); + cell.entity_ids.add(words.get(pos)); + cell.truth = false; + found = true; + } + } + return cell; + } + + public static int randInt(int min, int max) { + // nextInt is normally exclusive of the top value, + // so add 1 to make it inclusive + int randomNum = seed.nextInt((max - min) + 1) + min; + + return randomNum; + } + +} diff --git a/Project/src/logisticCMF/embedding.java b/Project/src/logisticCMF/embedding.java new file mode 100644 index 0000000..91f38b5 --- /dev/null +++ b/Project/src/logisticCMF/embedding.java @@ -0,0 +1,29 @@ +package logisticCMF; +import java.util.*; + +/* + * Embedding Class - Stores embedding for each entity which includes : + * Bias for each relation the entity belongs to + * K- Dimensional Latent Vector for each entity +*/ + +public class embedding { + Map bias; // Map : [Relation_id, bias] + double[] vector; // K-Dimensional Latent Vector + static Random rand = new Random(20); + + // For entity realized first time. Put Bias = 0.0 for relation and initialize latent vector to random + public embedding(int K){ + bias = new HashMap(); + vector = new double[K]; + for(int k=0; k embs; // Map[Entity_Id, Embedding] + Map alpha; // Map [Relation_Id, Alpha (matrix mean) ] + HashSet relIds; // Set of relation Ids + HashSet entityIds; // Set of Entity Ids + int K ; + static Random rand = new Random(); + + // Instantiating Embeddings for Data + public embeddings(data YData, int lK){ + K = lK; + alpha = new HashMap(); + embs = new HashMap(); + for(Cell cell : YData.Data){ + for(String entityId : cell.entity_ids){ + if(!embs.containsKey(entityId)) + embs.put(entityId, new embedding(K)); + + embs.get(entityId).addRelation(cell.relation_id); + } + } + System.out.println("Unique Entites in Database : " + embs.keySet().size()); + computeAlpha(YData); + } + + public embeddings(embeddings e){ + this.alpha = e.alpha; + this.embs = e.embs; + this.relIds = e.relIds; + this.entityIds = e.entityIds; + this.K = e.K; + } + + //Compute Alpha - Relation wise mean of values + public void computeAlpha(data D){ + Map relSum = new HashMap(); // Sum of truth values in each relation + Map relCount = new HashMap(); // Count of truth values in each relation + + for(Cell cell : D.trainData){ + if(!relSum.containsKey(cell.relation_id)){ + if(cell.truth == true) + relSum.put(cell.relation_id, 1); + else + relSum.put(cell.relation_id, 0); + relCount.put(cell.relation_id, 1); + } + else{ + int sum = relSum.get(cell.relation_id); + int count = relCount.get(cell.relation_id); + if(cell.truth == true){ + sum += 1; + } + count++; + relSum.put(cell.relation_id, sum); + relCount.put(cell.relation_id, count); + } + + } + + for(String relId : relSum.keySet()){ + double a = ((double)relSum.get(relId))/relCount.get(relId); + a = Math.log((a / (1-a))); + //alpha.put(relId, a); + alpha.put(relId, rand.nextGaussian()*0.001); + } + + + } + + + // Dot Product of vectors in cell, but leave out one entity, For fixed k + public double coeffVector(Cell cell, String leaveEntity, int k){ + double result=1.0; + for(String entityId : cell.entity_ids){ + if(!entityId.equals(leaveEntity)) + result *= embs.get(entityId).vector[k]; + } + return result; + } + + // Dot Product of vectors in cell, but leave out one entity, For fixed k + public double dot(Cell cell, boolean enableBias, int K, boolean ealpha, boolean onlyAlpha){ + double result=0.0; + + if(ealpha) + result += alpha.get(cell.relation_id); + if(!onlyAlpha){ + if(enableBias){ + for(String entityId : cell.entity_ids){ + if(!embs.containsKey(entityId)) + System.out.println("Entity not found : " + entityId); + if(!embs.get(entityId).bias.containsKey(cell.relation_id)) + System.out.println("Relation Not found : " + cell.relation_id); + result += embs.get(entityId).bias.get(cell.relation_id); + + + } + } + + for(int k = 0; k trainData = new ArrayList(); + embeddings eBest = new embeddings(embedings); + + Map> evalMap = Eval.getEvalMap(Data, eBest, "test"); + Eval.printEval(); + while(notConverged){ + + trainData.clear(); + trainData.addAll(Data.trainData); + System.out.println("Train Data Original : " + trainData.size()); + if(busWord){ + ArrayList negSamples = Data.getNegativeSamples("b-word", busWordNegSamSize); + trainData.addAll(negSamples); + } + if(userWord){ + ArrayList negSamples = Data.getNegativeSamples("u-word", userWordNegSamSize); + trainData.addAll(negSamples); + } + System.out.println("trainData : " + trainData.size()); + codeTest.getMemoryDetails(); + System.gc(); + Collections.shuffle(trainData, seed); // Shuffle List of Training Data before each iteration of learning parameters + for(Cell cell : trainData){ + update(cell, enableBias, embedings, ealpha, onlyAlpha); + } + //System.out.println("Train Data size :" + trainData.size()); + epoch++; + System.out.print(epoch + " "); + System.gc(); + if(epoch%5 == 0){ + System.out.println("################## Epoch : " + epoch + " ############"); + evalMap = Eval.getEvalMap(Data, embedings, "validation"); + Eval.printEval(); + double wf1 = evalMap.get("average").get(3), wacc = evalMap.get("average").get(0); + if(epoch == 5){ + //maxF1 = wf1; maxAcc = wacc; + maxF1 = 0.0; maxAcc = 0.0; + dropfor = 0; nochange = 0; + bestEpoch = epoch; + eBest = new embeddings(embedings); + } + else{ + //System.out.println("dF : " + dropfor + " nc : " + nochange + " maxF1 : " + maxF1 + " maxAcc : " + maxAcc); + if(wf1 > maxF1 || wacc > maxAcc){ + bestEpoch = epoch; + eBest = new embeddings(embedings); + maxF1 = wf1; + maxAcc = wacc; + dropfor = 0; nochange = 0; + } + else{ + if(wf1 == maxF1 || wacc == maxAcc) + nochange++; + else + dropfor++; + } + } + + } + if(dropfor >= 3 || nochange >= 4) /// CONDITIONS for STOPPING + notConverged = false; + } + System.out.println("TRAINING CONVEREGED, BEST EPOCH = " + bestEpoch); + return eBest; + } + + +} diff --git a/Project/src/logisticCMF/writeDataToFile.java b/Project/src/logisticCMF/writeDataToFile.java new file mode 100644 index 0000000..158173a --- /dev/null +++ b/Project/src/logisticCMF/writeDataToFile.java @@ -0,0 +1,33 @@ +package logisticCMF; + +import java.util.*; +import java.io.*; + +public class writeDataToFile { + public static void writePrediction(String fileName, String folder, data Data, embeddings e) throws IOException{ + String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/pred-data/" + fileName; + BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress)); + for(Cell cell : Data.testData){ + double dot = e.dot(cell, learner.enableBias, e.K, learner.ealpha, learner.onlyAlpha); + double sigmdot = learner.sigm(dot); + int truth = (cell.truth) ? 1 : 0; + String e1 = cell.entity_ids.get(0), e2 = cell.entity_ids.get(1); + bw.write(e1 + " :: " + e2 + " :: " + sigmdot + " :: " + truth +"\n") ; + } + bw.close(); + } + + public static void writeEmbeddings(String folder, String fileName, data Data, int entityNumber, embeddings e) throws IOException{ + String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/embeddings/" + fileName; + BufferedWriter bw = new BufferedWriter(new FileWriter(fileAddress)); + for(String entity : Data.entityIds.get(entityNumber)){ + bw.write(entity + " :: "); + embedding em = e.embs.get(entity); + for(int i=0; i entityVector; + + public EntityEmbeddings(String folder, String fileName, int K) throws IOException{ + entityVector = new HashMap(); + readEmbeddings(folder, fileName, K); + } + + public void readEmbeddings(String folder, String fileName, int K) throws IOException{ + String fileAddress = System.getProperty("user.dir")+"/../Embeddings_Prediction_Data/"+ folder +"/embeddings/" + fileName; + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + String line; + while((line = br.readLine()) != null){ + String[] array = line.split("::"); + + // Reading entity from embeddings file and creating + String entity = array[0].trim(); + if(!entityVector.containsKey(entity)) + entityVector.put(entity, new double[K]); + else + System.out.println("Duplicate Entity. ERROR !!!!!!"); + + // Reading latent vector and storing in Map + String[] vec = array[1].trim().split(","); + for(int i=0; i> KNN(EntityEmbeddings a, EntityEmbeddings b, int K){ + Map> kNN = new HashMap>(); + for(String en1 : a.entityVector.keySet()){ + double [] vec1 = a.entityVector.get(en1); + Map distances = new HashMap(); + for(String en2 : b.entityVector.keySet()){ + double[] vec2 = b.entityVector.get(en2); + distances.put(en2, Util.dotProd(vec1, vec2)); + } + + kNN.put(en1, Util.getKNN(distances, K)); + } + + return kNN; + } + + public Map> getSimilarity(EntityEmbeddings ee1, EntityEmbeddings ee2){ + Map> simMap = KNN(ee1, ee2, 20); + + for(String en : simMap.keySet()){ + System.out.print(en + " : "); + for(String nn : simMap.get(en)){ + System.out.print(nn + ", "); + } + System.out.println(); + } + + return simMap; + } + + + public static void main(String [] args) throws IOException { + String folder = "AZ"; + String evaluation = "HeldOut"; + System.out.println("Start"); + + EntityEmbeddings attributes = new EntityEmbeddings(folder, evaluation+"/"+"attributes-bw", 30); + EntityEmbeddings words = new EntityEmbeddings(folder, evaluation+"/"+"words-bw", 30); + EntityEmbeddings categories = new EntityEmbeddings(folder, evaluation+"/"+"categories-bw", 30); + EntityEmbeddings business = new EntityEmbeddings(folder, evaluation+"/"+"business-bw", 30); + + Similarity s = new Similarity(); + s.getSimilarity(categories, words); + + + + + } + +} diff --git a/Project/src/postProcessing/Util.java b/Project/src/postProcessing/Util.java new file mode 100644 index 0000000..5ce7190 --- /dev/null +++ b/Project/src/postProcessing/Util.java @@ -0,0 +1,62 @@ +package postProcessing; +import java.util.*; +import java.util.Map.Entry; + +public class Util { + + public static double sigm(double x){ + return 1.0 / (1.0 + Math.exp(-x)); + } + + public static double norm(double [] vec){ + double norm = 0.0; + for(int i=0; i getKNN(Map map, int K){ + ArrayList top = new ArrayList(); + Set> set = map.entrySet(); + List> list = new ArrayList>(set); + Collections.sort( list, new Comparator>() + { + public int compare( Map.Entry o1, Map.Entry o2 ) + { + return (o2.getValue()).compareTo( o1.getValue() ); + } + } ); + + for(int i=0; i entry : list) + System.out.println(entry.getKey() + " ==== " + entry.getValue());*/ + } + + public static void writeSelectEmbeddingsToFile(String folder, String selectionEmbeddings, Map> simMap){ + + } + +} diff --git a/Project/src/yelpDataProcessing/AttributeCategory.java b/Project/src/yelpDataProcessing/AttributeCategory.java new file mode 100644 index 0000000..25147de --- /dev/null +++ b/Project/src/yelpDataProcessing/AttributeCategory.java @@ -0,0 +1,311 @@ +package yelpDataProcessing; + +import java.util.*; +import java.io.*; + +import javax.json.*; +import javax.json.spi.*; + +/* + * Reads the yelp_dataset_restaurant json file and creates + * - resAtt.txt - The restaurant-attribute data that can directly be factorized + * - resCat.txt - The retaurant-category data that can be factorized but needs negative data + */ + + +public class AttributeCategory { + + Map> attributes; // Map [Attribute, SubAttribute] + Set categories; + Map catCount; // Map [Category, Occurrence Count] + Map busCatCount; // Map [Restaurant, No. of Category Count] + Set catReduced; // Set [Categories] - Thresholded on terms of occurrence + Map> busCat; // Map [Res, List[Categories]] + + + public void printCategories() { + for(Object ob : categories){ + System.out.println(ob); + } + } + + public void printAttributes() { + for(String att : attributes.keySet()){ + System.out.print(att + " : "); + for(String subatt : attributes.get(att).keySet()) + System.out.print(subatt+", "); + System.out.print("\n"); + } + System.out.println("Total Attributes : " + attributes.keySet().size()); + } + + private void buildCategorySet(String folder) throws IOException { + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + categories = new HashSet(); + catCount = new HashMap(); + busCatCount = new HashMap(); + busCat = new HashMap>(); + + String line; + int count = 0; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"business\"")){ + if(object.get("categories").getValueType().toString().equals("ARRAY")){ + JsonArray cat = (JsonArray) object.get("categories"); + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + busCatCount.put(bus_id.getString(), cat.size()); + busCat.put(bus_id.getString(), new ArrayList()); + for(JsonValue s : cat){ + JsonString c = (JsonString) s; + String category = c.getString(); + categories.add(category); + busCat.get(bus_id.getString()).add(category); + if(!catCount.containsKey(category)) + catCount.put(category, 1); + else + catCount.put(category, catCount.get(category)+1 ); + } + } + count++; + } + } + //return categories; + System.out.println("Businesses Read : " + count); + } + + private void buildThresholdCatSet(int thresh){ + catReduced = new HashSet(); + for(String cat : catCount.keySet()){ + if(catCount.get(cat) >= thresh){ + catReduced.add(cat); + } + } + } + + /* Reads the Yelp JSON Dataset and makes attributes hashmap */ + private void readAttributes(String folder) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + attributes = new HashMap>(); + String line; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + //if(object.containsKey("attributes")){ + if(object.get("type").toString().equals("\"business\"")){ + if(object.get("attributes").getValueType().toString() == "OBJECT"){ + JsonObject attributeObject = (JsonObject) object.getJsonObject("attributes"); + for(String key : attributeObject.keySet()){ + if(!attributes.containsKey(key)){ // To Make Attributes Keys Set + attributes.put(key, new HashMap()); + getValueSet(attributeObject.get(key), key, attributes); + } + else + getValueSet(attributeObject.get(key), key, attributes); + } + } + } + } + br.close(); + } + + /* Creates a dataset with business Id and values for attributes and stores them in a file. Uses the attributes hashmap. */ + private void buildBusiness_AttributeDataset(String folder) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/data/" + folder + "/busAtt.txt")); + String line; + int count = 0; + + while(( (line = br.readLine()) != null) && count < 42151 ){ + StringBuilder str = new StringBuilder(); + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + bw.write("business_id : "+bus_id.getString()+"\n"); + + if(object.get("attributes").getValueType().toString() == "OBJECT"){ + JsonObject attributeObject = (JsonObject) object.getJsonObject("attributes"); + navigateObjectForValues(attributeObject, null, null, str); + //System.out.println(str.toString()); + bw.write(str.toString()+"\n\n"); + + } + count++; + } + System.out.println("No. of Business : "+count); + br.close(); + bw.close(); + } + + /* */ + public void navigateObjectForValues(JsonValue tree, String key, String prevKey, StringBuilder str) { + switch(tree.getValueType()) { + case OBJECT: + //System.out.println("OBJECT"); + JsonObject object = (JsonObject) tree; + for (String name : object.keySet()) + navigateObjectForValues(object.get(name), name, key, str); + break; + case ARRAY: + break; + case STRING: + JsonString st = (JsonString) tree; + if (key!= null){ + if(prevKey != null){ + //System.out.println(prevKey+"_" + key + "_"+st.getString()+" : "+1); + str.append(prevKey+"_" + key + "_"+st.getString()+" : "+1 + "\n"); + if(attributes.get(key).keySet().size() != 0) + for(String s: attributes.get(key).keySet()) + if(!s.equals(st.getString())) + //System.out.println(prevKey+"_" + key + "_"+s+" : "+0); + str.append(prevKey+"_" + key + "_"+s+" : "+0+"\n"); + } + else{ + //System.out.println(key + "_"+st.getString()+" : "+1); + str.append(key + "_"+st.getString()+" : "+1+"\n"); + + if(attributes.get(key).keySet().size() != 0) + for(String s: attributes.get(key).keySet()) + if(!s.equals(st.getString())) + //System.out.println(key + "_"+s+" : "+0); + str.append(key + "_"+s+" : "+0+"\n"); + } + } + + + break; + case NUMBER: + /*if (key!= null) + if(prevKey != null) + System.out.print(prevKey+"_" + key + " : "); + else + System.out.print(key + " : "); + JsonNumber num = (JsonNumber) tree; + System.out.println(num.toString()); + */ + break; + case TRUE: + if (key!= null) + if(prevKey != null) + //System.out.print(prevKey+"_" + key + " : "); + str.append(prevKey+"_" + key + " : "+1+"\n"); + else +// //System.out.print(key + " : " + 1); + str.append(key + " : " + 1+"\n"); + //System.out.println(1); + break; + case FALSE: + case NULL: + if (key!= null) + if(prevKey != null) + //System.out.print(prevKey+"_" + key + " : "); + str.append(prevKey+"_"+key + " : " + 0+"\n"); + else + //System.out.print(key + " : "); + str.append(key + " : " + 0+"\n"); + //System.out.println(0); + break; + } + } + + private void getValueSet(JsonValue attribute, String attributeName, Map> attributes){ + switch(attribute.getValueType()){ + + case STRING: + JsonString st = (JsonString) attribute; + if(!attributes.get(attributeName).containsKey(st)) + attributes.get(attributeName).put(st.getString(), 1); + break; + + case NUMBER: + if(!attributes.get(attributeName).containsKey("NUMBER")) + attributes.get(attributeName).put("NUMBER", 1); + break; + case TRUE: + if(!attributes.get(attributeName).containsKey("TRUE")) + attributes.get(attributeName).put("TRUE", 1); + break; + case FALSE: + if(!attributes.get(attributeName).containsKey("FALSE")) + attributes.get(attributeName).put("FALSE", 1); + break; + case OBJECT: + JsonObject att = (JsonObject) attribute; + for(String subAttributeName : att.keySet()) + if(!attributes.get(attributeName).containsKey(subAttributeName)) + attributes.get(attributeName).put(subAttributeName, 1); + break; + } + } + + // Writes Restaurant : Category1, Category2, ... - to file resCat.txt + private void writeResCatToFile(String folder) throws IOException{ + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/data/" + folder +"/busCat.txt")); + for(String res : busCat.keySet()){ + bw.write(res + " : "); + for(String cat : busCat.get(res)) + bw.write(cat + "; "); + bw.write("\n"); + } + bw.close(); + } + + private void makeCitySet(String dataset) throws IOException { + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json"+dataset)); + String line; int count = 0; + Map states = new HashMap(); + while(( (line = br.readLine()) != null) && count < 42151){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + + JsonString city = (JsonString) object.get("state"); + String ci = city.getString(); + if(!states.containsKey(ci)) + states.put(ci, 0); + int sc = states.get(ci); + sc++; + states.put(ci, sc); + + count++; + } + + for(String state : states.keySet()){ + System.out.println(state + " : " + states.get(state)); + } + System.out.println(states.keySet().size()); + } + + + public static void main(String[] args) throws Exception{ + System.out.println("Hello"); + String State = "NV"; + + AttributeCategory data = new AttributeCategory(); + + + // Read file for attributes and write to a file + data.readAttributes(State); + //data.printAttributes(); + //data.buildBusiness_AttributeDataset(State); + + + data.buildCategorySet(State); + data.printCategories(); + //data.writeResCatToFile(State); + + int c = 0; + + + //data.makeCitySet("yelp_dataset"); + + + + + + } + + +} diff --git a/Project/src/yelpDataProcessing/ProcessYelpJson.java b/Project/src/yelpDataProcessing/ProcessYelpJson.java new file mode 100644 index 0000000..94f52ad --- /dev/null +++ b/Project/src/yelpDataProcessing/ProcessYelpJson.java @@ -0,0 +1,220 @@ +package yelpDataProcessing; + +import java.io.*; +import java.util.*; + +import javax.json.Json; +import javax.json.JsonArray; +import javax.json.JsonNumber; +import javax.json.JsonObject; +import javax.json.JsonReader; +import javax.json.JsonString; +import javax.json.JsonValue; +import javax.json.stream.JsonParsingException; + + +/* + * Processes Yelp Dataset Json and creates different files + * - yelp_business - Business dataset in json format. + * - yelp_dataset_retaurant - Restaurant Json data from Yelp + * - + * + * + */ + +public class ProcessYelpJson { + + Set busIds = new HashSet(); + + // Create file - yelp_reviews - in dataset/json + public void createCompleteReviewJson(String yelpDataset) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+yelpDataset)); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/complete/reviews")); + String line; + int count = 0; int cr = 0; + while( ((line = br.readLine()) != null) && count < 1199227){ + if(count <= 73770) + count++; + else{ + cr++; + bw.write(line+"\n"); + count++; + } + } + bw.close(); + br.close(); + System.out.println("Reviews : " + cr); + + + } + + // Creates file - yelp_business - in dataset/json folder + public void createCompleteBusinessJson(String yelpDataset) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+yelpDataset)); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/complete/business")); + String line; int count = 0; + + while(( (line = br.readLine()) != null) && count < 42151 ){ + bw.write(line + "\n"); + count++; + } + bw.close(); + System.out.println("No. of Businesses : " + count); + + } + + public void createStateBusinessJson(String folder, String state) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/"+state+"/business")); + + String line; + int count = 0; int cr = 0; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"business\"")){ + JsonValue s = object.get("state"); + JsonString st = (JsonString) s; + String place = st.getString(); + if(place.equals(state)){ + bw.write(line + "\n"); + cr++; + } + } + } + bw.close(); + System.out.println("No. of Businesses in "+state + " : " + cr); + } + + + // Creates file - yelp_dataset_restaurant - in dataset/json folder + public void createRestaurantJson(String folder) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/Restaurant/business")); + + String line; + int count = 0; int cr = 0; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"business\"")){ + if(object.get("categories").getValueType().toString().equals("ARRAY")){ + JsonArray cat = (JsonArray) object.get("categories"); + for(JsonValue s : cat){ + JsonString c = (JsonString) s; + String category = c.getString(); + if(category.equals("Restaurants")){ + bw.write(line + "\n"); + cr ++; + } + } + } + } + } + bw.close(); + System.out.println("No. of Restaurants : " + cr); + } + + + // Create file - yelp_reviews_restaurant - in dataset/json + public void createBusReviewJson(String folder_complete, String folder) throws IOException{ + makeBusIdSet(folder); + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder_complete+"/reviews")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/reviews")); + int count = 0; + String line; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"review\"")){ + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + String bid = bus_id.getString(); + if(busIds.contains(bid)){ + bw.write(line+"\n"); + count++; + } + } + } + bw.close(); + System.out.println("Reviews Count : " + count); + + } + + // Make a Set of Res-Ids to extract reviews + public void makeBusIdSet(String folder) throws IOException{ + busIds = new HashSet(); + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/business")); + String line; + while(( (line = br.readLine()) != null)){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"business\"")){ + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + String bid = bus_id.getString(); + busIds.add(bid); + } + } + System.out.println("Size of resIds set :" + busIds.size()); + } + + public static void putReviewDatatoFile(String folder) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir")+"/../Dataset/json/"+folder+"/reviews")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/../Dataset/data/"+folder+"/reviews.txt")); + String line; + int count = 0; + while((line = br.readLine()) != null){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + String bId = bus_id.getString(); + bw.write("bus_id : " + bId + "\n"); + + JsonValue u_id = object.get("user_id"); + JsonString user_id = (JsonString) u_id; + String uId = user_id.getString(); + bw.write("user_id : " + uId + "\n"); + + + JsonValue star = object.get("stars"); + JsonNumber s = (JsonNumber) star; + Double st = s.doubleValue(); + bw.write("star : " + st + "\n"); + + JsonValue t = object.get("text"); + JsonString te = (JsonString) t; + String text = te.getString(); + bw.write("text: " + t + "\n\n"); + + count++; + } + br.close(); + bw.close(); + System.out.println("Reviews Written : " + count); + } + + public static void main(String [] args) throws Exception{ + String yelpDataset = "yelp_dataset"; + String State = "NV"; + + ProcessYelpJson yelp = new ProcessYelpJson(); + + + //yelp.createCompleteBusinessJson(yelpDataset); + + //yelp.createRestaurantJson("complete"); + yelp.createStateBusinessJson("complete", State); + + //yelp.createCompleteReviewJson(yelpDataset); + + yelp.createBusReviewJson("complete", State); + + yelp.putReviewDatatoFile(State); + + + + } +} diff --git a/Project/src/yelpDataProcessing/reviewData.java b/Project/src/yelpDataProcessing/reviewData.java new file mode 100644 index 0000000..f9ac3e0 --- /dev/null +++ b/Project/src/yelpDataProcessing/reviewData.java @@ -0,0 +1,141 @@ +package yelpDataProcessing; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.*; + +import logisticCMF.Cell; + + +public class reviewData { + + public static Map wordCount = new HashMap(); + public static Map> resWord = new HashMap>(); + public static Map> userWord = new HashMap>(); + public static HashSet words = new HashSet(); + + public static void addWordInMap(String word){ + if(!wordCount.containsKey(word)) + wordCount.put(word, 0); + int wcount = wordCount.get(word); + wcount++; + wordCount.put(word, wcount); + } + + public static void readData(String reviewData) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(reviewData)); + String line; String bId = null; String userId = null; int countl = 0, countR = 0; + + while((line = br.readLine()) != null){ + countl++; + String[] array = line.split(":"); + + if( array[0].trim().equals("user_id")){ + countR++; + userId = array[1].trim(); + if(!userWord.containsKey(userId)) + userWord.put(userId, new HashSet()); + } + + if( array[0].trim().equals("bus_id")){ + bId = array[1].trim(); + if(!resWord.containsKey(bId)) + resWord.put(bId, new HashSet()); + } + + if(array[0].trim().equals("text")){ + String [] tokens = array[1].trim().split(" "); + if(tokens.length > 0){ + for(String word : tokens){ + word = word.trim(); + if(word.length() >= 3){ + addWordInMap(word); + resWord.get(bId).add(word); + userWord.get(userId).add(word); + } + } + } + } + if(countl % 100000 == 0) + System.out.println("line : "+countl); + } + System.out.println("Total No. of Reviews : " + countR); + } + + public static void getMapStats(Map> enWord){ + int potentialResWordCells = 0; + int min=100000000; + System.out.println("No. of entities = " + enWord.keySet().size()); + System.out.println("No. of total words in Vocab = " + words.size()); + + for(String en : enWord.keySet()){ + min = (enWord.get(en).size() < min) ? enWord.get(en).size() : min; + for(String word : enWord.get(en)){ + potentialResWordCells++; + } + + } + System.out.println("Potential Entity-Word Cells : " + potentialResWordCells); + System.out.println("Min No. of Words in Entity : " + min); + } + + // Remove words from Map[Entity -> Set[words]] that occur few times in dictionary. If Set of words for entity go empty, remove Entity from Map. + public static void pruneEntityWordMap(Map> enWord){ + Iterator it = enWord.keySet().iterator(); + while(it.hasNext()){ + String en = it.next(); + Iterator itr = enWord.get(en).iterator(); + while(itr.hasNext()){ + String word = itr.next(); + if(!words.contains(word)) + itr.remove(); + } + if(enWord.get(en).size() == 0) + it.remove(); + } + } + + + // Make a Array of Words that have frequency above the given threshold. + public static void makePrunedWordList(int occThresh){ + words = new HashSet(); int count = 0; + for(String word : wordCount.keySet()){ + if(wordCount.get(word) > occThresh){ + count++; + words.add(word); + } + } + System.out.println("Words with greater than occurence of " + occThresh + " : " + words.size()); + } + + public static int countLines(String file) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(file)); + String line; int count = 0; + while((line = br.readLine()) != null){ + count++; + } + return count; + } + + + public static void main(String [] args) throws IOException{ + + String reviewData = System.getProperty("user.dir")+"/../Dataset/data/ON/reviews_textProc.txt"; + readData(reviewData); + + int occThresh = 1; + System.out.println("Total Words in Review Data : " + wordCount.keySet().size()); + int count = 0; + + /*for(occThresh = 0; occThresh <= 50; occThresh++){ + makePrunedWordList(occThresh); + }*/ + + makePrunedWordList(4); + pruneEntityWordMap(resWord); + getMapStats(resWord); + + //pruneEntityWordMap(userWord, occThresh); + //getMapStats(userWord); + } +} diff --git a/Project/src/yelpDataProcessing/reviewJson.java b/Project/src/yelpDataProcessing/reviewJson.java new file mode 100644 index 0000000..8add377 --- /dev/null +++ b/Project/src/yelpDataProcessing/reviewJson.java @@ -0,0 +1,131 @@ +package yelpDataProcessing; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.StringReader; +import java.util.*; + +import javax.json.Json; +import javax.json.JsonArray; +import javax.json.JsonNumber; +import javax.json.JsonObject; +import javax.json.JsonReader; +import javax.json.JsonString; +import javax.json.JsonValue; + +import logisticCMF.Cell; + +public class reviewJson { + + public static HashSet resIds = new HashSet(); // Set - [Restaurant Ids] + + // Reads Yelp Dataset Json - extracts only Review Jsons + public static void extractReviewJson(String fileAddress) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/Data/json/yelp_reviews.txt")); + String line; + int count=1; + while( ((line = br.readLine()) != null) && count < 1199228){ + if(count <= 73770) + count++; + else{ + bw.write(line+"\n"); + count++; + } + } + bw.close(); + br.close(); + } + + // Makes resIds Set - Then extracts json objects for restaurant reviews and stores them in file + public static void extractResReviewJson() throws IOException{ + makeResIds(); + BufferedReader br = new BufferedReader(new FileReader(System.getProperty("user.dir") + "/Data/json/yelp_reviews.txt")); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/Data/json/yelp_reviews_restaurants.txt")); + String line; + int count = 0; + while(( (line = br.readLine()) != null) ){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + if(object.get("type").toString().equals("\"review\"")){ + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + String bId = bus_id.getString(); + if(resIds.contains(bId)){ + bw.write(line.trim() + "\n"); + count++; + } + } + } + + System.out.println("Count : " + count); + br.close(); + bw.close(); + } + + // Reads file that contains Restaurant Ids and stores in Set - resIds + public static void makeResIds() throws IOException{ + String fA = System.getProperty("user.dir") + "/Data/new/restaurant_ids"; + BufferedReader br = new BufferedReader(new FileReader(fA)); + String line; + while((line = br.readLine()) != null){ + resIds.add(line.trim()); + } + br.close(); + System.out.println("No. of res : " + resIds.size()); + } + + public static void putReviewDatatoFile(String fileAddress ) throws IOException{ + BufferedReader br = new BufferedReader(new FileReader(fileAddress)); + BufferedWriter bw = new BufferedWriter(new FileWriter(System.getProperty("user.dir") + "/../Data/new/res_review_data.txt")); + String line; + int count = 0; String relation = "restaurant-user"; + while((line = br.readLine()) != null){ + JsonReader reader = Json.createReader(new StringReader(line)); + JsonObject object = reader.readObject(); + + JsonValue b_id = object.get("business_id"); + JsonString bus_id = (JsonString) b_id; + String bId = bus_id.getString(); + bw.write("bus_id : " + bId + "\n"); + + JsonValue u_id = object.get("user_id"); + JsonString user_id = (JsonString) u_id; + String uId = user_id.getString(); + bw.write("user_id : " + uId + "\n"); + + + JsonValue star = object.get("stars"); + JsonNumber s = (JsonNumber) star; + Double st = s.doubleValue(); + bw.write("star : " + st + "\n"); + + JsonValue t = object.get("text"); + JsonString te = (JsonString) t; + String text = te.getString(); + bw.write("text: " + t + "\n\n"); + + count++; + } + br.close(); + bw.close(); + System.out.println(relation + " : " + count); + } + + + + public static void main(String [] args) throws IOException{ + //convertResAttLogisticReadableFile(System.getProperty("user.dir")+"/Data/yelp_dataset_restaurant_att"); + //String fileAddress = System.getProperty("user.dir") + "/../Dataset/yelp_dataset"; + String fileAddress = System.getProperty("user.dir") + "/../Data/json/yelp_reviews_restaurants.txt"; + //putReviewDatatoFile(fileAddress); + //extractReviewJson(fileAddress); + //extractResReviewJson(); + + } + +} diff --git a/PythonScript/clean_text.py b/PythonScript/clean_text.py new file mode 100644 index 0000000..c247d04 --- /dev/null +++ b/PythonScript/clean_text.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +## Inputs review file in format + +# businessId : asnhdjkasld +# userId : kjhdjkfadfjbasdbjfbasdb +# stars : 4.0 +# text : I was amazed with the quality of the food + +##### OUTPUTS the file in the same order, but tokenizes, removes punctuatuations, removes stop words and stems the words before outputing. Also each word in review occurs once in output. + + +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +import re +import string +from nltk.stem import PorterStemmer + +stemmer = PorterStemmer() + +def getTokens(doc): + doc = re.sub("\d+", "", doc) + + tokenized = word_tokenize(doc.decode('utf-8')) + + regex = re.compile('[%s]' % re.escape(string.punctuation)) #see documentation here: http://docs.python.org/2/library/string.html + + tokenized_no_punctuation = [] + + new_review = [] + + + ## REMOVING PUNCTUATION + #for token in tokenized: + # new_token = regex.sub(u'', token.lower()) + # if not new_token == u'': + # tokenized_no_punctuation.append(new_token) + + tokenized_no_punctuation = [re.sub(r'[^A-Za-z0-9]+', '', x.lower()) for x in tokenized] + tokenized_no_punctuation = [s for s in tokenized_no_punctuation if (len(s)>1)] + + #print tokenized_no_punctuation + + token_no_stop = [] + ## REMOVING STOP WORDS + for word in tokenized_no_punctuation: + if not word in stopwords.words('english'): + try: + word = stemmer.stem(word.encode('utf-8')) + except UnicodeDecodeError: + word = word #.encode('utf-8') + token_no_stop.append(word.encode('utf-8')) + + + return token_no_stop + + + +fin = open('reviews.txt', 'r') +fout = open('reviews_textProc.txt', 'w') +count = 0 +for line in fin: +# print count + if(count %100000 == 0): + print count + if(line.strip().split(':')[0] == 'text'): + tokens = getTokens(line.strip().split(':')[1]) + tokenSet = set() + for i in tokens: + if (len(i) >= 3): + tokenSet.add(i) + fout.write("text : ") + fout.write(" ".join(tokenSet)) + fout.write("\n\n") + else: + fout.write(line) + count = count + 1 diff --git a/PythonScript/clean_text.py~ b/PythonScript/clean_text.py~ new file mode 100644 index 0000000..c247d04 --- /dev/null +++ b/PythonScript/clean_text.py~ @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +## Inputs review file in format + +# businessId : asnhdjkasld +# userId : kjhdjkfadfjbasdbjfbasdb +# stars : 4.0 +# text : I was amazed with the quality of the food + +##### OUTPUTS the file in the same order, but tokenizes, removes punctuatuations, removes stop words and stems the words before outputing. Also each word in review occurs once in output. + + +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +import re +import string +from nltk.stem import PorterStemmer + +stemmer = PorterStemmer() + +def getTokens(doc): + doc = re.sub("\d+", "", doc) + + tokenized = word_tokenize(doc.decode('utf-8')) + + regex = re.compile('[%s]' % re.escape(string.punctuation)) #see documentation here: http://docs.python.org/2/library/string.html + + tokenized_no_punctuation = [] + + new_review = [] + + + ## REMOVING PUNCTUATION + #for token in tokenized: + # new_token = regex.sub(u'', token.lower()) + # if not new_token == u'': + # tokenized_no_punctuation.append(new_token) + + tokenized_no_punctuation = [re.sub(r'[^A-Za-z0-9]+', '', x.lower()) for x in tokenized] + tokenized_no_punctuation = [s for s in tokenized_no_punctuation if (len(s)>1)] + + #print tokenized_no_punctuation + + token_no_stop = [] + ## REMOVING STOP WORDS + for word in tokenized_no_punctuation: + if not word in stopwords.words('english'): + try: + word = stemmer.stem(word.encode('utf-8')) + except UnicodeDecodeError: + word = word #.encode('utf-8') + token_no_stop.append(word.encode('utf-8')) + + + return token_no_stop + + + +fin = open('reviews.txt', 'r') +fout = open('reviews_textProc.txt', 'w') +count = 0 +for line in fin: +# print count + if(count %100000 == 0): + print count + if(line.strip().split(':')[0] == 'text'): + tokens = getTokens(line.strip().split(':')[1]) + tokenSet = set() + for i in tokens: + if (len(i) >= 3): + tokenSet.add(i) + fout.write("text : ") + fout.write(" ".join(tokenSet)) + fout.write("\n\n") + else: + fout.write(line) + count = count + 1 diff --git a/PythonScript/combineAllPredData.py b/PythonScript/combineAllPredData.py new file mode 100644 index 0000000..e652ecb --- /dev/null +++ b/PythonScript/combineAllPredData.py @@ -0,0 +1,29 @@ +# From Logisitc CMF : python PythonScript/combineAllPredData.py Embeddings_Prediction_Data HeldOut + +import sys +import os +from os import walk + +embeddingsPath = sys.argv[1]; +evToMerge = sys.argv[2]; +print embeddingsPath+"/All/pred-data/"+evToMerge +if not os.path.exists(embeddingsPath+"/All/pred-data/"+evToMerge): + os.makedirs(embeddingsPath+"/All/pred-data/"+evToMerge) + +foldersToMerge = ['AZ', 'NV', 'WI', 'EDH'] + +print embeddingsPath+"/WI/pred-data/"+evToMerge +files = [] +for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToMerge): + files.extend(filenames) + break + +print files; + +for folder in foldersToMerge: + for filename in files: + with open(embeddingsPath+"/All/pred-data/"+evToMerge+"/"+filename, 'a') as outfile: + with open(embeddingsPath+"/"+folder+"/pred-data/"+evToMerge+"/"+filename) as infile: + for line in infile: + outfile.write(line) + diff --git a/PythonScript/combineAllPredData.py~ b/PythonScript/combineAllPredData.py~ new file mode 100644 index 0000000..e652ecb --- /dev/null +++ b/PythonScript/combineAllPredData.py~ @@ -0,0 +1,29 @@ +# From Logisitc CMF : python PythonScript/combineAllPredData.py Embeddings_Prediction_Data HeldOut + +import sys +import os +from os import walk + +embeddingsPath = sys.argv[1]; +evToMerge = sys.argv[2]; +print embeddingsPath+"/All/pred-data/"+evToMerge +if not os.path.exists(embeddingsPath+"/All/pred-data/"+evToMerge): + os.makedirs(embeddingsPath+"/All/pred-data/"+evToMerge) + +foldersToMerge = ['AZ', 'NV', 'WI', 'EDH'] + +print embeddingsPath+"/WI/pred-data/"+evToMerge +files = [] +for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToMerge): + files.extend(filenames) + break + +print files; + +for folder in foldersToMerge: + for filename in files: + with open(embeddingsPath+"/All/pred-data/"+evToMerge+"/"+filename, 'a') as outfile: + with open(embeddingsPath+"/"+folder+"/pred-data/"+evToMerge+"/"+filename) as infile: + for line in infile: + outfile.write(line) + diff --git a/PythonScript/getPRCurveData.py b/PythonScript/getPRCurveData.py new file mode 100644 index 0000000..8f1ee6d --- /dev/null +++ b/PythonScript/getPRCurveData.py @@ -0,0 +1,28 @@ +# From Logistic_CMF : python PythonScript/getPRCurveData.py Embeddings_Prediction_Data All HeldOut + +from os import walk +import sys +import os + +embeddingsPath = sys.argv[1] +folderToTest = sys.argv[2] +evToTest = sys.argv[3] + +path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest + +files = [] +for (dirpath, dirnames, filenames) in walk(path): + files.extend(filenames) + break + +for fileName in files: + f = open(path+"/"+fileName, 'r') + if not os.path.exists(path+"/PRCurve"): + os.makedirs(path+"/PRCurve") + o = open(path+"/PRCurve/"+fileName, 'w') + for line in f: + line.strip(); + a = line.split("::") + o.write(a[2].strip() + "\t" + a[3].strip() + "\n"); + o.close(); + f.close(); diff --git a/PythonScript/getPRCurveData.py~ b/PythonScript/getPRCurveData.py~ new file mode 100644 index 0000000..8f1ee6d --- /dev/null +++ b/PythonScript/getPRCurveData.py~ @@ -0,0 +1,28 @@ +# From Logistic_CMF : python PythonScript/getPRCurveData.py Embeddings_Prediction_Data All HeldOut + +from os import walk +import sys +import os + +embeddingsPath = sys.argv[1] +folderToTest = sys.argv[2] +evToTest = sys.argv[3] + +path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest + +files = [] +for (dirpath, dirnames, filenames) in walk(path): + files.extend(filenames) + break + +for fileName in files: + f = open(path+"/"+fileName, 'r') + if not os.path.exists(path+"/PRCurve"): + os.makedirs(path+"/PRCurve") + o = open(path+"/PRCurve/"+fileName, 'w') + for line in f: + line.strip(); + a = line.split("::") + o.write(a[2].strip() + "\t" + a[3].strip() + "\n"); + o.close(); + f.close(); diff --git a/PythonScript/getPRF.py b/PythonScript/getPRF.py new file mode 100644 index 0000000..c51be2c --- /dev/null +++ b/PythonScript/getPRF.py @@ -0,0 +1,22 @@ +# From Logistic CMF : python PythonScript/getPRF.py Embeddings_Prediction_Data WI HeldOut + +import prf; +from os import walk +import sys + +embeddingsPath = sys.argv[1] +folderToTest = sys.argv[2] +evToTest = sys.argv[3] + +path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest +print path +fs = [] +for (dirpath, dirnames, filenames) in walk(path): + fs.extend(filenames) + break +fs.sort() +for fileName in fs: + print fileName + print prf.getPRF(path+"/"+fileName) + + diff --git a/PythonScript/getPRF.py~ b/PythonScript/getPRF.py~ new file mode 100644 index 0000000..c51be2c --- /dev/null +++ b/PythonScript/getPRF.py~ @@ -0,0 +1,22 @@ +# From Logistic CMF : python PythonScript/getPRF.py Embeddings_Prediction_Data WI HeldOut + +import prf; +from os import walk +import sys + +embeddingsPath = sys.argv[1] +folderToTest = sys.argv[2] +evToTest = sys.argv[3] + +path = embeddingsPath+"/"+folderToTest+"/pred-data/"+evToTest +print path +fs = [] +for (dirpath, dirnames, filenames) in walk(path): + fs.extend(filenames) + break +fs.sort() +for fileName in fs: + print fileName + print prf.getPRF(path+"/"+fileName) + + diff --git a/PythonScript/prf.py b/PythonScript/prf.py new file mode 100644 index 0000000..2373821 --- /dev/null +++ b/PythonScript/prf.py @@ -0,0 +1,39 @@ +## Call getPRF(fileName) to get [P, R, F] + +import numpy as np; +from sklearn.metrics import precision_recall_fscore_support as prf + +def readFile(fileName): + y_pred = [] + y_true = [] + y_pred_true = [] + f = open(fileName, 'r'); + for line in f: + line.strip(); + a = line.split("::") + if(float(a[2].strip()) >= 0.5): + pred = 1; + y_pred.append(pred); + else: + pred = 0; + y_pred.append(pred); + + true = float(a[3].strip()); + y_true.append(true); + y_pred_true.append(y_pred) + y_pred_true.append(y_true) + return y_pred_true + +def getPRF(fileName): + y_p_t = readFile(fileName) + y_pred = y_p_t[0]; + y_true = y_p_t[1]; + acc = prf(y_true, y_pred, average = 'micro'); + p = round(acc[0]*100, 1); + r = round(acc[1]*100, 1); + f = round(acc[2]*100, 1); + return np.array([p,r,f]) + +#fileName = "A-A" +#print getPRF(fileName) + diff --git a/PythonScript/prf.pyc b/PythonScript/prf.pyc new file mode 100644 index 0000000..c908aca Binary files /dev/null and b/PythonScript/prf.pyc differ diff --git a/PythonScript/prf.py~ b/PythonScript/prf.py~ new file mode 100644 index 0000000..2373821 --- /dev/null +++ b/PythonScript/prf.py~ @@ -0,0 +1,39 @@ +## Call getPRF(fileName) to get [P, R, F] + +import numpy as np; +from sklearn.metrics import precision_recall_fscore_support as prf + +def readFile(fileName): + y_pred = [] + y_true = [] + y_pred_true = [] + f = open(fileName, 'r'); + for line in f: + line.strip(); + a = line.split("::") + if(float(a[2].strip()) >= 0.5): + pred = 1; + y_pred.append(pred); + else: + pred = 0; + y_pred.append(pred); + + true = float(a[3].strip()); + y_true.append(true); + y_pred_true.append(y_pred) + y_pred_true.append(y_true) + return y_pred_true + +def getPRF(fileName): + y_p_t = readFile(fileName) + y_pred = y_p_t[0]; + y_true = y_p_t[1]; + acc = prf(y_true, y_pred, average = 'micro'); + p = round(acc[0]*100, 1); + r = round(acc[1]*100, 1); + f = round(acc[2]*100, 1); + return np.array([p,r,f]) + +#fileName = "A-A" +#print getPRF(fileName) + diff --git a/PythonScript/writePRFTable.py b/PythonScript/writePRFTable.py new file mode 100644 index 0000000..e89284e --- /dev/null +++ b/PythonScript/writePRFTable.py @@ -0,0 +1,36 @@ +# From Logisitc CMF : python PythonScript/writePRFTable.py Embeddings_Prediction_Data A HeldOut + +import sys +import os +from os import walk +import prf + +embeddingsPath = sys.argv[1] +relation = sys.argv[2] +evToTest = sys.argv[3] + +writePath = embeddingsPath+"/Tables/"+relation+"-"+evToTest +out = open(writePath, 'w') + +foldersToWrite = ['AZ', 'NV', 'WI', 'EDH', 'All'] +filesToWrite = [] +for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToTest): + for filename in filenames: + r = filename.split("-")[0].strip() + if r == relation: + filesToWrite.append(filename) + break + +for model in filesToWrite: + out.write("\\textbf{"+model.split("-")[1].strip()+"}\n") + for folder in foldersToWrite: + PRF = prf.getPRF(embeddingsPath+"/"+folder+"/pred-data/"+evToTest+"/"+model) + out.write("& " + str(PRF[0]) + "\t & " + str(PRF[1]) + "\t & " + str(PRF[2]) + "\n") + out.write("\\\\ \n") + + + + + + + diff --git a/PythonScript/writePRFTable.py~ b/PythonScript/writePRFTable.py~ new file mode 100644 index 0000000..e89284e --- /dev/null +++ b/PythonScript/writePRFTable.py~ @@ -0,0 +1,36 @@ +# From Logisitc CMF : python PythonScript/writePRFTable.py Embeddings_Prediction_Data A HeldOut + +import sys +import os +from os import walk +import prf + +embeddingsPath = sys.argv[1] +relation = sys.argv[2] +evToTest = sys.argv[3] + +writePath = embeddingsPath+"/Tables/"+relation+"-"+evToTest +out = open(writePath, 'w') + +foldersToWrite = ['AZ', 'NV', 'WI', 'EDH', 'All'] +filesToWrite = [] +for (dirpath, dirnames, filenames) in walk(embeddingsPath+"/WI/pred-data/"+evToTest): + for filename in filenames: + r = filename.split("-")[0].strip() + if r == relation: + filesToWrite.append(filename) + break + +for model in filesToWrite: + out.write("\\textbf{"+model.split("-")[1].strip()+"}\n") + for folder in foldersToWrite: + PRF = prf.getPRF(embeddingsPath+"/"+folder+"/pred-data/"+evToTest+"/"+model) + out.write("& " + str(PRF[0]) + "\t & " + str(PRF[1]) + "\t & " + str(PRF[2]) + "\n") + out.write("\\\\ \n") + + + + + + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ada5b5b --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# README # + +### What is this repository for? ### + +* Yelp Dataset Challenge + * Parse Yelp Json format data to get required data in the required format + * Perform logisitc CMF on the parsed data to predict relations. + +### Details on Logistic CMF Code ### + +* This Java project currently has 2 packages : + * yelpDataProcessing - Contains classes/functions to read the yelp dataset in json format and parse to get different data in required format. + * logisticCMF - Contains classes/functions to read data produced in required format and then split train/validation/test data. Learn the embeddings for entities and print prediction evaluation. + +* The folder PythonSCript contains a file cleantext.py that reads the yelp review data in user format and pre-processes the text review. + * The text pre-processing contains tokenization, stemming, removal of stop words and punctuations. + * Each word is kept only once if occurs multiple times in a review. + +#### Project Contributors #### +* Nitish Gupta +* Sameer Singh \ No newline at end of file