Skip to content

Commit

Permalink
support server for nbest translation (OpenNMT#1631)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zenglinxiao authored and francoishernandez committed Nov 12, 2019
1 parent d314a48 commit 3e9c528
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
17 changes: 9 additions & 8 deletions onmt/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ def translate():
inputs = request.get_json(force=True)
out = {}
try:
translation, scores, n_best, times = translation_server.run(inputs)
assert len(translation) == len(inputs)
assert len(scores) == len(inputs)

out = [[{"src": inputs[i]['src'], "tgt": translation[i],
"n_best": n_best,
"pred_score": scores[i]}
for i in range(len(translation))]]
trans, scores, n_best, times = translation_server.run(inputs)
assert len(trans) == len(inputs) * n_best
assert len(scores) == len(inputs) * n_best

out = [[] for _ in range(n_best)]
for i in range(len(trans)):
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
"n_best": n_best, "pred_score": scores[i]}
out[i % n_best].append(response)
except ServerModelError as e:
out['error'] = str(e)
out['status'] = STATUS_ERROR
Expand Down
10 changes: 1 addition & 9 deletions onmt/tests/test_translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,11 @@ def test_run(self):
for elem in scores:
self.assertIsInstance(elem, float)
self.assertEqual(len(results), len(scores))
self.assertEqual(len(scores), len(inp))
self.assertEqual(n_best, 1)
self.assertEqual(len(scores), len(inp) * n_best)
self.assertEqual(len(time), 1)
self.assertIsInstance(time, dict)
self.assertIn("translation", time)

def test_nbest_init_fails(self):
model_id = 0
opt = {"models": ["test_model.pt"], "n_best": 2}
model_root = TEST_DIR
with self.assertRaises(ValueError):
ServerModel(opt, model_id, model_root=model_root, load=True)


class TestTranslationServer(unittest.TestCase):
# this could be considered an integration test because it touches
Expand Down
11 changes: 5 additions & 6 deletions onmt/translate/translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
on_timeout="to_cpu", model_root="./"):
self.model_root = model_root
self.opt = self.parse_opt(opt)
if self.opt.n_best > 1:
raise ValueError("Values of n_best > 1 are not supported")

self.model_id = model_id
self.preprocess_opt = preprocess_opt
Expand Down Expand Up @@ -441,8 +439,6 @@ def run(self, inputs):
self.reset_unload_timer()

# NOTE: translator returns lists of `n_best` list
# we can ignore that (i.e. flatten lists) only because
# we restrict `n_best=1`
def flatten_list(_list): return sum(_list, [])
results = flatten_list(predictions)
scores = [score_tensor.item()
Expand All @@ -455,9 +451,12 @@ def flatten_list(_list): return sum(_list, [])
for item in results]
# build back results with empty texts
for i in empty_indices:
results.insert(i, "")
scores.insert(i, 0)
j = i * self.opt.n_best
results = results[:j] + [""] * self.opt.n_best + results[j:]
scores = scores[:j] + [0] * self.opt.n_best + scores[j:]

head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)]
tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)]
results = ["".join(items)
for items in zip(head_spaces, results, tail_spaces)]

Expand Down

0 comments on commit 3e9c528

Please sign in to comment.