-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmeteor.py
87 lines (73 loc) · 3.16 KB
/
meteor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python
# Python wrapper for METEOR implementation, by Xinlei Chen
# Acknowledge Michael Denkowski for the generous discussion and help
import os
import sys
import subprocess
import threading
import pdb
# resolve UnicodeEncodeError: 'ascii' codec can't encode character
reload(sys)
sys.setdefaultencoding('utf-8')
# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
METEOR_JAR = 'meteor-1.5.jar'
# print METEOR_JAR
class Meteor:
def __init__(self):
self.meteor_cmd = ['java', '-jar', '-Xmx10G', METEOR_JAR, \
'-', '-', '-stdio', '-l', 'en', '-norm']
self.meteor_p = subprocess.Popen(self.meteor_cmd, \
cwd=os.path.dirname(os.path.abspath(__file__)), \
stdin=subprocess.PIPE, \
stdout=subprocess.PIPE, \
stderr=subprocess.PIPE)
# Used to guarantee thread safety
self.lock = threading.Lock()
def compute_score(self, gts, res):
assert(sorted(gts.keys()) == sorted(res.keys()))
imgIds = gts.keys()
scores = []
eval_line = 'EVAL'
self.lock.acquire()
for i in imgIds:
if len(res[i]) > 1:
res[i] = [res[i][0]]
assert(len(res[i]) == 1)
stat = self._stat(res[i][0], gts[i])
eval_line += ' ||| {}'.format(stat)
self.meteor_p.stdin.write('{}\n'.format(eval_line))
for i in range(0,len(imgIds)):
scores.append(float(self.meteor_p.stdout.readline().strip()))
score = float(self.meteor_p.stdout.readline().strip())
self.lock.release()
return score, scores
def method(self):
return "METEOR"
def _stat(self, hypothesis_str, reference_list):
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
self.meteor_p.stdin.write('{}\n'.format(score_line))
return self.meteor_p.stdout.readline().strip()
def _score(self, hypothesis_str, reference_list):
self.lock.acquire()
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
self.meteor_p.stdin.write('{}\n'.format(score_line))
stats = self.meteor_p.stdout.readline().strip()
eval_line = 'EVAL ||| {}'.format(stats)
# EVAL ||| stats
self.meteor_p.stdin.write('{}\n'.format(eval_line))
score = float(self.meteor_p.stdout.readline().strip())
# bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
# thanks for Andrej for pointing this out
score = float(self.meteor_p.stdout.readline().strip())
self.lock.release()
return score
def __del__(self):
self.lock.acquire()
self.meteor_p.stdin.close()
self.meteor_p.kill()
self.meteor_p.wait()
self.lock.release()