diff --git a/onmt/bin/server.py b/onmt/bin/server.py index 44f8723f83..35d00b89fe 100755 --- a/onmt/bin/server.py +++ b/onmt/bin/server.py @@ -75,14 +75,15 @@ def translate(): inputs = request.get_json(force=True) out = {} try: - translation, scores, n_best, times = translation_server.run(inputs) - assert len(translation) == len(inputs) - assert len(scores) == len(inputs) - - out = [[{"src": inputs[i]['src'], "tgt": translation[i], - "n_best": n_best, - "pred_score": scores[i]} - for i in range(len(translation))]] + trans, scores, n_best, times = translation_server.run(inputs) + assert len(trans) == len(inputs) * n_best + assert len(scores) == len(inputs) * n_best + + out = [[] for _ in range(n_best)] + for i in range(len(trans)): + response = {"src": inputs[i // n_best]['src'], "tgt": trans[i], + "n_best": n_best, "pred_score": scores[i]} + out[i % n_best].append(response) except ServerModelError as e: out['error'] = str(e) out['status'] = STATUS_ERROR diff --git a/onmt/tests/test_translation_server.py b/onmt/tests/test_translation_server.py index b93898d89b..2c5baca506 100644 --- a/onmt/tests/test_translation_server.py +++ b/onmt/tests/test_translation_server.py @@ -120,19 +120,11 @@ def test_run(self): for elem in scores: self.assertIsInstance(elem, float) self.assertEqual(len(results), len(scores)) - self.assertEqual(len(scores), len(inp)) - self.assertEqual(n_best, 1) + self.assertEqual(len(scores), len(inp) * n_best) self.assertEqual(len(time), 1) self.assertIsInstance(time, dict) self.assertIn("translation", time) - def test_nbest_init_fails(self): - model_id = 0 - opt = {"models": ["test_model.pt"], "n_best": 2} - model_root = TEST_DIR - with self.assertRaises(ValueError): - ServerModel(opt, model_id, model_root=model_root, load=True) - class TestTranslationServer(unittest.TestCase): # this could be considered an integration test because it touches diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index e18024f3c7..ae05336300 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -205,8 +205,6 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, on_timeout="to_cpu", model_root="./"): self.model_root = model_root self.opt = self.parse_opt(opt) - if self.opt.n_best > 1: - raise ValueError("Values of n_best > 1 are not supported") self.model_id = model_id self.preprocess_opt = preprocess_opt @@ -441,8 +439,6 @@ def run(self, inputs): self.reset_unload_timer() # NOTE: translator returns lists of `n_best` list - # we can ignore that (i.e. flatten lists) only because - # we restrict `n_best=1` def flatten_list(_list): return sum(_list, []) results = flatten_list(predictions) scores = [score_tensor.item() @@ -455,9 +451,12 @@ def flatten_list(_list): return sum(_list, []) for item in results] # build back results with empty texts for i in empty_indices: - results.insert(i, "") - scores.insert(i, 0) + j = i * self.opt.n_best + results = results[:j] + [""] * self.opt.n_best + results[j:] + scores = scores[:j] + [0] * self.opt.n_best + scores[j:] + head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)] + tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)] results = ["".join(items) for items in zip(head_spaces, results, tail_spaces)]