-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added pr curves and bar chart visualizations
- Loading branch information
1 parent
304eac5
commit 633ce71
Showing
3 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<groupId>visualization</groupId> | ||
<artifactId>visualization</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
<packaging>jar</packaging> | ||
|
||
<properties> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
<java.version>1.7</java.version> | ||
<scala.version>2.10.3</scala.version> | ||
<scala.version.tools>2.10</scala.version.tools> | ||
</properties> | ||
|
||
<build> | ||
<sourceDirectory>src/main/scala</sourceDirectory> | ||
<testSourceDirectory>src/test/scala</testSourceDirectory> | ||
<plugins> | ||
<plugin> | ||
<groupId>net.alchim31.maven</groupId> | ||
<artifactId>scala-maven-plugin</artifactId> | ||
<version>3.2.0</version> | ||
<executions> | ||
<execution> | ||
<!-- this is so we don't end with a compile error in maven-compiler-plugin --> | ||
<phase>process-sources</phase> | ||
<goals> | ||
<goal>compile</goal> | ||
<goal>testCompile</goal> | ||
</goals> | ||
<configuration> | ||
</configuration> | ||
</execution> | ||
</executions> | ||
</plugin> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-compiler-plugin</artifactId> | ||
<version>3.1</version> | ||
<configuration> | ||
<source>${java.version}</source> | ||
<target>${java.version}</target> | ||
</configuration> | ||
</plugin> | ||
</plugins> | ||
</build> | ||
<dependencies> | ||
<dependency> | ||
<groupId>org.scala-lang</groupId> | ||
<artifactId>scala-library</artifactId> | ||
<version>${scala.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.sameersingh.scalaplot</groupId> | ||
<artifactId>scalaplot</artifactId> | ||
<version>0.0.4</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>junit</groupId> | ||
<artifactId>junit</artifactId> | ||
<version>4.11</version> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.scalacheck</groupId> | ||
<artifactId>scalacheck_${scala.version.tools}</artifactId> | ||
<version>1.11.4</version> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.scalatest</groupId> | ||
<artifactId>scalatest_${scala.version.tools}</artifactId> | ||
<version>2.2.0</version> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
</project> |
69 changes: 69 additions & 0 deletions
69
visualization/src/main/scala/visualization/BarCharts.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package visualization | ||
|
||
import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter | ||
import org.sameersingh.scalaplot.{MemBarSeries, BarData, BarChart} | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
/** | ||
* @author sameer | ||
* @since 12/26/14. | ||
*/ | ||
class BarCharts { | ||
|
||
val attr = new ArrayBuffer[(String, Double)]() | ||
attr += "A" -> 81.5 | ||
attr += "A+C" -> 83.9 | ||
attr += "A+R" -> 80.8 | ||
attr += "A+BW" -> 83.5 | ||
//attr += "A+UW" -> 81.5 | ||
attr += "A+C+R" -> 83.4 | ||
attr += "A+C+BW" -> 83.6 | ||
attr += "A+C+UW" -> 83.9 | ||
attr += "A+R+BW" -> 83.5 | ||
attr += "A+R+UW" -> 80.7 | ||
attr += "A+C+R+BW" -> 83.7 | ||
attr += "A+C+R+UW" -> 83.3 | ||
|
||
val ratings = new ArrayBuffer[(String, Double)]() | ||
ratings += "R" -> 71.3 | ||
ratings += "R+C" -> 73.6 | ||
ratings += "R+A" -> 72.3 | ||
ratings += "R+BW" -> 72.4 | ||
ratings += "R+UW" -> 79.2 | ||
ratings += "R+A+C" -> 72.5 | ||
ratings += "R+C+BW" -> 72.6 | ||
ratings += "R+C+UW" -> 80.4 | ||
ratings += "R+A+BW" -> 72.3 | ||
ratings += "R+A+UW" -> 79.2 | ||
ratings += "R+A+C+BW" -> 72.3 | ||
ratings += "R+A+C+UW" -> 79.9 | ||
|
||
def attributesPlot: BarChart = { | ||
val ss = new MemBarSeries(attr.map(_._2)) | ||
val d = new BarData((i) => attr(i)._1, Seq(ss)) | ||
val c = new BarChart("Heldout F1 (Attributes)", d) | ||
c.y.label = "F1" | ||
c.x.label = "Factorization Models" | ||
c | ||
} | ||
|
||
def ratingsPlot: BarChart = { | ||
val ss = new MemBarSeries(ratings.map(_._2)) | ||
val d = new BarData((i) => ratings(i)._1, Seq(ss)) | ||
val c = new BarChart("Heldout F1 (Ratings)", d) | ||
c.y.label = "F1" | ||
c.x.label = "Factorization Models" | ||
c | ||
} | ||
} | ||
|
||
object BarCharts { | ||
def main(args: Array[String]): Unit = { | ||
assert(args.length == 1, "Please include the directory to write output in as an argument.") | ||
val outputDir = args(0) | ||
val charts = new BarCharts | ||
GnuplotPlotter.pdf(charts.attributesPlot, outputDir, "Attributes") | ||
GnuplotPlotter.pdf(charts.ratingsPlot, outputDir, "Ratings") | ||
} | ||
} |
105 changes: 105 additions & 0 deletions
105
visualization/src/main/scala/visualization/PRCurves.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package visualization | ||
|
||
import java.io.{FilenameFilter, File} | ||
|
||
import org.sameersingh.scalaplot.gnuplot.GnuplotPlotter | ||
import org.sameersingh.scalaplot._ | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
/** | ||
* @author sameer | ||
* @since 12/25/14. | ||
*/ | ||
class PRCurves { | ||
|
||
val attrsOrder = Seq("A", "A+R", "A+BW", "A+C", "A+C+R", "A+C+BW") | ||
//Excluded: "A+R+UW", "A+C+UW", "A+R+BW", "A+C+R+UW", "A+UW", "A+C+R+BW" | ||
val ratingsOrder = Seq("R", "A+R", "R+UW", "R+BW", "R+C", "A+R+UW", "R+C+UW", "A+C+R+UW") | ||
// Excluded: , "A+C+R+BW", "A+C+R", "R+C+BW", "A+R+BW" | ||
|
||
def readFile(file: File): Seq[(Double, Boolean)] = { | ||
val source = io.Source.fromFile(file) | ||
val result = new ArrayBuffer[(Double, Boolean)] | ||
for (l <- source.getLines()) { | ||
val split = l.split("\t") | ||
assert(split.length == 2) | ||
assert(split(1) == "0" || split(1) == "1") | ||
result += split(0).toDouble -> (split(1) == "1") | ||
} | ||
source.close() | ||
result //.grouped(10).map(_.head).toSeq | ||
} | ||
|
||
def seriesFromFile(file: File): XYSeries = { | ||
val data = readFile(file) | ||
val m = new org.sameersingh.scalaplot.metrics.PrecRecallCurve(data) | ||
val s = new MemXYSeries(m.prChart("").data.serieses.head.points.toSeq, file.getName.drop(2)) | ||
s.every = Some(7500) | ||
s.plotStyle = XYPlotStyle.LinesPoints | ||
s | ||
} | ||
|
||
def reorderSerieses(serieses: Seq[XYSeries], order: Seq[String]): Seq[XYSeries] = { | ||
order.map(n => serieses.find(_.name == n).get) | ||
} | ||
|
||
def generateAttrChart(files: Seq[File]): XYChart = { | ||
val series = files.map(f => seriesFromFile(f)) | ||
val data = new XYData(reorderSerieses(series, attrsOrder): _*) | ||
val chart = new XYChart("PR Curve (Attributes)", data) | ||
chart.x.label = "Recall" | ||
chart.y.label = "Precision" | ||
//chart.monochrome = true | ||
chart.showLegend = true | ||
chart.legendPosX = LegendPosX.Center | ||
chart.legendPosY = LegendPosY.Bottom | ||
chart.size = Some(3.5,2.5) | ||
chart | ||
} | ||
|
||
def generateRatingsChart(files: Seq[File]): XYChart = { | ||
val series = files.map(f => seriesFromFile(f)) | ||
val data = new XYData(reorderSerieses(series, ratingsOrder): _*) | ||
val chart = new XYChart("PR Curve (Ratings)", data) | ||
chart.x.label = "Recall" | ||
chart.y.label = "Precision" | ||
//chart.monochrome = true | ||
chart.showLegend = true | ||
chart.legendPosX = LegendPosX.Right | ||
chart.legendPosY = LegendPosY.Top | ||
chart.size = Some(3.5,2.5) | ||
chart | ||
} | ||
|
||
def generateCharts(dirName: String) = { | ||
val dir = new File(dirName) | ||
assert(dir.isDirectory, dirName + " is not a directory.") | ||
// generate output dir | ||
val outDirName = dirName + "/output/" | ||
new File(outDirName).mkdir() | ||
// Attribute PR Curve | ||
val attrFiles = dir.listFiles(new FilenameFilter { | ||
override def accept(dir: File, name: String): Boolean = name.startsWith("A-") | ||
}).toSeq.sortBy(_.getName) | ||
println("attr: " + attrFiles.map(_.getName).mkString(", ")) | ||
val attrChart = generateAttrChart(attrFiles) | ||
GnuplotPlotter.pdf(attrChart, outDirName, dir.getName + "A") | ||
// Ratings PR Curve | ||
val ratingsFiles = dir.listFiles(new FilenameFilter { | ||
override def accept(dir: File, name: String): Boolean = name.startsWith("R-") | ||
}).toSeq.sortBy(_.getName) | ||
println("ratings: " + ratingsFiles.map(_.getName).mkString(", ")) | ||
val ratingsChart = generateRatingsChart(ratingsFiles) | ||
GnuplotPlotter.pdf(ratingsChart, outDirName, dir.getName + "R") | ||
} | ||
} | ||
|
||
object PRCurves { | ||
def main(args: Array[String]): Unit = { | ||
assert(args.length == 1, "Please include the directory containing the generated predictions as an argument.") | ||
val inputDir = args(0) | ||
val curves = new PRCurves | ||
curves.generateCharts(args(0)) | ||
} | ||
} |