diff --git a/visualization/pom.xml b/visualization/pom.xml new file mode 100644 index 0000000..6796214 --- /dev/null +++ b/visualization/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + visualization + visualization + 1.0-SNAPSHOT + jar + + + UTF-8 + 1.7 + 2.10.3 + 2.10 + + + + src/main/scala + src/test/scala + + + net.alchim31.maven + scala-maven-plugin + 3.2.0 + + + + process-sources + + compile + testCompile + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + ${java.version} + ${java.version} + + + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.sameersingh.scalaplot + scalaplot + 0.0.4 + + + junit + junit + 4.11 + test + + + org.scalacheck + scalacheck_${scala.version.tools} + 1.11.4 + test + + + org.scalatest + scalatest_${scala.version.tools} + 2.2.0 + test + + + diff --git a/visualization/src/main/scala/visualization/BarCharts.scala b/visualization/src/main/scala/visualization/BarCharts.scala new file mode 100644 index 0000000..229f2ec --- /dev/null +++ b/visualization/src/main/scala/visualization/BarCharts.scala @@ -0,0 +1,69 @@ +package visualization + +import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter +import org.sameersingh.scalaplot.{MemBarSeries, BarData, BarChart} + +import scala.collection.mutable.ArrayBuffer + +/** + * @author sameer + * @since 12/26/14. + */ +class BarCharts { + + val attr = new ArrayBuffer[(String, Double)]() + attr += "A" -> 81.5 + attr += "A+C" -> 83.9 + attr += "A+R" -> 80.8 + attr += "A+BW" -> 83.5 + //attr += "A+UW" -> 81.5 + attr += "A+C+R" -> 83.4 + attr += "A+C+BW" -> 83.6 + attr += "A+C+UW" -> 83.9 + attr += "A+R+BW" -> 83.5 + attr += "A+R+UW" -> 80.7 + attr += "A+C+R+BW" -> 83.7 + attr += "A+C+R+UW" -> 83.3 + + val ratings = new ArrayBuffer[(String, Double)]() + ratings += "R" -> 71.3 + ratings += "R+C" -> 73.6 + ratings += "R+A" -> 72.3 + ratings += "R+BW" -> 72.4 + ratings += "R+UW" -> 79.2 + ratings += "R+A+C" -> 72.5 + ratings += "R+C+BW" -> 72.6 + ratings += "R+C+UW" -> 80.4 + ratings += "R+A+BW" -> 72.3 + ratings += "R+A+UW" -> 79.2 + ratings += "R+A+C+BW" -> 72.3 + ratings += "R+A+C+UW" -> 79.9 + + def attributesPlot: BarChart = { + val ss = new MemBarSeries(attr.map(_._2)) + val d = new BarData((i) => attr(i)._1, Seq(ss)) + val c = new BarChart("Heldout F1 (Attributes)", d) + c.y.label = "F1" + c.x.label = "Factorization Models" + c + } + + def ratingsPlot: BarChart = { + val ss = new MemBarSeries(ratings.map(_._2)) + val d = new BarData((i) => ratings(i)._1, Seq(ss)) + val c = new BarChart("Heldout F1 (Ratings)", d) + c.y.label = "F1" + c.x.label = "Factorization Models" + c + } +} + +object BarCharts { + def main(args: Array[String]): Unit = { + assert(args.length == 1, "Please include the directory to write output in as an argument.") + val outputDir = args(0) + val charts = new BarCharts + GnuplotPlotter.pdf(charts.attributesPlot, outputDir, "Attributes") + GnuplotPlotter.pdf(charts.ratingsPlot, outputDir, "Ratings") + } +} diff --git a/visualization/src/main/scala/visualization/PRCurves.scala b/visualization/src/main/scala/visualization/PRCurves.scala new file mode 100644 index 0000000..6fb35db --- /dev/null +++ b/visualization/src/main/scala/visualization/PRCurves.scala @@ -0,0 +1,105 @@ +package visualization + +import java.io.{FilenameFilter, File} + +import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter +import org.sameersingh.scalaplot._ + +import scala.collection.mutable.ArrayBuffer + +/** + * @author sameer + * @since 12/25/14. + */ +class PRCurves { + + val attrsOrder = Seq("A", "A+R", "A+BW", "A+C", "A+C+R", "A+C+BW") + //Excluded: "A+R+UW", "A+C+UW", "A+R+BW", "A+C+R+UW", "A+UW", "A+C+R+BW" + val ratingsOrder = Seq("R", "A+R", "R+UW", "R+BW", "R+C", "A+R+UW", "R+C+UW", "A+C+R+UW") + // Excluded: , "A+C+R+BW", "A+C+R", "R+C+BW", "A+R+BW" + + def readFile(file: File): Seq[(Double, Boolean)] = { + val source = io.Source.fromFile(file) + val result = new ArrayBuffer[(Double, Boolean)] + for (l <- source.getLines()) { + val split = l.split("\t") + assert(split.length == 2) + assert(split(1) == "0" || split(1) == "1") + result += split(0).toDouble -> (split(1) == "1") + } + source.close() + result //.grouped(10).map(_.head).toSeq + } + + def seriesFromFile(file: File): XYSeries = { + val data = readFile(file) + val m = new org.sameersingh.scalaplot.metrics.PrecRecallCurve(data) + val s = new MemXYSeries(m.prChart("").data.serieses.head.points.toSeq, file.getName.drop(2)) + s.every = Some(7500) + s.plotStyle = XYPlotStyle.LinesPoints + s + } + + def reorderSerieses(serieses: Seq[XYSeries], order: Seq[String]): Seq[XYSeries] = { + order.map(n => serieses.find(_.name == n).get) + } + + def generateAttrChart(files: Seq[File]): XYChart = { + val series = files.map(f => seriesFromFile(f)) + val data = new XYData(reorderSerieses(series, attrsOrder): _*) + val chart = new XYChart("PR Curve (Attributes)", data) + chart.x.label = "Recall" + chart.y.label = "Precision" + //chart.monochrome = true + chart.showLegend = true + chart.legendPosX = LegendPosX.Center + chart.legendPosY = LegendPosY.Bottom + chart.size = Some(3.5,2.5) + chart + } + + def generateRatingsChart(files: Seq[File]): XYChart = { + val series = files.map(f => seriesFromFile(f)) + val data = new XYData(reorderSerieses(series, ratingsOrder): _*) + val chart = new XYChart("PR Curve (Ratings)", data) + chart.x.label = "Recall" + chart.y.label = "Precision" + //chart.monochrome = true + chart.showLegend = true + chart.legendPosX = LegendPosX.Right + chart.legendPosY = LegendPosY.Top + chart.size = Some(3.5,2.5) + chart + } + + def generateCharts(dirName: String) = { + val dir = new File(dirName) + assert(dir.isDirectory, dirName + " is not a directory.") + // generate output dir + val outDirName = dirName + "/output/" + new File(outDirName).mkdir() + // Attribute PR Curve + val attrFiles = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = name.startsWith("A-") + }).toSeq.sortBy(_.getName) + println("attr: " + attrFiles.map(_.getName).mkString(", ")) + val attrChart = generateAttrChart(attrFiles) + GnuplotPlotter.pdf(attrChart, outDirName, dir.getName + "A") + // Ratings PR Curve + val ratingsFiles = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = name.startsWith("R-") + }).toSeq.sortBy(_.getName) + println("ratings: " + ratingsFiles.map(_.getName).mkString(", ")) + val ratingsChart = generateRatingsChart(ratingsFiles) + GnuplotPlotter.pdf(ratingsChart, outDirName, dir.getName + "R") + } +} + +object PRCurves { + def main(args: Array[String]): Unit = { + assert(args.length == 1, "Please include the directory containing the generated predictions as an argument.") + val inputDir = args(0) + val curves = new PRCurves + curves.generateCharts(args(0)) + } +}