Skip to content

Commit a665d03

Browse files
authored
add an unstructured prune demo of BERT on SQuAD1.1 (#101)
1 parent 3beb6a5 commit a665d03

File tree

6 files changed

+677
-1
lines changed

6 files changed

+677
-1
lines changed

examples/unstructured_prune/GLUE/bert/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from transformers.utils import fx as hf_fx
1717
import torch.fx as torch_fx
1818
from sparsebit.sparse import parse_sconfig, SparseModel
19-
from tracer import my_trace
2019
from model import BertModel, BertForSequenceClassification
2120
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
2221

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
## Introduction
2+
- We introduced a SQuAD demo to demonstrate how to apply L1-norm weight sparser to BERT.
3+
4+
## Run
5+
6+
### Install Requirements
7+
- `pip install -r requirements.txt`
8+
9+
### fine-tuning
10+
- `python main.py sconfig.yaml`
11+
12+
## Results
13+
- sratio = #zeros / #totol\_params
14+
- Only BERT-Encoder be sparsed, excludes embedding & heads
15+
- For convenience, we use f1-score as the metric here
16+
17+
model | sparser | sratio=0.0 | sratio=0.25 | sratio=0.5 | sratio=0.75 | sratio=1.0 |
18+
--- | --- | --- | --- | --- | --- | --- |
19+
bert-base-uncased | l1-norm | 88.57 | 88.09 | 86.98 | 75.42 | 10.46 |
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import collections
2+
import numpy as np
3+
from tqdm import tqdm
4+
from functools import partial
5+
6+
import torch
7+
from torch.optim import AdamW
8+
from torch.utils.data import DataLoader
9+
import evaluate
10+
from datasets import load_dataset
11+
from transformers import BertTokenizerFast, default_data_collator, get_scheduler
12+
13+
from model import BertModel, BertForQuestionAnswering
14+
from sparsebit.sparse import parse_sconfig, SparseModel
15+
16+
17+
MAX_LENGTH = 384
18+
19+
20+
def build_dataloader(args, raw_datasets):
21+
STRIDE = 128
22+
tokenizer = BertTokenizerFast.from_pretrained(args.architecture, do_lower_case=True)
23+
24+
def preprocess_training_examples(examples):
25+
questions = [q.strip() for q in examples["question"]]
26+
inputs = tokenizer(
27+
questions,
28+
examples["context"],
29+
max_length=MAX_LENGTH,
30+
truncation="only_second",
31+
stride=STRIDE,
32+
return_overflowing_tokens=True,
33+
return_offsets_mapping=True,
34+
padding="max_length",
35+
)
36+
37+
offset_mapping = inputs.pop("offset_mapping")
38+
sample_map = inputs.pop("overflow_to_sample_mapping")
39+
answers = examples["answers"]
40+
start_positions = []
41+
end_positions = []
42+
43+
for i, offset in enumerate(offset_mapping):
44+
sample_idx = sample_map[i]
45+
answer = answers[sample_idx]
46+
start_char = answer["answer_start"][0]
47+
end_char = answer["answer_start"][0] + len(answer["text"][0])
48+
sequence_ids = inputs.sequence_ids(i)
49+
50+
# Find the start and end of the context
51+
idx = 0
52+
while sequence_ids[idx] != 1:
53+
idx += 1
54+
context_start = idx
55+
while sequence_ids[idx] == 1:
56+
idx += 1
57+
context_end = idx - 1
58+
59+
# If the answer is not fully inside the context, label is (0, 0)
60+
if (
61+
offset[context_start][0] > start_char
62+
or offset[context_end][1] < end_char
63+
):
64+
start_positions.append(0)
65+
end_positions.append(0)
66+
else:
67+
# Otherwise it's the start and end token positions
68+
idx = context_start
69+
while idx <= context_end and offset[idx][0] <= start_char:
70+
idx += 1
71+
start_positions.append(idx - 1)
72+
73+
idx = context_end
74+
while idx >= context_start and offset[idx][1] >= end_char:
75+
idx -= 1
76+
end_positions.append(idx + 1)
77+
78+
inputs["start_positions"] = start_positions
79+
inputs["end_positions"] = end_positions
80+
return inputs
81+
82+
def preprocess_validation_examples(examples):
83+
questions = [q.strip() for q in examples["question"]]
84+
inputs = tokenizer(
85+
questions,
86+
examples["context"],
87+
max_length=MAX_LENGTH,
88+
truncation="only_second",
89+
stride=STRIDE,
90+
return_overflowing_tokens=True,
91+
return_offsets_mapping=True,
92+
padding="max_length",
93+
)
94+
95+
sample_map = inputs.pop("overflow_to_sample_mapping")
96+
example_ids = []
97+
98+
for i in range(len(inputs["input_ids"])):
99+
sample_idx = sample_map[i]
100+
example_ids.append(examples["id"][sample_idx])
101+
102+
sequence_ids = inputs.sequence_ids(i)
103+
offset = inputs["offset_mapping"][i]
104+
inputs["offset_mapping"][i] = [
105+
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
106+
]
107+
108+
inputs["example_id"] = example_ids
109+
return inputs
110+
111+
train_dataset = raw_datasets["train"].map(
112+
preprocess_training_examples,
113+
batched=True,
114+
remove_columns=raw_datasets["train"].column_names,
115+
)
116+
validation_dataset = raw_datasets["validation"].map(
117+
preprocess_validation_examples,
118+
batched=True,
119+
remove_columns=raw_datasets["validation"].column_names,
120+
)
121+
122+
train_dataset.set_format("torch")
123+
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
124+
validation_set.set_format("torch")
125+
126+
train_dataloader = DataLoader(
127+
train_dataset,
128+
shuffle=True,
129+
collate_fn=default_data_collator,
130+
batch_size=args.batch_size,
131+
)
132+
133+
eval_dataloader = DataLoader(
134+
validation_set,
135+
collate_fn=default_data_collator,
136+
batch_size=2 * args.batch_size,
137+
)
138+
139+
return train_dataloader, eval_dataloader, validation_dataset
140+
141+
142+
def compute_metrics(start_logits, end_logits, features, examples):
143+
example_to_features = collections.defaultdict(list)
144+
for idx, feature in enumerate(features):
145+
example_to_features[feature["example_id"]].append(idx)
146+
147+
n_best = 20
148+
max_answer_length = 30
149+
predicted_answers = []
150+
for example in tqdm(examples):
151+
example_id = example["id"]
152+
context = example["context"]
153+
answers = []
154+
155+
# Loop through all features associated with that example
156+
for feature_index in example_to_features[example_id]:
157+
start_logit = start_logits[feature_index]
158+
end_logit = end_logits[feature_index]
159+
offsets = features[feature_index]["offset_mapping"]
160+
161+
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
162+
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
163+
for start_index in start_indexes:
164+
for end_index in end_indexes:
165+
# Skip answers that are not fully in the context
166+
if offsets[start_index] is None or offsets[end_index] is None:
167+
continue
168+
# Skip answers with a length that is either < 0 or > max_answer_length
169+
if (
170+
end_index < start_index
171+
or end_index - start_index + 1 > max_answer_length
172+
):
173+
continue
174+
175+
answer = {
176+
"text": context[
177+
offsets[start_index][0] : offsets[end_index][1]
178+
],
179+
"logit_score": start_logit[start_index] + end_logit[end_index],
180+
}
181+
answers.append(answer)
182+
183+
# Select the answer with the best score
184+
if len(answers) > 0:
185+
best_answer = max(answers, key=lambda x: x["logit_score"])
186+
predicted_answers.append(
187+
{"id": example_id, "prediction_text": best_answer["text"]}
188+
)
189+
else:
190+
predicted_answers.append({"id": example_id, "prediction_text": ""})
191+
192+
theoretical_answers = [
193+
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
194+
]
195+
metric = evaluate.load("squad")
196+
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
197+
198+
199+
def main(args):
200+
device = "cuda" if torch.cuda.is_available() else "cpu"
201+
202+
raw_datasets = load_dataset("squad")
203+
train_dataloader, eval_dataloader, validation_dataset = build_dataloader(
204+
args, raw_datasets
205+
)
206+
207+
bert_model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False)
208+
bert_model.embeddings.seq_length = MAX_LENGTH
209+
bert_model.config.num_labels = 2
210+
# model = BertForQuestionAnswering(bert_model, bert_model.config)
211+
sconfig = parse_sconfig(args.sconfig)
212+
sbert = SparseModel(bert_model, sconfig)
213+
sbert.calc_params()
214+
model = BertForQuestionAnswering(sbert, bert_model.config)
215+
model.cuda()
216+
217+
total_training_steps = args.epochs * len(train_dataloader)
218+
optimizer = AdamW(model.parameters(), lr=args.lr)
219+
lr_scheduler = get_scheduler(
220+
"linear",
221+
optimizer=optimizer,
222+
num_warmup_steps=0,
223+
num_training_steps=total_training_steps,
224+
)
225+
226+
progress_bar = tqdm(range(total_training_steps))
227+
228+
print(sconfig)
229+
for epoch in range(args.epochs):
230+
# training
231+
model.train()
232+
for step, batch in enumerate(train_dataloader):
233+
batch = {k: v.to(device) for k, v in batch.items()}
234+
outputs = model(**batch)
235+
loss = outputs.loss
236+
loss.backward()
237+
optimizer.step()
238+
lr_scheduler.step()
239+
optimizer.zero_grad()
240+
progress_bar.update(1)
241+
# evaluation
242+
model.eval()
243+
start_logits, end_logits = [], []
244+
for batch in tqdm(eval_dataloader):
245+
batch = {k: v.to(device) for k, v in batch.items()}
246+
with torch.no_grad():
247+
outputs = model(**batch)
248+
start_logits.append(outputs.start_logits.cpu().numpy())
249+
end_logits.append(outputs.end_logits.cpu().numpy())
250+
# cat all results to evaluate f1-score
251+
start_logits = np.concatenate(start_logits)[: len(validation_dataset)]
252+
end_logits = np.concatenate(end_logits)[: len(validation_dataset)]
253+
254+
metrics = compute_metrics(
255+
start_logits, end_logits, validation_dataset, raw_datasets["validation"]
256+
)
257+
print(f"epoch {epoch}:", metrics)
258+
print(sconfig)
259+
260+
261+
if __name__ == "__main__":
262+
import argparse
263+
264+
parser = argparse.ArgumentParser()
265+
parser.add_argument("sconfig")
266+
parser.add_argument(
267+
"--architecture",
268+
type=str,
269+
help="the architecture of BERT",
270+
default="bert-base-uncased",
271+
)
272+
parser.add_argument("--batch-size", type=int, default=8)
273+
parser.add_argument("--epochs", type=int, default=3)
274+
parser.add_argument("--lr", type=int, default=2e-5)
275+
args = parser.parse_args()
276+
277+
main(args)

0 commit comments

Comments
 (0)