Skip to content

Commit 6cd38ef

Browse files
committed
fix adad
1 parent 8ec00c2 commit 6cd38ef

File tree

2 files changed

+31
-36
lines changed

2 files changed

+31
-36
lines changed

induce_alignments_AdAd.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
import numpy as np
88

9-
from utils import LOG, load_editions
9+
from utils import LOG, load_editions, load_gold
1010

1111
class MyWG:
1212
def __init__(self, nodes=[]):
@@ -65,14 +65,6 @@ def calc_wadar(self, edges, verbose=False):
6565
return scores
6666

6767

68-
def load_gold(gold_path="golds/eng-fra-new.gold"):
69-
golds = {}
70-
with open(gold_path, "r") as fi:
71-
for l in fi:
72-
l = l.split("\t")
73-
golds[l[0]] = list(set(l[1].split()))
74-
return golds
75-
7668
def load_texts_and_alignments(editions_file, lang_files_path, verse_alignments_path, aligner="inter", golds=None):
7769
# Get languages and editions
7870
editions, langs = load_editions(editions_file)
@@ -83,7 +75,10 @@ def load_texts_and_alignments(editions_file, lang_files_path, verse_alignments_p
8375
texts = {}
8476
for langl in langs:
8577
verses = {}
86-
lang_path = lang_files_path + "/" + l[1] + ".txt"
78+
if langl == "eng":
79+
lang_path = os.path.join("/mounts/work/mjalili/projects/graph_align_base/data/pbc/", editions[langl] + ".txt")
80+
else:
81+
lang_path = os.path.join(lang_files_path, editions[langl] + ".txt")
8782
with codecs.open(lang_path, "r", "utf-8") as fi:
8883
for l in fi:
8984
if l[0] == "#": continue
@@ -102,7 +97,7 @@ def load_texts_and_alignments(editions_file, lang_files_path, verse_alignments_p
10297
v_path = F"{verse_alignments_path}/{verse}_{aligner}.txt"
10398
if not os.path.exists(v_path):
10499
LOG.info(v_path)
105-
LOG.info(f"================================== dos not exist ==================================")
100+
LOG.info(f"================================== does not exist ==================================")
106101
return None
107102
with open(v_path, "r") as f_al:
108103
for vl in f_al:
@@ -120,7 +115,7 @@ def load_texts_and_alignments(editions_file, lang_files_path, verse_alignments_p
120115
else:
121116
init_aligns[(l1, l2)][verse] = [[int(alp.split("-")[1]), int(alp.split("-")[0]), 1.0] for alp in vl[2].strip().split()]
122117

123-
return langs, texts, lang_pairs, init_aligns
118+
return lang_code_map, langs, texts, lang_pairs, init_aligns
124119

125120
def get_alignment_matrix(sim_matrix):
126121
m, n = sim_matrix.shape
@@ -232,18 +227,18 @@ def add_edges_to_align_argmax(texts, waligns, out_path="", target_pair=("eng", "
232227
return all_cnt
233228

234229
def main(args):
235-
target_pair = (args.source_lang, args.target_lang)
236230
if args.gold_file != "":
237231
pros, surs = load_gold(args.gold_file)
238232
all_verses = list(pros.keys())
239233
else:
240234
all_verses = None
241235

242236
# Get languages and initial alignments
243-
langs, texts, lang_pairs, init_aligns = load_texts_and_alignments(args.editions_file, args.lang_files_path, args.verse_alignments_path, args.aligner, golds=all_verses)
237+
lang_code_map, langs, texts, lang_pairs, init_aligns = load_texts_and_alignments(args.editions_file, args.lang_files_path, args.verse_alignments_path, args.aligner, golds=all_verses)
238+
target_pair = (lang_code_map[args.source_edition], lang_code_map[args.target_edition])
244239

245240
# print some info
246-
LOG.info(f"Inferring alignments from {args.source_lang} to {args.target_lang}")
241+
LOG.info(f"Inferring alignments from {args.source_edition} to {args.target_edition}")
247242
LOG.info(f"Number of verses whose alignments will be inferred: {len(all_verses)}")
248243
LOG.info(f"Number of editions to use for the graph algorithms: {len(langs)}")
249244

@@ -259,13 +254,13 @@ def main(args):
259254
if __name__ == "__main__":
260255
current_path = os.path.dirname(os.path.realpath(__file__))
261256
parser = argparse.ArgumentParser()
262-
257+
263258
parser.add_argument('--save_path', default=os.path.join(current_path, "predicted_alignments"), type=str)
264-
parser.add_argument('--gold_file', default=os.path.join(current_path, "data/gold-standards/blinker/eng-fra.gold"), type=str)
259+
parser.add_argument('--gold_file', default=os.path.join(current_path, "data/gold-standards/blinker/eng-fra.gold"), type=str)
265260
parser.add_argument('--verse_alignments_path', default="/mounts/data/proj/ayyoob/align_induction/verse_alignments/", type=str)
266261
parser.add_argument('--lang_files_path', default="/nfs/datc/pbc/", type=str)
267-
parser.add_argument('--source_lang', default="eng", type=str)
268-
parser.add_argument('--target_lang', default="fra", type=str)
262+
parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str)
263+
parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str)
269264
parser.add_argument('--editions_file', default=os.path.join(current_path, "data/edition_lists/blinker_edition_list.txt" ), type=str)
270265
parser.add_argument('--aligner', default="inter", type=str)
271266

induce_alignments_NMF.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_aligns(rf, cf, alignments):
5555
res.append( ( int(x[0]), int(x[1]) ) )
5656
else:
5757
return None
58-
58+
5959
return res
6060

6161
def add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items):
@@ -69,7 +69,7 @@ def add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items):
6969
token_counts[re] = align[0]
7070
if align[1] > token_counts[ce]:
7171
token_counts[ce] = align[1]
72-
72+
7373
existing_items[re][ce].append(f"{align[0]},{align[1]}")
7474

7575
def add_negative_samples(aligns_dict, existing_items, token_counts, verse_id):
@@ -89,7 +89,7 @@ def add_negative_samples(aligns_dict, existing_items, token_counts, verse_id):
8989
aligns_dict['userID'].append(re + str(i))
9090
aligns_dict['itemID'].append(ce + str(jp))
9191
aligns_dict['rating'].append(1)
92-
92+
9393
ip %= (token_counts[re] + 1)
9494
aligns_dict['userID'].append(re + str(ip))
9595
aligns_dict['itemID'].append(ce + str(j))
@@ -112,11 +112,11 @@ def get_alignments_df(row_editions, col_editions, verse_alignments,
112112

113113
if not aligns is None:
114114
add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items)
115-
115+
116116
add_negative_samples(aligns_dict, existing_items, token_counts, verse_id)
117117

118118
return pd.DataFrame(aligns_dict), token_counts[source_edition], token_counts[target_edition]
119-
119+
120120
def iter_max(sim_matrix: np.ndarray, max_count: int=2, alpha_ratio = 0.7) -> np.ndarray:
121121
m, n = sim_matrix.shape
122122
forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n
@@ -156,14 +156,14 @@ def get_itermax_predictions(raw_s_predictions, max_count=2, alpha_ratio=0.9):
156156
for i in raw_s_predictions:
157157
for j, s in raw_s_predictions[i]:
158158
matrix[i,j] = s
159-
159+
160160
itermax_res = iter_max(matrix, max_count, alpha_ratio)
161161
res = []
162162
for i in range(rows):
163163
for j in range(cols):
164164
if itermax_res[i,j] != 0:
165165
res.append((i,j))
166-
166+
167167
return res
168168

169169
def predict_alignments(algo, source_edition, target_edition):
@@ -197,13 +197,13 @@ def train_model(df, s_tok_count, t_tok_count, row_editions, col_editions):
197197
algo.row_editions = row_editions
198198
algo.col_editions = col_editions
199199
algo.df = df
200-
200+
201201
return algo
202202

203203
def get_induced_alignments(source_edition, target_edition, verse_alignments_path, verse_id, all_editions):
204204

205205
verse_alignments = get_verse_alignments(verse_alignments_path, verse_id, editions=all_editions)
206-
206+
207207
# this is only for saving the gdfa alignments from source to target for the evauation
208208
verse_alignments_gdfa = get_verse_alignments(verse_alignments_path, verse_id, editions=[source_edition, target_edition], gdfa=True)
209209

@@ -213,17 +213,17 @@ def get_induced_alignments(source_edition, target_edition, verse_alignments_path
213213
df, s_tok_count, t_tok_count = get_alignments_df(row_editions, col_editions, verse_alignments, source_edition, target_edition, verse_id)
214214

215215
algo = train_model(df, s_tok_count, t_tok_count, row_editions, col_editions)
216-
216+
217217
predicted_alignments = predict_alignments(algo, source_edition, target_edition)
218218
base_inter_alignments = verse_alignments[source_edition][target_edition]
219219
base_gdfa_alignments = verse_alignments_gdfa[source_edition][target_edition]
220-
220+
221221
with cnt.get_lock():
222222
cnt.value += 1
223223
if cnt.value % 20 == 0:
224224
LOG.info(f"Done inferring alignments for {cnt.value} verses")
225225

226-
return predicted_alignments, base_inter_alignments, base_gdfa_alignments, len(algo.col_editions)+1
226+
return predicted_alignments, base_inter_alignments, base_gdfa_alignments, len(algo.col_editions) + 1
227227

228228

229229
def init_globals(counter):
@@ -255,7 +255,7 @@ def main(args):
255255

256256
# get predicted alignments using parallel processing
257257
cnt = Value('i', 0)
258-
with Pool(processes=args.core_count, initializer=init_globals, initargs=(cnt,)) as p:
258+
with Pool(processes=args.core_count, initializer=init_globals, initargs=(cnt,)) as p:
259259
all_alignments = p.starmap(get_induced_alignments, starmap_args)
260260

261261
out_NMF_f_name = f"predicted_alignments_from_{args.source_edition}_to_{args.target_edition}_with_max_{len(all_editions)}_editions_for_{len(all_verses)}_verses_NMF.txt"
@@ -283,14 +283,14 @@ def main(args):
283283
parser = argparse.ArgumentParser()
284284

285285
parser.add_argument('--save_path', default=os.path.join(current_path, "predicted_alignments"), type=str)
286-
parser.add_argument('--gold_file', default=os.path.join(current_path, "data/gold-standards/blinker/eng-fra.gold"), type=str)
286+
parser.add_argument('--gold_file', default=os.path.join(current_path, "data/gold-standards/blinker/eng-fra.gold"), type=str)
287287
parser.add_argument('--verse_alignments_path', default="/mounts/data/proj/ayyoob/align_induction/verse_alignments/", type=str)
288-
parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str)
289-
parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str)
288+
parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str)
289+
parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str)
290290
parser.add_argument('--editions_file', default=os.path.join(current_path, "data/edition_lists/blinker_edition_list.txt" ), type=str)
291291
parser.add_argument('--core_count', default=80, type=int)
292292
parser.add_argument('--seed', default=42, type=int)
293293

294294
args = parser.parse_args()
295295
main(args)
296-
296+

0 commit comments

Comments
 (0)