11package edu .stanford .nlp .sentiment ;
22import edu .stanford .nlp .util .logging .Redwood ;
33
4+ import java .text .DecimalFormat ;
45import java .util .List ;
56
7+ import edu .stanford .nlp .ling .Label ;
8+ import edu .stanford .nlp .neural .rnn .RNNCoreAnnotations ;
69import edu .stanford .nlp .trees .Tree ;
710import edu .stanford .nlp .util .Generics ;
811
12+
913/** @author John Bauer */
1014public class Evaluate extends AbstractEvaluate {
1115
@@ -15,12 +19,72 @@ public class Evaluate extends AbstractEvaluate {
1519 final SentimentCostAndGradient cag ;
1620 final SentimentModel model ;
1721
22+ // Count how many trees are unknown to the model
23+ // The alternate version, ExternalEvaluate, has no concept of
24+ // unknown, so this is exclusive to the evaluate which uses a model
25+ int treesWithUnks ;
26+ int treesWithUnksCorrect ;
27+
1828 public Evaluate (SentimentModel model ) {
1929 super (model .op );
2030 this .model = model ;
2131 this .cag = new SentimentCostAndGradient (model , null );
2232 }
2333
34+ @ Override
35+ public void reset () {
36+ super .reset ();
37+
38+ treesWithUnks = 0 ;
39+ treesWithUnksCorrect = 0 ;
40+ }
41+
42+ @ Override
43+ public void eval (Tree tree ) {
44+ super .eval (tree );
45+
46+ countUnks (tree );
47+ }
48+
49+ /**
50+ * Keep track of how many trees have at least one unknown, and how
51+ * many of those have the top level annotation correct.
52+ */
53+ protected void countUnks (Tree tree ) {
54+ List <Label > labels = tree .yield ();
55+ boolean hasUnk = false ;
56+ for (Label label : labels ) {
57+ if (!model .wordVectors .containsKey (label .value ())) {
58+ hasUnk = true ;
59+ break ;
60+ }
61+ }
62+
63+ if (hasUnk ) {
64+ int gold = RNNCoreAnnotations .getGoldClass (tree );
65+ int guess = RNNCoreAnnotations .getPredictedClass (tree );
66+
67+ treesWithUnks += 1 ;
68+ if (gold == guess )
69+ treesWithUnksCorrect += 1 ;
70+ }
71+ }
72+
73+ private static final String FORMAT = "#.##" ;
74+ protected DecimalFormat format = new DecimalFormat (FORMAT );
75+
76+ @ Override
77+ public void printSummary () {
78+ super .printSummary ();
79+
80+ log .info ("Saw " + treesWithUnks + " trees with at least one unknown token." );
81+ if (treesWithUnks > 0 ) {
82+ double percent = (float ) treesWithUnksCorrect / treesWithUnks * 100.0 ;
83+ log .info (treesWithUnksCorrect + " / " + treesWithUnks + " trees (" + format .format (percent ) +
84+ "%) with at least one unknown token were classified correctly at the top level." );
85+ }
86+ }
87+
2488 @ Override
2589 public void populatePredictedLabels (List <Tree > trees ) {
2690 for (Tree tree : trees ) {
0 commit comments