|
| 1 | +import collections |
| 2 | +import numpy as np |
| 3 | +from tqdm import tqdm |
| 4 | +from functools import partial |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.optim import AdamW |
| 8 | +from torch.utils.data import DataLoader |
| 9 | +import evaluate |
| 10 | +from datasets import load_dataset |
| 11 | +from transformers import BertTokenizerFast, default_data_collator, get_scheduler |
| 12 | + |
| 13 | +from model import BertModel, BertForQuestionAnswering |
| 14 | +from sparsebit.sparse import parse_sconfig, SparseModel |
| 15 | + |
| 16 | + |
| 17 | +MAX_LENGTH = 384 |
| 18 | + |
| 19 | + |
| 20 | +def build_dataloader(args, raw_datasets): |
| 21 | + STRIDE = 128 |
| 22 | + tokenizer = BertTokenizerFast.from_pretrained(args.architecture, do_lower_case=True) |
| 23 | + |
| 24 | + def preprocess_training_examples(examples): |
| 25 | + questions = [q.strip() for q in examples["question"]] |
| 26 | + inputs = tokenizer( |
| 27 | + questions, |
| 28 | + examples["context"], |
| 29 | + max_length=MAX_LENGTH, |
| 30 | + truncation="only_second", |
| 31 | + stride=STRIDE, |
| 32 | + return_overflowing_tokens=True, |
| 33 | + return_offsets_mapping=True, |
| 34 | + padding="max_length", |
| 35 | + ) |
| 36 | + |
| 37 | + offset_mapping = inputs.pop("offset_mapping") |
| 38 | + sample_map = inputs.pop("overflow_to_sample_mapping") |
| 39 | + answers = examples["answers"] |
| 40 | + start_positions = [] |
| 41 | + end_positions = [] |
| 42 | + |
| 43 | + for i, offset in enumerate(offset_mapping): |
| 44 | + sample_idx = sample_map[i] |
| 45 | + answer = answers[sample_idx] |
| 46 | + start_char = answer["answer_start"][0] |
| 47 | + end_char = answer["answer_start"][0] + len(answer["text"][0]) |
| 48 | + sequence_ids = inputs.sequence_ids(i) |
| 49 | + |
| 50 | + # Find the start and end of the context |
| 51 | + idx = 0 |
| 52 | + while sequence_ids[idx] != 1: |
| 53 | + idx += 1 |
| 54 | + context_start = idx |
| 55 | + while sequence_ids[idx] == 1: |
| 56 | + idx += 1 |
| 57 | + context_end = idx - 1 |
| 58 | + |
| 59 | + # If the answer is not fully inside the context, label is (0, 0) |
| 60 | + if ( |
| 61 | + offset[context_start][0] > start_char |
| 62 | + or offset[context_end][1] < end_char |
| 63 | + ): |
| 64 | + start_positions.append(0) |
| 65 | + end_positions.append(0) |
| 66 | + else: |
| 67 | + # Otherwise it's the start and end token positions |
| 68 | + idx = context_start |
| 69 | + while idx <= context_end and offset[idx][0] <= start_char: |
| 70 | + idx += 1 |
| 71 | + start_positions.append(idx - 1) |
| 72 | + |
| 73 | + idx = context_end |
| 74 | + while idx >= context_start and offset[idx][1] >= end_char: |
| 75 | + idx -= 1 |
| 76 | + end_positions.append(idx + 1) |
| 77 | + |
| 78 | + inputs["start_positions"] = start_positions |
| 79 | + inputs["end_positions"] = end_positions |
| 80 | + return inputs |
| 81 | + |
| 82 | + def preprocess_validation_examples(examples): |
| 83 | + questions = [q.strip() for q in examples["question"]] |
| 84 | + inputs = tokenizer( |
| 85 | + questions, |
| 86 | + examples["context"], |
| 87 | + max_length=MAX_LENGTH, |
| 88 | + truncation="only_second", |
| 89 | + stride=STRIDE, |
| 90 | + return_overflowing_tokens=True, |
| 91 | + return_offsets_mapping=True, |
| 92 | + padding="max_length", |
| 93 | + ) |
| 94 | + |
| 95 | + sample_map = inputs.pop("overflow_to_sample_mapping") |
| 96 | + example_ids = [] |
| 97 | + |
| 98 | + for i in range(len(inputs["input_ids"])): |
| 99 | + sample_idx = sample_map[i] |
| 100 | + example_ids.append(examples["id"][sample_idx]) |
| 101 | + |
| 102 | + sequence_ids = inputs.sequence_ids(i) |
| 103 | + offset = inputs["offset_mapping"][i] |
| 104 | + inputs["offset_mapping"][i] = [ |
| 105 | + o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) |
| 106 | + ] |
| 107 | + |
| 108 | + inputs["example_id"] = example_ids |
| 109 | + return inputs |
| 110 | + |
| 111 | + train_dataset = raw_datasets["train"].map( |
| 112 | + preprocess_training_examples, |
| 113 | + batched=True, |
| 114 | + remove_columns=raw_datasets["train"].column_names, |
| 115 | + ) |
| 116 | + validation_dataset = raw_datasets["validation"].map( |
| 117 | + preprocess_validation_examples, |
| 118 | + batched=True, |
| 119 | + remove_columns=raw_datasets["validation"].column_names, |
| 120 | + ) |
| 121 | + |
| 122 | + train_dataset.set_format("torch") |
| 123 | + validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"]) |
| 124 | + validation_set.set_format("torch") |
| 125 | + |
| 126 | + train_dataloader = DataLoader( |
| 127 | + train_dataset, |
| 128 | + shuffle=True, |
| 129 | + collate_fn=default_data_collator, |
| 130 | + batch_size=args.batch_size, |
| 131 | + ) |
| 132 | + |
| 133 | + eval_dataloader = DataLoader( |
| 134 | + validation_set, |
| 135 | + collate_fn=default_data_collator, |
| 136 | + batch_size=2 * args.batch_size, |
| 137 | + ) |
| 138 | + |
| 139 | + return train_dataloader, eval_dataloader, validation_dataset |
| 140 | + |
| 141 | + |
| 142 | +def compute_metrics(start_logits, end_logits, features, examples): |
| 143 | + example_to_features = collections.defaultdict(list) |
| 144 | + for idx, feature in enumerate(features): |
| 145 | + example_to_features[feature["example_id"]].append(idx) |
| 146 | + |
| 147 | + n_best = 20 |
| 148 | + max_answer_length = 30 |
| 149 | + predicted_answers = [] |
| 150 | + for example in tqdm(examples): |
| 151 | + example_id = example["id"] |
| 152 | + context = example["context"] |
| 153 | + answers = [] |
| 154 | + |
| 155 | + # Loop through all features associated with that example |
| 156 | + for feature_index in example_to_features[example_id]: |
| 157 | + start_logit = start_logits[feature_index] |
| 158 | + end_logit = end_logits[feature_index] |
| 159 | + offsets = features[feature_index]["offset_mapping"] |
| 160 | + |
| 161 | + start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() |
| 162 | + end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() |
| 163 | + for start_index in start_indexes: |
| 164 | + for end_index in end_indexes: |
| 165 | + # Skip answers that are not fully in the context |
| 166 | + if offsets[start_index] is None or offsets[end_index] is None: |
| 167 | + continue |
| 168 | + # Skip answers with a length that is either < 0 or > max_answer_length |
| 169 | + if ( |
| 170 | + end_index < start_index |
| 171 | + or end_index - start_index + 1 > max_answer_length |
| 172 | + ): |
| 173 | + continue |
| 174 | + |
| 175 | + answer = { |
| 176 | + "text": context[ |
| 177 | + offsets[start_index][0] : offsets[end_index][1] |
| 178 | + ], |
| 179 | + "logit_score": start_logit[start_index] + end_logit[end_index], |
| 180 | + } |
| 181 | + answers.append(answer) |
| 182 | + |
| 183 | + # Select the answer with the best score |
| 184 | + if len(answers) > 0: |
| 185 | + best_answer = max(answers, key=lambda x: x["logit_score"]) |
| 186 | + predicted_answers.append( |
| 187 | + {"id": example_id, "prediction_text": best_answer["text"]} |
| 188 | + ) |
| 189 | + else: |
| 190 | + predicted_answers.append({"id": example_id, "prediction_text": ""}) |
| 191 | + |
| 192 | + theoretical_answers = [ |
| 193 | + {"id": ex["id"], "answers": ex["answers"]} for ex in examples |
| 194 | + ] |
| 195 | + metric = evaluate.load("squad") |
| 196 | + return metric.compute(predictions=predicted_answers, references=theoretical_answers) |
| 197 | + |
| 198 | + |
| 199 | +def main(args): |
| 200 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 201 | + |
| 202 | + raw_datasets = load_dataset("squad") |
| 203 | + train_dataloader, eval_dataloader, validation_dataset = build_dataloader( |
| 204 | + args, raw_datasets |
| 205 | + ) |
| 206 | + |
| 207 | + bert_model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False) |
| 208 | + bert_model.embeddings.seq_length = MAX_LENGTH |
| 209 | + bert_model.config.num_labels = 2 |
| 210 | + # model = BertForQuestionAnswering(bert_model, bert_model.config) |
| 211 | + sconfig = parse_sconfig(args.sconfig) |
| 212 | + sbert = SparseModel(bert_model, sconfig) |
| 213 | + sbert.calc_params() |
| 214 | + model = BertForQuestionAnswering(sbert, bert_model.config) |
| 215 | + model.cuda() |
| 216 | + |
| 217 | + total_training_steps = args.epochs * len(train_dataloader) |
| 218 | + optimizer = AdamW(model.parameters(), lr=args.lr) |
| 219 | + lr_scheduler = get_scheduler( |
| 220 | + "linear", |
| 221 | + optimizer=optimizer, |
| 222 | + num_warmup_steps=0, |
| 223 | + num_training_steps=total_training_steps, |
| 224 | + ) |
| 225 | + |
| 226 | + progress_bar = tqdm(range(total_training_steps)) |
| 227 | + |
| 228 | + print(sconfig) |
| 229 | + for epoch in range(args.epochs): |
| 230 | + # training |
| 231 | + model.train() |
| 232 | + for step, batch in enumerate(train_dataloader): |
| 233 | + batch = {k: v.to(device) for k, v in batch.items()} |
| 234 | + outputs = model(**batch) |
| 235 | + loss = outputs.loss |
| 236 | + loss.backward() |
| 237 | + optimizer.step() |
| 238 | + lr_scheduler.step() |
| 239 | + optimizer.zero_grad() |
| 240 | + progress_bar.update(1) |
| 241 | + # evaluation |
| 242 | + model.eval() |
| 243 | + start_logits, end_logits = [], [] |
| 244 | + for batch in tqdm(eval_dataloader): |
| 245 | + batch = {k: v.to(device) for k, v in batch.items()} |
| 246 | + with torch.no_grad(): |
| 247 | + outputs = model(**batch) |
| 248 | + start_logits.append(outputs.start_logits.cpu().numpy()) |
| 249 | + end_logits.append(outputs.end_logits.cpu().numpy()) |
| 250 | + # cat all results to evaluate f1-score |
| 251 | + start_logits = np.concatenate(start_logits)[: len(validation_dataset)] |
| 252 | + end_logits = np.concatenate(end_logits)[: len(validation_dataset)] |
| 253 | + |
| 254 | + metrics = compute_metrics( |
| 255 | + start_logits, end_logits, validation_dataset, raw_datasets["validation"] |
| 256 | + ) |
| 257 | + print(f"epoch {epoch}:", metrics) |
| 258 | + print(sconfig) |
| 259 | + |
| 260 | + |
| 261 | +if __name__ == "__main__": |
| 262 | + import argparse |
| 263 | + |
| 264 | + parser = argparse.ArgumentParser() |
| 265 | + parser.add_argument("sconfig") |
| 266 | + parser.add_argument( |
| 267 | + "--architecture", |
| 268 | + type=str, |
| 269 | + help="the architecture of BERT", |
| 270 | + default="bert-base-uncased", |
| 271 | + ) |
| 272 | + parser.add_argument("--batch-size", type=int, default=8) |
| 273 | + parser.add_argument("--epochs", type=int, default=3) |
| 274 | + parser.add_argument("--lr", type=int, default=2e-5) |
| 275 | + args = parser.parse_args() |
| 276 | + |
| 277 | + main(args) |
0 commit comments