|
| 1 | +from surprise import Dataset, Reader |
| 2 | +import pandas as pd |
| 3 | +from surprise import NMF |
| 4 | +from utils import LOG, setup_dict_entry, load_editions, load_gold, get_verse_alignments |
| 5 | +import os, argparse, random, math |
| 6 | +from nltk.translate.gdfa import grow_diag_final_and |
| 7 | +import numpy as np |
| 8 | +from multiprocessing import Pool, Value |
| 9 | + |
| 10 | + |
| 11 | +def get_row_col_editions(source_edition, target_edition, all_editions=None): |
| 12 | + row_editions = [] |
| 13 | + col_editions = [] |
| 14 | + for edition in all_editions: |
| 15 | + if edition != source_edition and edition != target_edition: |
| 16 | + row_editions.append(edition) |
| 17 | + col_editions.append(edition) |
| 18 | + |
| 19 | + row_editions.append(source_edition) |
| 20 | + col_editions.append(target_edition) |
| 21 | + |
| 22 | + return row_editions, col_editions |
| 23 | + |
| 24 | +def get_aligns(rf, cf, alignments): |
| 25 | + raw_align = '' |
| 26 | + |
| 27 | + if rf in alignments and cf in alignments[rf]: |
| 28 | + raw_align = alignments[rf][cf] |
| 29 | + alignment_line = [x.split('-') for x in raw_align.split()] |
| 30 | + res = [] |
| 31 | + for x in alignment_line: |
| 32 | + res.append( ( int(x[0]), int(x[1]) ) ) |
| 33 | + elif cf in alignments and rf in alignments[cf]: # re: aak, ce: aai, |
| 34 | + raw_align = alignments[cf][rf] |
| 35 | + alignment_line = [x.split('-') for x in raw_align.split()] |
| 36 | + res = [] |
| 37 | + for x in alignment_line: |
| 38 | + res.append( ( int(x[1]), int(x[0]) ) ) |
| 39 | + elif rf in alignments and rf == cf: # if source and target are the same |
| 40 | + keys = list(alignments[rf].keys()) |
| 41 | + max_count = 0 |
| 42 | + for key in keys: |
| 43 | + align = alignments[rf][key] |
| 44 | + for x in align.split(): |
| 45 | + count = int(x.split('-')[0]) |
| 46 | + if count > max_count: |
| 47 | + max_count = count |
| 48 | + raw_align = "0-0" |
| 49 | + for i in range(1,max_count): |
| 50 | + raw_align += f" {i}-{i}" |
| 51 | + |
| 52 | + alignment_line = [x.split('-') for x in raw_align.split()] |
| 53 | + res = [] |
| 54 | + for x in alignment_line: |
| 55 | + res.append( ( int(x[0]), int(x[1]) ) ) |
| 56 | + else: |
| 57 | + return None |
| 58 | + |
| 59 | + return res |
| 60 | + |
| 61 | +def add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items): |
| 62 | + for align in aligns: |
| 63 | + |
| 64 | + aligns_dict['userID'].append(re + str(align[0])) |
| 65 | + aligns_dict['itemID'].append(ce + str(align[1])) |
| 66 | + aligns_dict['rating'].append(3) |
| 67 | + |
| 68 | + if align[0] > token_counts[re]: |
| 69 | + token_counts[re] = align[0] |
| 70 | + if align[1] > token_counts[ce]: |
| 71 | + token_counts[ce] = align[1] |
| 72 | + |
| 73 | + existing_items[re][ce].append(f"{align[0]},{align[1]}") |
| 74 | + |
| 75 | +def add_negative_samples(aligns_dict, existing_items, token_counts, verse_id): |
| 76 | + for re in existing_items: |
| 77 | + if token_counts[re] < 2: |
| 78 | + continue |
| 79 | + for ce in existing_items[re]: |
| 80 | + if token_counts[ce] < 2: |
| 81 | + continue |
| 82 | + for item in existing_items[re][ce]: |
| 83 | + i,j = tuple(item.split(",")) |
| 84 | + i,j = (int(i), int(j)) |
| 85 | + jp = random.randint(math.ceil(j+1), math.ceil(j+token_counts[ce] )) |
| 86 | + ip = random.randint(math.ceil(i+1), math.ceil(i+token_counts[re] )) |
| 87 | + |
| 88 | + jp %= (token_counts[ce] + 1) |
| 89 | + aligns_dict['userID'].append(re + str(i)) |
| 90 | + aligns_dict['itemID'].append(ce + str(jp)) |
| 91 | + aligns_dict['rating'].append(1) |
| 92 | + |
| 93 | + ip %= (token_counts[re] + 1) |
| 94 | + aligns_dict['userID'].append(re + str(ip)) |
| 95 | + aligns_dict['itemID'].append(ce + str(j)) |
| 96 | + aligns_dict['rating'].append(1) |
| 97 | + |
| 98 | +def get_alignments_df(row_editions, col_editions, verse_alignments, |
| 99 | + source_edition, target_edition, verse_id): #TODO can be improved a lot |
| 100 | + token_counts = {} |
| 101 | + existing_items = {} |
| 102 | + aligns_dict = {'itemID': [], 'userID': [], 'rating': []} |
| 103 | + for no, re in enumerate(row_editions): |
| 104 | + token_counts[re] = 0 |
| 105 | + existing_items[re] = {} |
| 106 | + |
| 107 | + for ce in col_editions: |
| 108 | + |
| 109 | + setup_dict_entry(token_counts, ce, 0) |
| 110 | + existing_items[re][ce] = [] |
| 111 | + aligns = get_aligns(re, ce, verse_alignments) |
| 112 | + |
| 113 | + if not aligns is None: |
| 114 | + add_aligns(aligns, aligns_dict, token_counts, re, ce, existing_items) |
| 115 | + |
| 116 | + add_negative_samples(aligns_dict, existing_items, token_counts, verse_id) |
| 117 | + |
| 118 | + return pd.DataFrame(aligns_dict), token_counts[source_edition], token_counts[target_edition] |
| 119 | + |
| 120 | +def iter_max(sim_matrix: np.ndarray, max_count: int=2, alpha_ratio = 0.7) -> np.ndarray: |
| 121 | + m, n = sim_matrix.shape |
| 122 | + forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n |
| 123 | + backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m |
| 124 | + inter = forward * backward.transpose() |
| 125 | + |
| 126 | + if min(m, n) <= 2: |
| 127 | + return inter |
| 128 | + |
| 129 | + new_inter = np.zeros((m, n)) |
| 130 | + count = 1 |
| 131 | + while count < max_count: |
| 132 | + mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0) |
| 133 | + mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0) |
| 134 | + mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0) |
| 135 | + mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y)) |
| 136 | + if mask_x.sum() < 1.0 or mask_y.sum() < 1.0: |
| 137 | + mask *= 0.0 |
| 138 | + mask_zeros *= 0.0 |
| 139 | + |
| 140 | + new_sim = sim_matrix * mask |
| 141 | + fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros |
| 142 | + bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros |
| 143 | + new_inter = fwd * bac |
| 144 | + |
| 145 | + if np.array_equal(inter + new_inter, inter): |
| 146 | + break |
| 147 | + inter = inter + new_inter |
| 148 | + count += 1 |
| 149 | + return inter |
| 150 | + |
| 151 | +def get_itermax_predictions(raw_s_predictions, max_count=2, alpha_ratio=0.9): |
| 152 | + rows = len(raw_s_predictions) |
| 153 | + cols = len(raw_s_predictions[0]) |
| 154 | + matrix = np.ndarray(shape=(rows, cols), dtype=float) |
| 155 | + |
| 156 | + for i in raw_s_predictions: |
| 157 | + for j, s in raw_s_predictions[i]: |
| 158 | + matrix[i,j] = s |
| 159 | + |
| 160 | + itermax_res = iter_max(matrix, max_count, alpha_ratio) |
| 161 | + res = [] |
| 162 | + for i in range(rows): |
| 163 | + for j in range(cols): |
| 164 | + if itermax_res[i,j] != 0: |
| 165 | + res.append((i,j)) |
| 166 | + |
| 167 | + return res |
| 168 | + |
| 169 | +def predict_alignments(algo, source_edition, target_edition): |
| 170 | + raw_s_predictions = {} |
| 171 | + raw_t_predictions = {} |
| 172 | + |
| 173 | + for i in range(algo.s_tok_count + 1): |
| 174 | + for j in range(algo.t_tok_count + 1): |
| 175 | + pred = algo.predict(source_edition + str(i), target_edition + str(j)) |
| 176 | + |
| 177 | + setup_dict_entry(raw_s_predictions, i, []) |
| 178 | + setup_dict_entry(raw_t_predictions, j, []) |
| 179 | + |
| 180 | + raw_s_predictions[i].append((j, pred.est)) |
| 181 | + raw_t_predictions[j].append((i, pred.est)) |
| 182 | + |
| 183 | + # get predicted alignments from argmax (max_count=1 means argmax) |
| 184 | + res = get_itermax_predictions(raw_s_predictions, max_count=1) |
| 185 | + |
| 186 | + return res |
| 187 | + |
| 188 | +def train_model(df, s_tok_count, t_tok_count, row_editions, col_editions): |
| 189 | + algo = NMF() |
| 190 | + reader = Reader(rating_scale=(1, 3)) |
| 191 | + data = Dataset.load_from_df(df[['userID', 'itemID', 'rating']], reader) |
| 192 | + trainset = data.build_full_trainset() |
| 193 | + algo.fit(trainset) |
| 194 | + |
| 195 | + algo.s_tok_count = s_tok_count |
| 196 | + algo.t_tok_count = t_tok_count |
| 197 | + algo.row_editions = row_editions |
| 198 | + algo.col_editions = col_editions |
| 199 | + algo.df = df |
| 200 | + |
| 201 | + return algo |
| 202 | + |
| 203 | +def get_induced_alignments(source_edition, target_edition, verse_alignments_path, verse_id, all_editions): |
| 204 | + |
| 205 | + verse_alignments = get_verse_alignments(verse_alignments_path, verse_id, editions=all_editions) |
| 206 | + |
| 207 | + # this is only for saving the gdfa alignments from source to target for the evauation |
| 208 | + verse_alignments_gdfa = get_verse_alignments(verse_alignments_path, verse_id, editions=[source_edition, target_edition], gdfa=True) |
| 209 | + |
| 210 | + ### source -> row, target-> col### |
| 211 | + row_editions, col_editions = get_row_col_editions(source_edition, target_edition, all_editions) |
| 212 | + #itemid -> col, user -> row |
| 213 | + df, s_tok_count, t_tok_count = get_alignments_df(row_editions, col_editions, verse_alignments, source_edition, target_edition, verse_id) |
| 214 | + |
| 215 | + algo = train_model(df, s_tok_count, t_tok_count, row_editions, col_editions) |
| 216 | + |
| 217 | + predicted_alignments = predict_alignments(algo, source_edition, target_edition) |
| 218 | + base_inter_alignments = verse_alignments[source_edition][target_edition] |
| 219 | + base_gdfa_alignments = verse_alignments_gdfa[source_edition][target_edition] |
| 220 | + |
| 221 | + with cnt.get_lock(): |
| 222 | + cnt.value += 1 |
| 223 | + if cnt.value % 20 == 0: |
| 224 | + LOG.info(f"Done inferring alignments for {cnt.value} verses") |
| 225 | + |
| 226 | + return predicted_alignments, base_inter_alignments, base_gdfa_alignments, len(algo.col_editions)+1 |
| 227 | + |
| 228 | + |
| 229 | +def init_globals(counter): |
| 230 | + global cnt |
| 231 | + cnt = counter |
| 232 | + |
| 233 | +def main(args): |
| 234 | + random.seed(args.seed) |
| 235 | + |
| 236 | + pros, surs = load_gold(args.gold_file) |
| 237 | + all_verses =list(pros.keys()) |
| 238 | + all_verses = all_verses |
| 239 | + |
| 240 | + # Get languages and editions |
| 241 | + editions, langs = load_editions(args.editions_file) |
| 242 | + all_editions = [editions[lang] for lang in langs] |
| 243 | + |
| 244 | + # print some info |
| 245 | + LOG.info(f"Inferring alignments from {args.source_edition} to {args.target_edition}") |
| 246 | + LOG.info(f"Number of verses whose alignments will be inferred: {len(all_verses)}") |
| 247 | + LOG.info(f"Number of editions to use for the graph algorithms: {len(all_editions)}") |
| 248 | + LOG.info(f"Number of cores to be used for processing: {args.core_count}") |
| 249 | + |
| 250 | + # Prepare arguments for parallel processing |
| 251 | + starmap_args = [] |
| 252 | + for verse_id in all_verses: |
| 253 | + # aligns_predicted, used_edition_count = get_induced_alignments(args.source_edition, args.target_edition, args.verse_alignments, verse_id, all_editions) |
| 254 | + starmap_args.append((args.source_edition, args.target_edition, args.verse_alignments_path, verse_id, all_editions)) |
| 255 | + |
| 256 | + # get predicted alignments using parallel processing |
| 257 | + cnt = Value('i', 0) |
| 258 | + with Pool(processes=args.core_count, initializer=init_globals, initargs=(cnt,)) as p: |
| 259 | + all_alignments = p.starmap(get_induced_alignments, starmap_args) |
| 260 | + |
| 261 | + out_NMF_f_name = f"predicted_alignments_from_{args.source_edition}_to_{args.target_edition}_with_max_{len(all_editions)}_editions_for_{len(all_verses)}_verses_NMF.txt" |
| 262 | + out_NMF_file = open(os.path.join(args.save_path, out_NMF_f_name), 'w') |
| 263 | + out_inter_f_name = f"intersection_alignments_from_{args.source_edition}_to_{args.target_edition}_for_{len(all_verses)}_verses.txt" |
| 264 | + out_inter_file = open(os.path.join(args.save_path, out_inter_f_name), 'w') |
| 265 | + out_gdfa_f_name = f"gdfa_alignments_from_{args.source_edition}_to_{args.target_edition}_for_{len(all_verses)}_verses.txt" |
| 266 | + out_gdfa_file = open(os.path.join(args.save_path, out_gdfa_f_name), 'w') |
| 267 | + |
| 268 | + for id, verse_id in enumerate(all_verses): |
| 269 | + aligns_predicted, inter_aligns, gdfa_aligns, used_edition_count = all_alignments[id] |
| 270 | + |
| 271 | + # convert predicted alignments to string and write to a file |
| 272 | + aligns_predicted = ' '.join([f"{align[0]}-{align[1]}" for align in aligns_predicted]) |
| 273 | + out_NMF_file.write(f"{verse_id}\t{aligns_predicted}\n") |
| 274 | + out_inter_file.write(f"{verse_id}\t{inter_aligns.strip()}\n") |
| 275 | + out_gdfa_file.write(f"{verse_id}\t{gdfa_aligns.strip()}\n") |
| 276 | + |
| 277 | + out_NMF_file.close() |
| 278 | + out_inter_file.close() |
| 279 | + out_gdfa_file.close() |
| 280 | + |
| 281 | +if __name__ == "__main__": |
| 282 | + parser = argparse.ArgumentParser() |
| 283 | + |
| 284 | + parser.add_argument('--save_path', default="/mounts/Users/cisintern/lksenel/Projects/pbc/graph-align/predicted_alignments", type=str) |
| 285 | + parser.add_argument('--gold_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/pbc_utils/data/eng_fra_pbc/eng-fra.gold", type=str) |
| 286 | + parser.add_argument('--verse_alignments_path', default="/mounts/data/proj/ayyoob/align_induction/verse_alignments/", type=str) |
| 287 | + parser.add_argument('--source_edition', default="eng-x-bible-mixed", type=str) |
| 288 | + parser.add_argument('--target_edition', default="fra-x-bible-louissegond", type=str) |
| 289 | + parser.add_argument('--editions_file', default="/mounts/Users/cisintern/lksenel/Projects/pbc/pbc_utils/data/eng_fra_pbc/lang_list.txt", type=str) |
| 290 | + parser.add_argument('--core_count', default=80, type=int) |
| 291 | + parser.add_argument('--seed', default=42, type=int) |
| 292 | + |
| 293 | + args = parser.parse_args() |
| 294 | + main(args) |
| 295 | + |
0 commit comments