Skip to content

Commit

Permalink
added pr curves and bar chart visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
sameersingh committed Dec 26, 2014
1 parent 304eac5 commit 633ce71
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 0 deletions.
79 changes: 79 additions & 0 deletions visualization/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>visualization</groupId>
<artifactId>visualization</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.7</java.version>
<scala.version>2.10.3</scala.version>
<scala.version.tools>2.10</scala.version.tools>
</properties>

<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<testSourceDirectory>src/test/scala</testSourceDirectory>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.2.0</version>
<executions>
<execution>
<!-- this is so we don't end with a compile error in maven-compiler-plugin -->
<phase>process-sources</phase>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
<configuration>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.1</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.sameersingh.scalaplot</groupId>
<artifactId>scalaplot</artifactId>
<version>0.0.4</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.version.tools}</artifactId>
<version>1.11.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version.tools}</artifactId>
<version>2.2.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
69 changes: 69 additions & 0 deletions visualization/src/main/scala/visualization/BarCharts.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package visualization

import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter
import org.sameersingh.scalaplot.{MemBarSeries, BarData, BarChart}

import scala.collection.mutable.ArrayBuffer

/**
* @author sameer
* @since 12/26/14.
*/
class BarCharts {

val attr = new ArrayBuffer[(String, Double)]()
attr += "A" -> 81.5
attr += "A+C" -> 83.9
attr += "A+R" -> 80.8
attr += "A+BW" -> 83.5
//attr += "A+UW" -> 81.5
attr += "A+C+R" -> 83.4
attr += "A+C+BW" -> 83.6
attr += "A+C+UW" -> 83.9
attr += "A+R+BW" -> 83.5
attr += "A+R+UW" -> 80.7
attr += "A+C+R+BW" -> 83.7
attr += "A+C+R+UW" -> 83.3

val ratings = new ArrayBuffer[(String, Double)]()
ratings += "R" -> 71.3
ratings += "R+C" -> 73.6
ratings += "R+A" -> 72.3
ratings += "R+BW" -> 72.4
ratings += "R+UW" -> 79.2
ratings += "R+A+C" -> 72.5
ratings += "R+C+BW" -> 72.6
ratings += "R+C+UW" -> 80.4
ratings += "R+A+BW" -> 72.3
ratings += "R+A+UW" -> 79.2
ratings += "R+A+C+BW" -> 72.3
ratings += "R+A+C+UW" -> 79.9

def attributesPlot: BarChart = {
val ss = new MemBarSeries(attr.map(_._2))
val d = new BarData((i) => attr(i)._1, Seq(ss))
val c = new BarChart("Heldout F1 (Attributes)", d)
c.y.label = "F1"
c.x.label = "Factorization Models"
c
}

def ratingsPlot: BarChart = {
val ss = new MemBarSeries(ratings.map(_._2))
val d = new BarData((i) => ratings(i)._1, Seq(ss))
val c = new BarChart("Heldout F1 (Ratings)", d)
c.y.label = "F1"
c.x.label = "Factorization Models"
c
}
}

object BarCharts {
def main(args: Array[String]): Unit = {
assert(args.length == 1, "Please include the directory to write output in as an argument.")
val outputDir = args(0)
val charts = new BarCharts
GnuplotPlotter.pdf(charts.attributesPlot, outputDir, "Attributes")
GnuplotPlotter.pdf(charts.ratingsPlot, outputDir, "Ratings")
}
}
105 changes: 105 additions & 0 deletions visualization/src/main/scala/visualization/PRCurves.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package visualization

import java.io.{FilenameFilter, File}

import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter
import org.sameersingh.scalaplot._

import scala.collection.mutable.ArrayBuffer

/**
* @author sameer
* @since 12/25/14.
*/
class PRCurves {

val attrsOrder = Seq("A", "A+R", "A+BW", "A+C", "A+C+R", "A+C+BW")
//Excluded: "A+R+UW", "A+C+UW", "A+R+BW", "A+C+R+UW", "A+UW", "A+C+R+BW"
val ratingsOrder = Seq("R", "A+R", "R+UW", "R+BW", "R+C", "A+R+UW", "R+C+UW", "A+C+R+UW")
// Excluded: , "A+C+R+BW", "A+C+R", "R+C+BW", "A+R+BW"

def readFile(file: File): Seq[(Double, Boolean)] = {
val source = io.Source.fromFile(file)
val result = new ArrayBuffer[(Double, Boolean)]
for (l <- source.getLines()) {
val split = l.split("\t")
assert(split.length == 2)
assert(split(1) == "0" || split(1) == "1")
result += split(0).toDouble -> (split(1) == "1")
}
source.close()
result //.grouped(10).map(_.head).toSeq
}

def seriesFromFile(file: File): XYSeries = {
val data = readFile(file)
val m = new org.sameersingh.scalaplot.metrics.PrecRecallCurve(data)
val s = new MemXYSeries(m.prChart("").data.serieses.head.points.toSeq, file.getName.drop(2))
s.every = Some(7500)
s.plotStyle = XYPlotStyle.LinesPoints
s
}

def reorderSerieses(serieses: Seq[XYSeries], order: Seq[String]): Seq[XYSeries] = {
order.map(n => serieses.find(_.name == n).get)
}

def generateAttrChart(files: Seq[File]): XYChart = {
val series = files.map(f => seriesFromFile(f))
val data = new XYData(reorderSerieses(series, attrsOrder): _*)
val chart = new XYChart("PR Curve (Attributes)", data)
chart.x.label = "Recall"
chart.y.label = "Precision"
//chart.monochrome = true
chart.showLegend = true
chart.legendPosX = LegendPosX.Center
chart.legendPosY = LegendPosY.Bottom
chart.size = Some(3.5,2.5)
chart
}

def generateRatingsChart(files: Seq[File]): XYChart = {
val series = files.map(f => seriesFromFile(f))
val data = new XYData(reorderSerieses(series, ratingsOrder): _*)
val chart = new XYChart("PR Curve (Ratings)", data)
chart.x.label = "Recall"
chart.y.label = "Precision"
//chart.monochrome = true
chart.showLegend = true
chart.legendPosX = LegendPosX.Right
chart.legendPosY = LegendPosY.Top
chart.size = Some(3.5,2.5)
chart
}

def generateCharts(dirName: String) = {
val dir = new File(dirName)
assert(dir.isDirectory, dirName + " is not a directory.")
// generate output dir
val outDirName = dirName + "/output/"
new File(outDirName).mkdir()
// Attribute PR Curve
val attrFiles = dir.listFiles(new FilenameFilter {
override def accept(dir: File, name: String): Boolean = name.startsWith("A-")
}).toSeq.sortBy(_.getName)
println("attr: " + attrFiles.map(_.getName).mkString(", "))
val attrChart = generateAttrChart(attrFiles)
GnuplotPlotter.pdf(attrChart, outDirName, dir.getName + "A")
// Ratings PR Curve
val ratingsFiles = dir.listFiles(new FilenameFilter {
override def accept(dir: File, name: String): Boolean = name.startsWith("R-")
}).toSeq.sortBy(_.getName)
println("ratings: " + ratingsFiles.map(_.getName).mkString(", "))
val ratingsChart = generateRatingsChart(ratingsFiles)
GnuplotPlotter.pdf(ratingsChart, outDirName, dir.getName + "R")
}
}

object PRCurves {
def main(args: Array[String]): Unit = {
assert(args.length == 1, "Please include the directory containing the generated predictions as an argument.")
val inputDir = args(0)
val curves = new PRCurves
curves.generateCharts(args(0))
}
}

0 comments on commit 633ce71

Please sign in to comment.