-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_ppl.py
65 lines (41 loc) · 1.97 KB
/
evaluate_ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import json
import torch
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str)
parser.add_argument("--data_path", type=str)
parser.add_argument("--images_dir", type=str)
parser.add_argument("--answer_field", type=str)
return parser.parse_args()
def main(model_id, data_path, images_dir, answer_field):
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, device_map="cuda:0", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)
with open(data_path, 'r', encoding='utf8') as f:
ppls = []
for l in tqdm(f):
line_data = json.loads(l)
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{line_data['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
completion = f"{line_data[answer_field]}<|eot_id|>"
image = Image.open(os.path.join(images_dir, line_data['file_name']))
inputs = processor(text=prompt + completion, images=image, return_tensors="pt")
response_token_ids = processor.tokenizer(completion)
labels_mask = inputs['input_ids'].clone()
ignore_index = -100
labels_mask[0, :len(labels_mask[0]) - len(response_token_ids[0]) + 1] = ignore_index
inputs['labels'] = labels_mask
with torch.no_grad():
generate_ids = model(**inputs.to('cuda:0'), return_dict=True)
ppl = torch.exp(generate_ids.loss).cpu()
if ppl > 1000:
continue
ppls.append(ppl)
print(f"PERPLEXITY FOR MODEL {model_id}: {np.mean(ppls)}")
if __name__ == "__main__":
args = get_args()
main(**vars(args))