-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconll.py
115 lines (96 loc) · 4.56 KB
/
conll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import re
import tempfile
import subprocess
import operator
import collections
import logging
import json
import sys
logger = logging.getLogger(__name__)
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \(?(.*?)(?:\); part (\d+)|$)") # First line at each document
COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
def get_doc_key(doc_id, part):
return "{}_{}".format(doc_id, int(part)) if part is not None else doc_id
def output_conll(input_file, output_file, predictions):
prediction_map = {}
for doc_key, clusters in predictions.items():
start_map = collections.defaultdict(list)
end_map = collections.defaultdict(list)
word_map = collections.defaultdict(list)
for cluster_id, mentions in enumerate(clusters):
for start, end in mentions:
if start == end:
word_map[start].append(cluster_id)
else:
start_map[start].append((cluster_id, end))
end_map[end].append((cluster_id, start))
for k,v in start_map.items():
start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
for k,v in end_map.items():
end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
prediction_map[doc_key] = (start_map, end_map, word_map)
word_index = 0
for line in input_file.readlines():
row = line.split()
if len(row) == 0:
output_file.write("\n")
elif row[0].startswith("#"):
begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
if begin_match:
doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
start_map, end_map, word_map = prediction_map[doc_key]
word_index = 0
output_file.write(line)
# output_file.write("\n")
else:
assert get_doc_key(row[0], row[1]) == doc_key or get_doc_key(row[0], None) == doc_key
coref_list = []
if word_index in end_map:
for cluster_id in end_map[word_index]:
coref_list.append("{})".format(cluster_id))
if word_index in word_map:
for cluster_id in word_map[word_index]:
coref_list.append("({})".format(cluster_id))
if word_index in start_map:
for cluster_id in start_map[word_index]:
coref_list.append("({}".format(cluster_id))
if len(coref_list) == 0:
row[-1] = "-"
else:
row[-1] = "|".join(coref_list)
output_file.write(" ".join(row))
output_file.write("\n")
word_index += 1
def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True):
cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
stdout, stderr = process.communicate()
process.wait()
stdout = stdout.decode("utf-8")
if stderr is not None:
logger.error(stderr)
if official_stdout:
logger.info("Official result for {}".format(metric))
logger.info(stdout)
coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
recall = float(coref_results_match.group(1))
precision = float(coref_results_match.group(2))
f1 = float(coref_results_match.group(3))
return {"r": recall, "p": precision, "f": f1}
def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=True):
with tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file:
with open(gold_path, "r") as gold_file:
output_conll(gold_file, prediction_file, predictions, subtoken_maps)
# logger.info("Predicted conll file: {}".format(prediction_file.name))
results = {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") }
return results
def convert_to_conll(conll_gold_path, json_pred_path):
predictions = {}
with open(json_pred_path, "r") as json_pred_file:
for line in json_pred_file:
doc = json.loads(line)
predictions[doc["doc_key"][:-2]] = doc["predict_clusters"]
with open(conll_gold_path, "r") as gold_file:
output_conll(gold_file, sys.stdout, predictions)
if __name__ == "__main__":
convert_to_conll(sys.argv[1], sys.argv[2])