Skip to content

Commit 94bfe19

Browse files
author
jim chong
committed
first script
1 parent 41fbb43 commit 94bfe19

File tree

1 file changed

+372
-0
lines changed

1 file changed

+372
-0
lines changed

autolabel.py

+372
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# python3 autolabel.py /path/to/target_dir \
2+
# --config_file GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py \
3+
# --checkpoint GroundingDINO/weights/groundingdino_swinb_cogcoor.pth \
4+
# --device cuda \
5+
# --text_prompt "cauliflower . broccoli . zucchini" \
6+
# --box_threshold 0.30 \
7+
# --text_threshold 0.25 \
8+
# --iou_threshold 0.8
9+
10+
#!/usr/bin/env python3
11+
import argparse
12+
import os
13+
import copy
14+
import json
15+
import numpy as np
16+
import torch
17+
from PIL import Image, ImageDraw, ImageFont
18+
19+
# Grounding DINO 관련 모듈
20+
import GroundingDINO.groundingdino.datasets.transforms as T
21+
from GroundingDINO.groundingdino.models import build_model
22+
from GroundingDINO.groundingdino.util import box_ops
23+
from GroundingDINO.groundingdino.util.slconfig import SLConfig
24+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
25+
26+
# 그 외 모듈
27+
import cv2
28+
import matplotlib.pyplot as plt
29+
import torchvision.transforms as TT
30+
31+
# COCO Annotation 관련 (segmentation, bbox 계산)
32+
from skimage import measure
33+
from shapely.geometry import Polygon, MultiPolygon
34+
import datetime
35+
# (pycocotools는 COCO 평가 시 사용 – 여기서는 json 저장만 하므로 직접 사용하지 않음)
36+
37+
# GPU 설정
38+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
39+
40+
41+
########################################
42+
# 1. 이미지 로드 및 전처리 함수 (블럭 2)
43+
########################################
44+
def load_image(image_path):
45+
"""
46+
이미지 파일을 로드하여 PIL 이미지와 Grounding DINO 전처리에 맞는 텐서를 반환합니다.
47+
"""
48+
image_pil = Image.open(image_path).convert("RGB")
49+
transform = T.Compose([
50+
T.RandomResize([800], max_size=1333),
51+
T.ToTensor(),
52+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
53+
])
54+
image_tensor, _ = transform(image_pil, None) # 3 x H x W 텐서
55+
return image_pil, image_tensor
56+
57+
58+
########################################
59+
# 2. 모델 로드 함수 (블럭 3)
60+
########################################
61+
def load_model(model_config_path, model_checkpoint_path, device):
62+
"""
63+
모델 config 파일과 체크포인트를 이용해 모델을 로드합니다.
64+
"""
65+
args = SLConfig.fromfile(model_config_path)
66+
args.device = device
67+
model = build_model(args)
68+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
69+
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
70+
print("모델 로드 결과:", load_res)
71+
model.eval()
72+
return model
73+
74+
75+
########################################
76+
# 3. 추론 함수 (블럭 4)
77+
########################################
78+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
79+
"""
80+
이미지와 텍스트 프롬프트(caption)를 입력받아 모델 추론을 수행한 후,
81+
임계값을 적용하여 박스와 예측된 phrase 리스트를 반환합니다.
82+
"""
83+
caption = caption.lower().strip()
84+
if not caption.endswith("."):
85+
caption = caption + "."
86+
model = model.to(device)
87+
image = image.to(device)
88+
with torch.no_grad():
89+
outputs = model(image[None], captions=[caption])
90+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
91+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
92+
93+
# 임계값으로 filtering
94+
filt_mask = logits.max(dim=1)[0] > box_threshold
95+
logits_filt = logits[filt_mask] # num_filt x 256
96+
boxes_filt = boxes[filt_mask] # num_filt x 4
97+
98+
# 모델 내 tokenizer를 사용하여 phrase 추출
99+
tokenlizer = model.tokenizer
100+
tokenized = tokenlizer(caption)
101+
pred_phrases = []
102+
for logit in logits_filt:
103+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
104+
if with_logits:
105+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
106+
else:
107+
pred_phrases.append(pred_phrase)
108+
109+
return boxes_filt, pred_phrases
110+
111+
112+
########################################
113+
# 4. 시각화를 위한 함수 (블럭 5, 6)
114+
########################################
115+
def show_mask(mask, ax, random_color=True):
116+
"""
117+
디버깅/시각화를 위해 마스크를 표시합니다.
118+
"""
119+
if random_color:
120+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
121+
else:
122+
color = np.array([30/255, 144/255, 255/255, 0.6])
123+
h, w = mask.shape[-2:]
124+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
125+
ax.imshow(mask_image)
126+
127+
128+
def show_box(box, ax, label):
129+
"""
130+
디버깅/시각화를 위해 박스를 그리고 label을 표시합니다.
131+
"""
132+
x0, y0 = box[0], box[1]
133+
w, h = box[2] - box[0], box[3] - box[1]
134+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
135+
ax.text(x0, y0, label, fontsize=12, color='green')
136+
137+
138+
########################################
139+
# 5. COCO Annotation 관련 함수 (블럭 13)
140+
########################################
141+
def create_sub_mask_annotation(sub_mask, image_id, category_id, annotation_id, is_crowd):
142+
"""
143+
단일 객체의 서브 마스크(2D binary mask)를 입력받아 COCO segmentation annotation을 생성합니다.
144+
"""
145+
# sub_mask 에서 contour 추출 (객체 경계)
146+
contours = measure.find_contours(sub_mask, 0.5, positive_orientation='low')
147+
segmentations = []
148+
polygons = []
149+
for contour in contours:
150+
# (row, col) -> (x, y) 변환 및 좌표 보정
151+
for i in range(len(contour)):
152+
row, col = contour[i]
153+
contour[i] = (col - 1, row - 1)
154+
poly = Polygon(contour)
155+
poly = poly.simplify(1.0, preserve_topology=False)
156+
polygons.append(poly)
157+
segmentation = np.array(poly.exterior.coords).ravel().tolist()
158+
segmentations.append(segmentation)
159+
# 다수의 polygon 결합하여 bbox 및 area 계산
160+
multi_poly = MultiPolygon(polygons)
161+
x, y, max_x, max_y = multi_poly.bounds
162+
width = max_x - x
163+
height = max_y - y
164+
bbox = [x, y, width, height]
165+
area = multi_poly.area
166+
167+
annotation = {
168+
'segmentation': segmentations,
169+
'iscrowd': is_crowd,
170+
'image_id': image_id,
171+
'category_id': category_id,
172+
'id': annotation_id,
173+
'bbox': bbox,
174+
'area': area
175+
}
176+
return annotation
177+
178+
179+
def create_image_info(image_id, file_name, image_size,
180+
date_captured=datetime.datetime.utcnow().isoformat(' '),
181+
license_id=1, coco_url="", flickr_url=""):
182+
"""
183+
단일 이미지에 대한 COCO image 정보(dict)를 생성합니다.
184+
image_size는 (width, height) 형식이어야 합니다.
185+
"""
186+
image_info = {
187+
"id": image_id,
188+
"file_name": file_name,
189+
"width": image_size[0],
190+
"height": image_size[1],
191+
"date_captured": date_captured,
192+
"license": license_id,
193+
"coco_url": coco_url,
194+
"flickr_url": flickr_url
195+
}
196+
return image_info
197+
198+
199+
def IOUcalc(registered, cand_area, thresh):
200+
"""
201+
이미 등록된 bbox들과 IoU를 계산하여, 임계값보다 크면 False (중복) 반환.
202+
"""
203+
for bbox in registered:
204+
iou = get_iou(bbox, cand_area)
205+
if iou >= float(thresh):
206+
return False
207+
return True
208+
209+
210+
def get_iou(bb1, bb2):
211+
"""
212+
두 bbox (dict, keys: 'x1','x2','y1','y2')의 IoU를 계산합니다.
213+
"""
214+
assert bb1['x1'] < bb1['x2']
215+
assert bb1['y1'] < bb1['y2']
216+
assert bb2['x1'] < bb2['x2']
217+
assert bb2['y1'] < bb2['y2']
218+
219+
x_left = max(bb1['x1'], bb2['x1'])
220+
y_top = max(bb1['y1'], bb2['y1'])
221+
x_right = min(bb1['x2'], bb2['x2'])
222+
y_bottom = min(bb1['y2'], bb2['y2'])
223+
224+
if x_right < x_left or y_bottom < y_top:
225+
return 0.0
226+
227+
intersection_area = (x_right - x_left) * (y_bottom - y_top)
228+
bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
229+
bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
230+
iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
231+
return iou
232+
233+
234+
########################################
235+
# 6. Main 함수: 디렉토리 내 이미지 순회 및 처리
236+
########################################
237+
def main():
238+
parser = argparse.ArgumentParser(
239+
description="디렉토리 내 이미지들을 대상으로 Grounding DINO를 이용해 객체 검출 및 COCO annotation JSON 생성"
240+
)
241+
parser.add_argument("target_dir", type=str, help="이미지 파일들이 위치한 디렉토리")
242+
parser.add_argument("--config_file", type=str,
243+
default="GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py",
244+
help="모델 config 파일 경로")
245+
parser.add_argument("--checkpoint", type=str,
246+
default="GroundingDINO/weights/groundingdino_swinb_cogcoor.pth",
247+
help="모델 checkpoint 파일 경로")
248+
parser.add_argument("--device", type=str, default="cuda",
249+
help="사용할 device (cuda 또는 cpu)")
250+
parser.add_argument("--text_prompt", type=str,
251+
default="cauliflower . broccoli . zucchini",
252+
help="텍스트 프롬프트")
253+
parser.add_argument("--box_threshold", type=float, default=0.30,
254+
help="박스 임계값")
255+
parser.add_argument("--text_threshold", type=float, default=0.25,
256+
help="텍스트 임계값")
257+
parser.add_argument("--iou_threshold", type=float, default=0.8,
258+
help="IOU 임계값 (중복 bbox 제거용)")
259+
args = parser.parse_args()
260+
261+
# 모델 로드 (한 번만 로드)
262+
print("모델 로드 중...")
263+
model = load_model(args.config_file, args.checkpoint, args.device)
264+
print("모델 로드 완료.")
265+
266+
# COCO annotation에 사용할 카테고리 사전 (블럭 14)
267+
CAT_ID = {'asparagus': 1, 'broccoli': 2, 'carrot': 3, 'cauliflower': 4, 'potato': 5, 'zucchini': 6}
268+
269+
# 대상 디렉토리 내의 이미지 파일 순회
270+
for file in os.listdir(args.target_dir):
271+
if file.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
272+
image_path = os.path.join(args.target_dir, file)
273+
print("\n처리 중:", image_path)
274+
try:
275+
image_pil, image_tensor = load_image(image_path)
276+
except Exception as e:
277+
print(f"이미지 로드 실패: {e}")
278+
continue
279+
width, height = image_pil.size
280+
281+
# Grounding DINO 모델 추론
282+
boxes_filt, pred_phrases = get_grounding_output(
283+
model, image_tensor, args.text_prompt,
284+
args.box_threshold, args.text_threshold,
285+
with_logits=True, device=args.device
286+
)
287+
if boxes_filt.shape[0] == 0:
288+
print("검출된 박스 없음.")
289+
continue
290+
291+
# (블럭 11) 각 박스 영역에 대해 binary mask 생성
292+
masks = []
293+
for box in boxes_filt:
294+
# box: [x_min, y_min, x_max, y_max] (정수형 변환)
295+
box_list = list(map(int, box.tolist()))
296+
x_min, y_min, x_max, y_max = box_list
297+
mask = np.zeros((height, width), dtype=np.uint8)
298+
# 이미지 범위 내로 보정
299+
x_min = max(0, x_min)
300+
y_min = max(0, y_min)
301+
x_max = min(width, x_max)
302+
y_max = min(height, y_max)
303+
mask[y_min:y_max, x_min:x_max] = 1
304+
masks.append(mask)
305+
306+
if len(masks) == 0:
307+
print("생성된 mask가 없습니다.")
308+
continue
309+
310+
########################################
311+
# COCO annotation 구성 (블럭 14)
312+
########################################
313+
coco_annotation = {"images": [], "annotations": [], "categories": []}
314+
# 배경 카테고리 추가
315+
coco_annotation["categories"].append({"supercategory": None, "id": 0, "name": "_background_"})
316+
# CAT_ID에 정의된 카테고리 추가
317+
for i, category in enumerate(CAT_ID.keys()):
318+
coco_annotation["categories"].append({"supercategory": None, "id": i+1, "name": category})
319+
print("사용 카테고리:", coco_annotation["categories"])
320+
321+
image_id = 0
322+
annotation_id = 0
323+
is_crowd = 0
324+
registered_regions = [] # 중복 bbox 제거용
325+
326+
# 각 검출된 객체에 대해 annotation 생성
327+
for i, (box, label) in enumerate(zip(boxes_filt, pred_phrases)):
328+
# box는 tensor → dict 변환 (x1, y1, x2, y2)
329+
bbox = {'x1': box[0].item(), 'y1': box[1].item(),
330+
'x2': box[2].item(), 'y2': box[3].item()}
331+
# label에서 카테고리 이름 추출 (예: "cauliflower(0.935)" → "cauliflower")
332+
cat_str = label.split('(')[0].strip()
333+
if cat_str in CAT_ID:
334+
category_id = CAT_ID[cat_str]
335+
else:
336+
print(f"카테고리 '{cat_str}'가 CAT_ID에 없으므로 해당 객체는 건너뜁니다.")
337+
continue
338+
339+
mask = masks[i]
340+
# (옵션) 디버그: mask의 shape 출력
341+
# print("mask shape:", mask.shape)
342+
343+
# 중복 bbox (IoU 임계값 기반) 체크 후 annotation 생성
344+
if IOUcalc(registered_regions, bbox, args.iou_threshold):
345+
registered_regions.append(bbox)
346+
annotation = create_sub_mask_annotation(np.asarray(mask), image_id, category_id, annotation_id, is_crowd)
347+
coco_annotation["annotations"].append(annotation)
348+
annotation_id += 1
349+
350+
# image 정보 생성 (PIL 이미지의 size는 (width, height))
351+
image_info = create_image_info(image_id, file, (width, height))
352+
coco_annotation["images"].append(image_info)
353+
354+
# 결과 COCO JSON 파일 저장 (예: "image.jpg" → "image.json")
355+
output_json_path = os.path.join(args.target_dir, os.path.splitext(file)[0] + ".json")
356+
with open(output_json_path, 'w') as f:
357+
json.dump(coco_annotation, f, indent=4)
358+
print(f"COCO annotation 저장 완료: {output_json_path}")
359+
360+
# (옵션) 검출 결과 시각화 – plt 창으로 띄우고자 하면 주석 해제
361+
# fig, ax = plt.subplots(1, 1, figsize=(10, 10))
362+
# ax.imshow(image_pil)
363+
# for box, label in zip(boxes_filt, pred_phrases):
364+
# show_box(box.numpy(), ax, label)
365+
# for mask in masks:
366+
# show_mask(mask, ax, random_color=True)
367+
# plt.axis('off')
368+
# plt.show()
369+
370+
371+
if __name__ == "__main__":
372+
main()

0 commit comments

Comments
 (0)