From ca4eac4df05d24cb52712a5f08c6d13384525fe2 Mon Sep 17 00:00:00 2001 From: Austin Matthews Date: Thu, 23 May 2019 13:46:46 -0400 Subject: [PATCH] fix to beam search stopping criteria (#572) --- xnmt/search_strategies.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/xnmt/search_strategies.py b/xnmt/search_strategies.py index 2c10425c2..14c08de69 100644 --- a/xnmt/search_strategies.py +++ b/xnmt/search_strategies.py @@ -142,10 +142,18 @@ def generate_output(self, completed_hyp = [] for length in range(self.max_len): if len(completed_hyp) >= self.beam_size: - break + completed_hyp = sorted(completed_hyp, key=lambda hyp: hyp.score, reverse=True) + completed_hyp = completed_hyp[:self.beam_size] + worst_complete_hyp_score = completed_hyp[-1].score + active_hyp = [hyp for hyp in active_hyp if hyp.score >= worst_complete_hyp_score] + # Assumption: each additional word will always *decrease* the total score. + if len(active_hyp) == 0: + break + # Expand hyp new_set = [] for hyp in active_hyp: + # Note: prev_word has *not* yet been added to prev_state if length > 0: prev_word = hyp.word prev_state = hyp.output.state