|
| 1 | +# python3 autolabel.py /path/to/target_dir \ |
| 2 | +# --config_file GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py \ |
| 3 | +# --checkpoint GroundingDINO/weights/groundingdino_swinb_cogcoor.pth \ |
| 4 | +# --device cuda \ |
| 5 | +# --text_prompt "cauliflower . broccoli . zucchini" \ |
| 6 | +# --box_threshold 0.30 \ |
| 7 | +# --text_threshold 0.25 \ |
| 8 | +# --iou_threshold 0.8 |
| 9 | + |
| 10 | +#!/usr/bin/env python3 |
| 11 | +import argparse |
| 12 | +import os |
| 13 | +import copy |
| 14 | +import json |
| 15 | +import numpy as np |
| 16 | +import torch |
| 17 | +from PIL import Image, ImageDraw, ImageFont |
| 18 | + |
| 19 | +# Grounding DINO 관련 모듈 |
| 20 | +import GroundingDINO.groundingdino.datasets.transforms as T |
| 21 | +from GroundingDINO.groundingdino.models import build_model |
| 22 | +from GroundingDINO.groundingdino.util import box_ops |
| 23 | +from GroundingDINO.groundingdino.util.slconfig import SLConfig |
| 24 | +from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap |
| 25 | + |
| 26 | +# 그 외 모듈 |
| 27 | +import cv2 |
| 28 | +import matplotlib.pyplot as plt |
| 29 | +import torchvision.transforms as TT |
| 30 | + |
| 31 | +# COCO Annotation 관련 (segmentation, bbox 계산) |
| 32 | +from skimage import measure |
| 33 | +from shapely.geometry import Polygon, MultiPolygon |
| 34 | +import datetime |
| 35 | +# (pycocotools는 COCO 평가 시 사용 – 여기서는 json 저장만 하므로 직접 사용하지 않음) |
| 36 | + |
| 37 | +# GPU 설정 |
| 38 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| 39 | + |
| 40 | + |
| 41 | +######################################## |
| 42 | +# 1. 이미지 로드 및 전처리 함수 (블럭 2) |
| 43 | +######################################## |
| 44 | +def load_image(image_path): |
| 45 | + """ |
| 46 | + 이미지 파일을 로드하여 PIL 이미지와 Grounding DINO 전처리에 맞는 텐서를 반환합니다. |
| 47 | + """ |
| 48 | + image_pil = Image.open(image_path).convert("RGB") |
| 49 | + transform = T.Compose([ |
| 50 | + T.RandomResize([800], max_size=1333), |
| 51 | + T.ToTensor(), |
| 52 | + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| 53 | + ]) |
| 54 | + image_tensor, _ = transform(image_pil, None) # 3 x H x W 텐서 |
| 55 | + return image_pil, image_tensor |
| 56 | + |
| 57 | + |
| 58 | +######################################## |
| 59 | +# 2. 모델 로드 함수 (블럭 3) |
| 60 | +######################################## |
| 61 | +def load_model(model_config_path, model_checkpoint_path, device): |
| 62 | + """ |
| 63 | + 모델 config 파일과 체크포인트를 이용해 모델을 로드합니다. |
| 64 | + """ |
| 65 | + args = SLConfig.fromfile(model_config_path) |
| 66 | + args.device = device |
| 67 | + model = build_model(args) |
| 68 | + checkpoint = torch.load(model_checkpoint_path, map_location="cpu") |
| 69 | + load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) |
| 70 | + print("모델 로드 결과:", load_res) |
| 71 | + model.eval() |
| 72 | + return model |
| 73 | + |
| 74 | + |
| 75 | +######################################## |
| 76 | +# 3. 추론 함수 (블럭 4) |
| 77 | +######################################## |
| 78 | +def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): |
| 79 | + """ |
| 80 | + 이미지와 텍스트 프롬프트(caption)를 입력받아 모델 추론을 수행한 후, |
| 81 | + 임계값을 적용하여 박스와 예측된 phrase 리스트를 반환합니다. |
| 82 | + """ |
| 83 | + caption = caption.lower().strip() |
| 84 | + if not caption.endswith("."): |
| 85 | + caption = caption + "." |
| 86 | + model = model.to(device) |
| 87 | + image = image.to(device) |
| 88 | + with torch.no_grad(): |
| 89 | + outputs = model(image[None], captions=[caption]) |
| 90 | + logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) |
| 91 | + boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) |
| 92 | + |
| 93 | + # 임계값으로 filtering |
| 94 | + filt_mask = logits.max(dim=1)[0] > box_threshold |
| 95 | + logits_filt = logits[filt_mask] # num_filt x 256 |
| 96 | + boxes_filt = boxes[filt_mask] # num_filt x 4 |
| 97 | + |
| 98 | + # 모델 내 tokenizer를 사용하여 phrase 추출 |
| 99 | + tokenlizer = model.tokenizer |
| 100 | + tokenized = tokenlizer(caption) |
| 101 | + pred_phrases = [] |
| 102 | + for logit in logits_filt: |
| 103 | + pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) |
| 104 | + if with_logits: |
| 105 | + pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") |
| 106 | + else: |
| 107 | + pred_phrases.append(pred_phrase) |
| 108 | + |
| 109 | + return boxes_filt, pred_phrases |
| 110 | + |
| 111 | + |
| 112 | +######################################## |
| 113 | +# 4. 시각화를 위한 함수 (블럭 5, 6) |
| 114 | +######################################## |
| 115 | +def show_mask(mask, ax, random_color=True): |
| 116 | + """ |
| 117 | + 디버깅/시각화를 위해 마스크를 표시합니다. |
| 118 | + """ |
| 119 | + if random_color: |
| 120 | + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| 121 | + else: |
| 122 | + color = np.array([30/255, 144/255, 255/255, 0.6]) |
| 123 | + h, w = mask.shape[-2:] |
| 124 | + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| 125 | + ax.imshow(mask_image) |
| 126 | + |
| 127 | + |
| 128 | +def show_box(box, ax, label): |
| 129 | + """ |
| 130 | + 디버깅/시각화를 위해 박스를 그리고 label을 표시합니다. |
| 131 | + """ |
| 132 | + x0, y0 = box[0], box[1] |
| 133 | + w, h = box[2] - box[0], box[3] - box[1] |
| 134 | + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) |
| 135 | + ax.text(x0, y0, label, fontsize=12, color='green') |
| 136 | + |
| 137 | + |
| 138 | +######################################## |
| 139 | +# 5. COCO Annotation 관련 함수 (블럭 13) |
| 140 | +######################################## |
| 141 | +def create_sub_mask_annotation(sub_mask, image_id, category_id, annotation_id, is_crowd): |
| 142 | + """ |
| 143 | + 단일 객체의 서브 마스크(2D binary mask)를 입력받아 COCO segmentation annotation을 생성합니다. |
| 144 | + """ |
| 145 | + # sub_mask 에서 contour 추출 (객체 경계) |
| 146 | + contours = measure.find_contours(sub_mask, 0.5, positive_orientation='low') |
| 147 | + segmentations = [] |
| 148 | + polygons = [] |
| 149 | + for contour in contours: |
| 150 | + # (row, col) -> (x, y) 변환 및 좌표 보정 |
| 151 | + for i in range(len(contour)): |
| 152 | + row, col = contour[i] |
| 153 | + contour[i] = (col - 1, row - 1) |
| 154 | + poly = Polygon(contour) |
| 155 | + poly = poly.simplify(1.0, preserve_topology=False) |
| 156 | + polygons.append(poly) |
| 157 | + segmentation = np.array(poly.exterior.coords).ravel().tolist() |
| 158 | + segmentations.append(segmentation) |
| 159 | + # 다수의 polygon 결합하여 bbox 및 area 계산 |
| 160 | + multi_poly = MultiPolygon(polygons) |
| 161 | + x, y, max_x, max_y = multi_poly.bounds |
| 162 | + width = max_x - x |
| 163 | + height = max_y - y |
| 164 | + bbox = [x, y, width, height] |
| 165 | + area = multi_poly.area |
| 166 | + |
| 167 | + annotation = { |
| 168 | + 'segmentation': segmentations, |
| 169 | + 'iscrowd': is_crowd, |
| 170 | + 'image_id': image_id, |
| 171 | + 'category_id': category_id, |
| 172 | + 'id': annotation_id, |
| 173 | + 'bbox': bbox, |
| 174 | + 'area': area |
| 175 | + } |
| 176 | + return annotation |
| 177 | + |
| 178 | + |
| 179 | +def create_image_info(image_id, file_name, image_size, |
| 180 | + date_captured=datetime.datetime.utcnow().isoformat(' '), |
| 181 | + license_id=1, coco_url="", flickr_url=""): |
| 182 | + """ |
| 183 | + 단일 이미지에 대한 COCO image 정보(dict)를 생성합니다. |
| 184 | + image_size는 (width, height) 형식이어야 합니다. |
| 185 | + """ |
| 186 | + image_info = { |
| 187 | + "id": image_id, |
| 188 | + "file_name": file_name, |
| 189 | + "width": image_size[0], |
| 190 | + "height": image_size[1], |
| 191 | + "date_captured": date_captured, |
| 192 | + "license": license_id, |
| 193 | + "coco_url": coco_url, |
| 194 | + "flickr_url": flickr_url |
| 195 | + } |
| 196 | + return image_info |
| 197 | + |
| 198 | + |
| 199 | +def IOUcalc(registered, cand_area, thresh): |
| 200 | + """ |
| 201 | + 이미 등록된 bbox들과 IoU를 계산하여, 임계값보다 크면 False (중복) 반환. |
| 202 | + """ |
| 203 | + for bbox in registered: |
| 204 | + iou = get_iou(bbox, cand_area) |
| 205 | + if iou >= float(thresh): |
| 206 | + return False |
| 207 | + return True |
| 208 | + |
| 209 | + |
| 210 | +def get_iou(bb1, bb2): |
| 211 | + """ |
| 212 | + 두 bbox (dict, keys: 'x1','x2','y1','y2')의 IoU를 계산합니다. |
| 213 | + """ |
| 214 | + assert bb1['x1'] < bb1['x2'] |
| 215 | + assert bb1['y1'] < bb1['y2'] |
| 216 | + assert bb2['x1'] < bb2['x2'] |
| 217 | + assert bb2['y1'] < bb2['y2'] |
| 218 | + |
| 219 | + x_left = max(bb1['x1'], bb2['x1']) |
| 220 | + y_top = max(bb1['y1'], bb2['y1']) |
| 221 | + x_right = min(bb1['x2'], bb2['x2']) |
| 222 | + y_bottom = min(bb1['y2'], bb2['y2']) |
| 223 | + |
| 224 | + if x_right < x_left or y_bottom < y_top: |
| 225 | + return 0.0 |
| 226 | + |
| 227 | + intersection_area = (x_right - x_left) * (y_bottom - y_top) |
| 228 | + bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) |
| 229 | + bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) |
| 230 | + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) |
| 231 | + return iou |
| 232 | + |
| 233 | + |
| 234 | +######################################## |
| 235 | +# 6. Main 함수: 디렉토리 내 이미지 순회 및 처리 |
| 236 | +######################################## |
| 237 | +def main(): |
| 238 | + parser = argparse.ArgumentParser( |
| 239 | + description="디렉토리 내 이미지들을 대상으로 Grounding DINO를 이용해 객체 검출 및 COCO annotation JSON 생성" |
| 240 | + ) |
| 241 | + parser.add_argument("target_dir", type=str, help="이미지 파일들이 위치한 디렉토리") |
| 242 | + parser.add_argument("--config_file", type=str, |
| 243 | + default="GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py", |
| 244 | + help="모델 config 파일 경로") |
| 245 | + parser.add_argument("--checkpoint", type=str, |
| 246 | + default="GroundingDINO/weights/groundingdino_swinb_cogcoor.pth", |
| 247 | + help="모델 checkpoint 파일 경로") |
| 248 | + parser.add_argument("--device", type=str, default="cuda", |
| 249 | + help="사용할 device (cuda 또는 cpu)") |
| 250 | + parser.add_argument("--text_prompt", type=str, |
| 251 | + default="cauliflower . broccoli . zucchini", |
| 252 | + help="텍스트 프롬프트") |
| 253 | + parser.add_argument("--box_threshold", type=float, default=0.30, |
| 254 | + help="박스 임계값") |
| 255 | + parser.add_argument("--text_threshold", type=float, default=0.25, |
| 256 | + help="텍스트 임계값") |
| 257 | + parser.add_argument("--iou_threshold", type=float, default=0.8, |
| 258 | + help="IOU 임계값 (중복 bbox 제거용)") |
| 259 | + args = parser.parse_args() |
| 260 | + |
| 261 | + # 모델 로드 (한 번만 로드) |
| 262 | + print("모델 로드 중...") |
| 263 | + model = load_model(args.config_file, args.checkpoint, args.device) |
| 264 | + print("모델 로드 완료.") |
| 265 | + |
| 266 | + # COCO annotation에 사용할 카테고리 사전 (블럭 14) |
| 267 | + CAT_ID = {'asparagus': 1, 'broccoli': 2, 'carrot': 3, 'cauliflower': 4, 'potato': 5, 'zucchini': 6} |
| 268 | + |
| 269 | + # 대상 디렉토리 내의 이미지 파일 순회 |
| 270 | + for file in os.listdir(args.target_dir): |
| 271 | + if file.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): |
| 272 | + image_path = os.path.join(args.target_dir, file) |
| 273 | + print("\n처리 중:", image_path) |
| 274 | + try: |
| 275 | + image_pil, image_tensor = load_image(image_path) |
| 276 | + except Exception as e: |
| 277 | + print(f"이미지 로드 실패: {e}") |
| 278 | + continue |
| 279 | + width, height = image_pil.size |
| 280 | + |
| 281 | + # Grounding DINO 모델 추론 |
| 282 | + boxes_filt, pred_phrases = get_grounding_output( |
| 283 | + model, image_tensor, args.text_prompt, |
| 284 | + args.box_threshold, args.text_threshold, |
| 285 | + with_logits=True, device=args.device |
| 286 | + ) |
| 287 | + if boxes_filt.shape[0] == 0: |
| 288 | + print("검출된 박스 없음.") |
| 289 | + continue |
| 290 | + |
| 291 | + # (블럭 11) 각 박스 영역에 대해 binary mask 생성 |
| 292 | + masks = [] |
| 293 | + for box in boxes_filt: |
| 294 | + # box: [x_min, y_min, x_max, y_max] (정수형 변환) |
| 295 | + box_list = list(map(int, box.tolist())) |
| 296 | + x_min, y_min, x_max, y_max = box_list |
| 297 | + mask = np.zeros((height, width), dtype=np.uint8) |
| 298 | + # 이미지 범위 내로 보정 |
| 299 | + x_min = max(0, x_min) |
| 300 | + y_min = max(0, y_min) |
| 301 | + x_max = min(width, x_max) |
| 302 | + y_max = min(height, y_max) |
| 303 | + mask[y_min:y_max, x_min:x_max] = 1 |
| 304 | + masks.append(mask) |
| 305 | + |
| 306 | + if len(masks) == 0: |
| 307 | + print("생성된 mask가 없습니다.") |
| 308 | + continue |
| 309 | + |
| 310 | + ######################################## |
| 311 | + # COCO annotation 구성 (블럭 14) |
| 312 | + ######################################## |
| 313 | + coco_annotation = {"images": [], "annotations": [], "categories": []} |
| 314 | + # 배경 카테고리 추가 |
| 315 | + coco_annotation["categories"].append({"supercategory": None, "id": 0, "name": "_background_"}) |
| 316 | + # CAT_ID에 정의된 카테고리 추가 |
| 317 | + for i, category in enumerate(CAT_ID.keys()): |
| 318 | + coco_annotation["categories"].append({"supercategory": None, "id": i+1, "name": category}) |
| 319 | + print("사용 카테고리:", coco_annotation["categories"]) |
| 320 | + |
| 321 | + image_id = 0 |
| 322 | + annotation_id = 0 |
| 323 | + is_crowd = 0 |
| 324 | + registered_regions = [] # 중복 bbox 제거용 |
| 325 | + |
| 326 | + # 각 검출된 객체에 대해 annotation 생성 |
| 327 | + for i, (box, label) in enumerate(zip(boxes_filt, pred_phrases)): |
| 328 | + # box는 tensor → dict 변환 (x1, y1, x2, y2) |
| 329 | + bbox = {'x1': box[0].item(), 'y1': box[1].item(), |
| 330 | + 'x2': box[2].item(), 'y2': box[3].item()} |
| 331 | + # label에서 카테고리 이름 추출 (예: "cauliflower(0.935)" → "cauliflower") |
| 332 | + cat_str = label.split('(')[0].strip() |
| 333 | + if cat_str in CAT_ID: |
| 334 | + category_id = CAT_ID[cat_str] |
| 335 | + else: |
| 336 | + print(f"카테고리 '{cat_str}'가 CAT_ID에 없으므로 해당 객체는 건너뜁니다.") |
| 337 | + continue |
| 338 | + |
| 339 | + mask = masks[i] |
| 340 | + # (옵션) 디버그: mask의 shape 출력 |
| 341 | + # print("mask shape:", mask.shape) |
| 342 | + |
| 343 | + # 중복 bbox (IoU 임계값 기반) 체크 후 annotation 생성 |
| 344 | + if IOUcalc(registered_regions, bbox, args.iou_threshold): |
| 345 | + registered_regions.append(bbox) |
| 346 | + annotation = create_sub_mask_annotation(np.asarray(mask), image_id, category_id, annotation_id, is_crowd) |
| 347 | + coco_annotation["annotations"].append(annotation) |
| 348 | + annotation_id += 1 |
| 349 | + |
| 350 | + # image 정보 생성 (PIL 이미지의 size는 (width, height)) |
| 351 | + image_info = create_image_info(image_id, file, (width, height)) |
| 352 | + coco_annotation["images"].append(image_info) |
| 353 | + |
| 354 | + # 결과 COCO JSON 파일 저장 (예: "image.jpg" → "image.json") |
| 355 | + output_json_path = os.path.join(args.target_dir, os.path.splitext(file)[0] + ".json") |
| 356 | + with open(output_json_path, 'w') as f: |
| 357 | + json.dump(coco_annotation, f, indent=4) |
| 358 | + print(f"COCO annotation 저장 완료: {output_json_path}") |
| 359 | + |
| 360 | + # (옵션) 검출 결과 시각화 – plt 창으로 띄우고자 하면 주석 해제 |
| 361 | + # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) |
| 362 | + # ax.imshow(image_pil) |
| 363 | + # for box, label in zip(boxes_filt, pred_phrases): |
| 364 | + # show_box(box.numpy(), ax, label) |
| 365 | + # for mask in masks: |
| 366 | + # show_mask(mask, ax, random_color=True) |
| 367 | + # plt.axis('off') |
| 368 | + # plt.show() |
| 369 | + |
| 370 | + |
| 371 | +if __name__ == "__main__": |
| 372 | + main() |
0 commit comments