Skip to content

Commit

Permalink
Added checks in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy9 committed Nov 3, 2016
1 parent 43beb53 commit 717392f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/main/java/numoccur/NumoccurDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ public static double testModel(String modelPath, SLProblem sp)
NumoccurY gold = (NumoccurY) sp.goldStructureList.get(i);
NumoccurY pred = (NumoccurY) model.infSolver.getBestStructure(
model.wv, prob);
// List<Integer> allTrue = new ArrayList<>();
// for(int j=0; j<prob.quantities.size(); ++j) {
// allTrue.add(1);
// }
// pred = new NumoccurY(allTrue);
total.add(prob.problemIndex);
double goldWt = model.wv.dotProduct(
model.featureGenerator.getFeatureVector(prob, gold));
Expand Down
32 changes: 11 additions & 21 deletions src/main/java/pipeline/PipelineDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,36 @@ public static double doTest(int testFold) throws Exception {
public static double testModel(SLModel numOccurModel, SLModel varModel,
SLModel treeModel, SLProblem sp, boolean printMistakes) throws Exception {
double acc = 0.0;
int numQuant = 0, numQuantCorrect = 0, numTokens = 0, numTokensCorrect = 0;
for (int i = 0; i < sp.instanceList.size(); i++) {
JointX prob = (JointX) sp.instanceList.get(i);
JointY gold = (JointY) sp.goldStructureList.get(i);
JointY pred = PipelineInfSolver.getBestStructure(
prob, numOccurModel, varModel, treeModel);
System.out.println("----------------------------------------------------------");
// if(Equation.getLoss(gold.equation, pred.equation, true) < 0.0001 ||
numQuant += prob.quantities.size();
numTokens += prob.ta.size();
// if(Equation.getLoss(gold.equation, pred.equation, true) < 0.0001 ||
// Equation.getLoss(gold.equation, pred.equation, false) < 0.0001) {
if(JointY.getLoss(gold, pred) < 0.0001) {
acc += 1;
numQuantCorrect += prob.quantities.size();
numTokensCorrect += prob.ta.size();
} else if(printMistakes) {
System.out.println(prob.problemIndex+" : "+prob.ta.getText());
System.out.println("Quantities : "+prob.quantities);
System.out.println("Gold : \n"+gold);
System.out.println("Pred : \n"+pred);
System.out.println("Loss : "+JointY.getLoss(gold, pred));
} else if(printMistakes) {
// System.out.println(prob.problemIndex+" : "+prob.ta.getText());
//// System.out.println("Quantities : "+prob.quantities);
// System.out.println("\nGold : \nEquation : "+gold.equation+"\nVaraiable : ");
// for(String key : gold.varTokens.keySet()) {
// System.out.print(key + "= {");
// for(Integer index : gold.varTokens.get(key)) {
// System.out.print(VarFeatGen.getString(prob.ta, prob.candidateVars.get(index))+" , ");
// }
// System.out.print(" } ");
// }
// System.out.println("\n\nPred : : \nEquation : "+pred.equation+"\nVaraiable : ");
// for(String key : pred.varTokens.keySet()) {
// System.out.print(key + "= {");
// for(Integer index : pred.varTokens.get(key)) {
// System.out.print(VarFeatGen.getString(prob.ta, prob.candidateVars.get(index))+" , ");
// }
// System.out.print(" } ");
// }
// System.out.println("\n\nLoss : "+JointY.getLoss(gold, pred));
}
System.out.println("----------------------------------------------------------");
}
System.out.println("Accuracy : = " + acc + " / " + sp.instanceList.size()
+ " = " + (acc/sp.instanceList.size()));
// System.out.println("Average : Tokens : " + (numTokens*1.0/sp.instanceList.size()));
// System.out.println("Average : Quantities : " + (numQuant*1.0/sp.instanceList.size()));
// System.out.println("Correct : Tokens : " + (numTokensCorrect*1.0/acc));
// System.out.println("Correct : Quantities : " +(numQuantCorrect*1.0/acc));
return (acc/sp.instanceList.size());
}

Expand Down

0 comments on commit 717392f

Please sign in to comment.