-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathcreate_vision_cot_data.py
153 lines (122 loc) · 4.85 KB
/
create_vision_cot_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import base64
import concurrent.futures
import io
import json
import os
import random
import re
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
from tqdm import tqdm
import bytedtos
import seaborn as sns
import yaml
from openai import AzureOpenAI
from PIL import Image
from pillow_avif import AvifImagePlugin
PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
Input Format:
Original Question: {original_question}
Original Answer: {original_answer}
Output Format:
Question: [rewrite the question if necessary]
Answer: [answer with reasoning steps, including calculations where applicable]
<think>step-by-step reasoning process</think>
<answer>easy to verify answer</answer>
"""
def get_image_data_url(image_input):
if isinstance(image_input, str) and image_input.startswith("data:"):
return image_input
if isinstance(image_input, str) and image_input.startswith("http"):
image_input = load_image(image_input)
if isinstance(image_input, str):
image_input = Image.open(image_input)
if not isinstance(image_input, Image.Image):
raise ValueError("Unsupported image input type")
if image_input.mode != "RGB":
image_input = image_input.convert("RGB")
buffer = BytesIO()
image_input.save(buffer, format="JPEG")
img_bytes = buffer.getvalue()
base64_data = base64.b64encode(img_bytes).decode("utf-8")
return f"data:image/jpeg;base64,{base64_data}"
def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
if image is None:
return None
data_url_list = [get_image_data_url(image)]
client = AzureOpenAI(
azure_endpoint="YOUR_AZURE_ENDPOINT",
api_version="2023-07-01-preview",
api_key="YOUR_API_KEY",
)
for attempt in range(max_retries):
try:
messages = [
{
"role": "system",
"content": "You are an expert to analyze the image and provide useful information for users.",
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
},
]
for data_url in data_url_list:
messages[1]["content"].insert(
0, {"type": "image_url", "image_url": {"url": data_url}}
)
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=messages,
temperature=0.2,
max_tokens=8192,
)
return response.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise Exception(
f"Failed after {max_retries} attempts. Last error: {str(e)}"
)
delay = initial_delay * (2**attempt) + random.uniform(
0, 0.1 * initial_delay * (2**attempt)
)
time.sleep(delay)
def process_single_item(example):
try:
image_path = example["image_path"]
formatted_prompt = PROMPT_FORMAT.format(
original_question=example["question"], original_answer=example["answer"]
)
response = gpt4o_query(image_path, formatted_prompt)
example["gpt4o_response"] = response
return example
except Exception as e:
print(f"Error processing item: {str(e)}")
example["gpt4o_response"] = None
return example
def main():
dataset_path = "path/to/your/dataset"
full_dataset = load_from_disk(dataset_path)
processed_dataset = full_dataset.map(
function=partial(process_single_item),
num_proc=256,
desc="Processing dataset with GPT-4o",
keep_in_memory=True,
)
output_path = f"{dataset_path}_processed"
processed_dataset.save_to_disk(output_path)
print(f"Processed dataset saved to: {output_path}")
if __name__ == "__main__":
main()