Skip to content

Commit 79a94ee

Browse files
committed
added clean code for NMF
1 parent 183fd32 commit 79a94ee

8 files changed

+1279
-0
lines changed

__pycache__/utils.cpython-38.pyc

3.92 KB
Binary file not shown.

evaluate_induced_alignments.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from utils import LOG, setup_dict_entry, calc_and_update_score, load_gold, get_verse_alignments
2+
import argparse
3+
import os
4+
5+
def main(args):
6+
pros, surs = load_gold(args.gold_file)
7+
8+
save_name = os.path.basename(args.predicted_alignments_file)
9+
save_name = save_name[:-4] + "_results.txt"
10+
11+
result_names = [
12+
"base_intersection",
13+
"base_gdfa",
14+
"NMF",
15+
"NMF + intersection",
16+
"NMF + gdfa"
17+
]
18+
19+
results_all = {}
20+
for name in result_names:
21+
setup_dict_entry(results_all, name, {"p_hit_count": 0, "s_hit_count": 0, "total_hit_count": 0, "gold_s_hit_count": 0, "prec": 0, "rec": 0, "f1": 0, "aer": 0})
22+
23+
with open(args.predicted_alignments_file, 'r') as f_pred, \
24+
open(args.intersection_alignments_file, 'r') as f_inter, \
25+
open(args.gdfa_alignments_file, 'r') as f_gdfa:
26+
27+
lines_predicted = f_pred.read().splitlines()
28+
lines_intersection = f_inter.read().splitlines()
29+
lines_gdfa = f_gdfa.read().splitlines()
30+
31+
for no, (line_pred, line_inter, line_gdfa) in enumerate(zip(lines_predicted, lines_intersection, lines_gdfa)):
32+
verse_id, aligns_predicted = line_pred.split('\t')
33+
_, aligns_intersection = line_inter.split('\t')
34+
_, aligns_gdfa = line_gdfa.split('\t')
35+
36+
# Convert string alignments to set
37+
aligns_predicted = set(aligns_predicted.split(' '))
38+
aligns_intersection = set(aligns_intersection.split(' '))
39+
aligns_gdfa = set(aligns_gdfa.split(' '))
40+
41+
# combine base alignments with predictions
42+
aligns_NMF_plus_intersection = aligns_predicted.union(aligns_intersection)
43+
aligns_NMF_plus_gdfa = aligns_predicted.union(aligns_gdfa)
44+
45+
# update results for all alignments
46+
calc_and_update_score(aligns_intersection, pros[verse_id], surs[verse_id], results_all["base_intersection"])
47+
calc_and_update_score(aligns_gdfa, pros[verse_id], surs[verse_id], results_all["base_gdfa"])
48+
calc_and_update_score(aligns_predicted, pros[verse_id], surs[verse_id], results_all["NMF"])
49+
calc_and_update_score(aligns_NMF_plus_intersection, pros[verse_id], surs[verse_id], results_all["NMF + intersection"])
50+
calc_and_update_score(aligns_NMF_plus_gdfa, pros[verse_id], surs[verse_id], results_all["NMF + gdfa"])
51+
52+
with open(os.path.join(args.save_path, save_name), 'w') as f_out:
53+
for i in results_all:
54+
f_out.write(f'----{i}----\nPrecision: {results_all[i]["prec"]}\nRecall: {results_all[i]["rec"]}\nF1: {results_all[i]["f1"]}\nAER: {results_all[i]["aer"]}\nHits: {results_all[i]["total_hit_count"]}\n\n')
55+
print(f'----{i}----\nPrecision: {results_all[i]["prec"]}\nRecall: {results_all[i]["rec"]}\nF1: {results_all[i]["f1"]}\nAER: {results_all[i]["aer"]}\nHits: {results_all[i]["total_hit_count"]}\n\n')
56+
57+
if __name__ == "__main__":
58+
parser = argparse.ArgumentParser()
59+
60+
parser.add_argument('--save_path', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/results/", type=str)
61+
parser.add_argument('--gold_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/pbc_utils/data/eng_fra_pbc/eng-fra.gold", type=str)
62+
parser.add_argument('--predicted_alignments_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/predicted_alignments/predicted_alignments_from_eng-x-bible-mixed_to_fra-x-bible-louissegond_with_max_83_editions_for_250_verses_NMF.txt", type=str)
63+
parser.add_argument('--intersection_alignments_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/predicted_alignments/intersection_alignments_from_eng-x-bible-mixed_to_fra-x-bible-louissegond_for_250_verses.txt", type=str)
64+
parser.add_argument('--gdfa_alignments_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/predicted_alignments/gdfa_alignments_from_eng-x-bible-mixed_to_fra-x-bible-louissegond_for_250_verses.txt", type=str)
65+
parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str)
66+
parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str)
67+
68+
args = parser.parse_args()
69+
main(args)

induce_alignments_NMF.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
from surprise import Dataset, Reader
2+
import pandas as pd
3+
from surprise import NMF
4+
from utils import LOG, setup_dict_entry, load_editions, load_gold, get_verse_alignments
5+
import os, argparse, random, math
6+
from nltk.translate.gdfa import grow_diag_final_and
7+
import numpy as np
8+
from multiprocessing import Pool, Value
9+
10+
11+
def get_row_col_editions(source_edition, target_edition, all_editions=None):
12+
row_editions = []
13+
col_editions = []
14+
for edition in all_editions:
15+
if edition != source_edition and edition != target_edition:
16+
row_editions.append(edition)
17+
col_editions.append(edition)
18+
19+
row_editions.append(source_edition)
20+
col_editions.append(target_edition)
21+
22+
return row_editions, col_editions
23+
24+
def get_aligns(rf, cf, alignments):
25+
raw_align = ''
26+
27+
if rf in alignments and cf in alignments[rf]:
28+
raw_align = alignments[rf][cf]
29+
alignment_line = [x.split('-') for x in raw_align.split()]
30+
res = []
31+
for x in alignment_line:
32+
res.append( ( int(x[0]), int(x[1]) ) )
33+
elif cf in alignments and rf in alignments[cf]: # re: aak, ce: aai,
34+
raw_align = alignments[cf][rf]
35+
alignment_line = [x.split('-') for x in raw_align.split()]
36+
res = []
37+
for x in alignment_line:
38+
res.append( ( int(x[1]), int(x[0]) ) )
39+
elif rf in alignments and rf == cf: # if source and target are the same
40+
keys = list(alignments[rf].keys())
41+
max_count = 0
42+
for key in keys:
43+
align = alignments[rf][key]
44+
for x in align.split():
45+
count = int(x.split('-')[0])
46+
if count > max_count:
47+
max_count = count
48+
raw_align = "0-0"
49+
for i in range(1,max_count):
50+
raw_align += f" {i}-{i}"
51+
52+
alignment_line = [x.split('-') for x in raw_align.split()]
53+
res = []
54+
for x in alignment_line:
55+
res.append( ( int(x[0]), int(x[1]) ) )
56+
else:
57+
return None
58+
59+
return res
60+
61+
def add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items):
62+
for align in aligns:
63+
64+
aligns_dict['userID'].append(re + str(align[0]))
65+
aligns_dict['itemID'].append(ce + str(align[1]))
66+
aligns_dict['rating'].append(3)
67+
68+
if align[0] > token_counts[re]:
69+
token_counts[re] = align[0]
70+
if align[1] > token_counts[ce]:
71+
token_counts[ce] = align[1]
72+
73+
existing_items[re][ce].append(f"{align[0]},{align[1]}")
74+
75+
def add_negative_samples(aligns_dict, existing_items, token_counts, verse_id):
76+
for re in existing_items:
77+
if token_counts[re] < 2:
78+
continue
79+
for ce in existing_items[re]:
80+
if token_counts[ce] < 2:
81+
continue
82+
for item in existing_items[re][ce]:
83+
i,j = tuple(item.split(","))
84+
i,j = (int(i), int(j))
85+
jp = random.randint(math.ceil(j+1), math.ceil(j+token_counts[ce] ))
86+
ip = random.randint(math.ceil(i+1), math.ceil(i+token_counts[re] ))
87+
88+
jp %= (token_counts[ce] + 1)
89+
aligns_dict['userID'].append(re + str(i))
90+
aligns_dict['itemID'].append(ce + str(jp))
91+
aligns_dict['rating'].append(1)
92+
93+
ip %= (token_counts[re] + 1)
94+
aligns_dict['userID'].append(re + str(ip))
95+
aligns_dict['itemID'].append(ce + str(j))
96+
aligns_dict['rating'].append(1)
97+
98+
def get_alignments_df(row_editions, col_editions, verse_alignments,
99+
source_edition, target_edition, verse_id): #TODO can be improved a lot
100+
token_counts = {}
101+
existing_items = {}
102+
aligns_dict = {'itemID': [], 'userID': [], 'rating': []}
103+
for no, re in enumerate(row_editions):
104+
token_counts[re] = 0
105+
existing_items[re] = {}
106+
107+
for ce in col_editions:
108+
109+
setup_dict_entry(token_counts, ce, 0)
110+
existing_items[re][ce] = []
111+
aligns = get_aligns(re, ce, verse_alignments)
112+
113+
if not aligns is None:
114+
add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items)
115+
116+
add_negative_samples(aligns_dict, existing_items, token_counts, verse_id)
117+
118+
return pd.DataFrame(aligns_dict), token_counts[source_edition], token_counts[target_edition]
119+
120+
def iter_max(sim_matrix: np.ndarray, max_count: int=2, alpha_ratio = 0.7) -> np.ndarray:
121+
m, n = sim_matrix.shape
122+
forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n
123+
backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m
124+
inter = forward * backward.transpose()
125+
126+
if min(m, n) <= 2:
127+
return inter
128+
129+
new_inter = np.zeros((m, n))
130+
count = 1
131+
while count < max_count:
132+
mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0)
133+
mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0)
134+
mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0)
135+
mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y))
136+
if mask_x.sum() < 1.0 or mask_y.sum() < 1.0:
137+
mask *= 0.0
138+
mask_zeros *= 0.0
139+
140+
new_sim = sim_matrix * mask
141+
fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros
142+
bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros
143+
new_inter = fwd * bac
144+
145+
if np.array_equal(inter + new_inter, inter):
146+
break
147+
inter = inter + new_inter
148+
count += 1
149+
return inter
150+
151+
def get_itermax_predictions(raw_s_predictions, max_count=2, alpha_ratio=0.9):
152+
rows = len(raw_s_predictions)
153+
cols = len(raw_s_predictions[0])
154+
matrix = np.ndarray(shape=(rows, cols), dtype=float)
155+
156+
for i in raw_s_predictions:
157+
for j, s in raw_s_predictions[i]:
158+
matrix[i,j] = s
159+
160+
itermax_res = iter_max(matrix, max_count, alpha_ratio)
161+
res = []
162+
for i in range(rows):
163+
for j in range(cols):
164+
if itermax_res[i,j] != 0:
165+
res.append((i,j))
166+
167+
return res
168+
169+
def predict_alignments(algo, source_edition, target_edition):
170+
raw_s_predictions = {}
171+
raw_t_predictions = {}
172+
173+
for i in range(algo.s_tok_count + 1):
174+
for j in range(algo.t_tok_count + 1):
175+
pred = algo.predict(source_edition + str(i), target_edition + str(j))
176+
177+
setup_dict_entry(raw_s_predictions, i, [])
178+
setup_dict_entry(raw_t_predictions, j, [])
179+
180+
raw_s_predictions[i].append((j, pred.est))
181+
raw_t_predictions[j].append((i, pred.est))
182+
183+
# get predicted alignments from argmax (max_count=1 means argmax)
184+
res = get_itermax_predictions(raw_s_predictions, max_count=1)
185+
186+
return res
187+
188+
def train_model(df, s_tok_count, t_tok_count, row_editions, col_editions):
189+
algo = NMF()
190+
reader = Reader(rating_scale=(1, 3))
191+
data = Dataset.load_from_df(df[['userID', 'itemID', 'rating']], reader)
192+
trainset = data.build_full_trainset()
193+
algo.fit(trainset)
194+
195+
algo.s_tok_count = s_tok_count
196+
algo.t_tok_count = t_tok_count
197+
algo.row_editions = row_editions
198+
algo.col_editions = col_editions
199+
algo.df = df
200+
201+
return algo
202+
203+
def get_induced_alignments(source_edition, target_edition, verse_alignments_path, verse_id, all_editions):
204+
205+
verse_alignments = get_verse_alignments(verse_alignments_path, verse_id, editions=all_editions)
206+
207+
# this is only for saving the gdfa alignments from source to target for the evauation
208+
verse_alignments_gdfa = get_verse_alignments(verse_alignments_path, verse_id, editions=[source_edition, target_edition], gdfa=True)
209+
210+
### source -> row, target-> col###
211+
row_editions, col_editions = get_row_col_editions(source_edition, target_edition, all_editions)
212+
#itemid -> col, user -> row
213+
df, s_tok_count, t_tok_count = get_alignments_df(row_editions, col_editions, verse_alignments, source_edition, target_edition, verse_id)
214+
215+
algo = train_model(df, s_tok_count, t_tok_count, row_editions, col_editions)
216+
217+
predicted_alignments = predict_alignments(algo, source_edition, target_edition)
218+
base_inter_alignments = verse_alignments[source_edition][target_edition]
219+
base_gdfa_alignments = verse_alignments_gdfa[source_edition][target_edition]
220+
221+
with cnt.get_lock():
222+
cnt.value += 1
223+
if cnt.value % 20 == 0:
224+
LOG.info(f"Done inferring alignments for {cnt.value} verses")
225+
226+
return predicted_alignments, base_inter_alignments, base_gdfa_alignments, len(algo.col_editions)+1
227+
228+
229+
def init_globals(counter):
230+
global cnt
231+
cnt = counter
232+
233+
def main(args):
234+
random.seed(args.seed)
235+
236+
pros, surs = load_gold(args.gold_file)
237+
all_verses =list(pros.keys())
238+
all_verses = all_verses
239+
240+
# Get languages and editions
241+
editions, langs = load_editions(args.editions_file)
242+
all_editions = [editions[lang] for lang in langs]
243+
244+
# print some info
245+
LOG.info(f"Inferring alignments from {args.source_edition} to {args.target_edition}")
246+
LOG.info(f"Number of verses whose alignments will be inferred: {len(all_verses)}")
247+
LOG.info(f"Number of editions to use for the graph algorithms: {len(all_editions)}")
248+
LOG.info(f"Number of cores to be used for processing: {args.core_count}")
249+
250+
# Prepare arguments for parallel processing
251+
starmap_args = []
252+
for verse_id in all_verses:
253+
# aligns_predicted, used_edition_count = get_induced_alignments(args.source_edition, args.target_edition, args.verse_alignments, verse_id, all_editions)
254+
starmap_args.append((args.source_edition, args.target_edition, args.verse_alignments_path, verse_id, all_editions))
255+
256+
# get predicted alignments using parallel processing
257+
cnt = Value('i', 0)
258+
with Pool(processes=args.core_count, initializer=init_globals, initargs=(cnt,)) as p:
259+
all_alignments = p.starmap(get_induced_alignments, starmap_args)
260+
261+
out_NMF_f_name = f"predicted_alignments_from_{args.source_edition}_to_{args.target_edition}_with_max_{len(all_editions)}_editions_for_{len(all_verses)}_verses_NMF.txt"
262+
out_NMF_file = open(os.path.join(args.save_path, out_NMF_f_name), 'w')
263+
out_inter_f_name = f"intersection_alignments_from_{args.source_edition}_to_{args.target_edition}_for_{len(all_verses)}_verses.txt"
264+
out_inter_file = open(os.path.join(args.save_path, out_inter_f_name), 'w')
265+
out_gdfa_f_name = f"gdfa_alignments_from_{args.source_edition}_to_{args.target_edition}_for_{len(all_verses)}_verses.txt"
266+
out_gdfa_file = open(os.path.join(args.save_path, out_gdfa_f_name), 'w')
267+
268+
for id, verse_id in enumerate(all_verses):
269+
aligns_predicted, inter_aligns, gdfa_aligns, used_edition_count = all_alignments[id]
270+
271+
# convert predicted alignments to string and write to a file
272+
aligns_predicted = ' '.join([f"{align[0]}-{align[1]}" for align in aligns_predicted])
273+
out_NMF_file.write(f"{verse_id}\t{aligns_predicted}\n")
274+
out_inter_file.write(f"{verse_id}\t{inter_aligns.strip()}\n")
275+
out_gdfa_file.write(f"{verse_id}\t{gdfa_aligns.strip()}\n")
276+
277+
out_NMF_file.close()
278+
out_inter_file.close()
279+
out_gdfa_file.close()
280+
281+
if __name__ == "__main__":
282+
parser = argparse.ArgumentParser()
283+
284+
parser.add_argument('--save_path', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/predicted_alignments", type=str)
285+
parser.add_argument('--gold_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/pbc_utils/data/eng_fra_pbc/eng-fra.gold", type=str)
286+
parser.add_argument('--verse_alignments_path', default="/mounts/data/proj/ayyoob/align_induction/verse_alignments/", type=str)
287+
parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str)
288+
parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str)
289+
parser.add_argument('--editions_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/pbc_utils/data/eng_fra_pbc/lang_list.txt", type=str)
290+
parser.add_argument('--core_count', default=80, type=int)
291+
parser.add_argument('--seed', default=42, type=int)
292+
293+
args = parser.parse_args()
294+
main(args)
295+

0 commit comments

Comments
 (0)