|
| 1 | +import time |
| 2 | + |
1 | 3 | import torch
|
2 | 4 | import numpy as np
|
| 5 | +import torchvision |
3 | 6 | import tqdm
|
4 | 7 | from torchvision.ops.boxes import batched_nms
|
5 | 8 |
|
@@ -46,6 +49,94 @@ def rescale_boxes(boxes, current_dim, original_shape):
|
46 | 49 | return boxes
|
47 | 50 |
|
48 | 51 |
|
| 52 | +def soft_non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False, |
| 53 | + is_p1p2=False): |
| 54 | + """Performs Non-Maximum Suppression (NMS) on inference results |
| 55 | +
|
| 56 | + Returns: |
| 57 | + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) |
| 58 | + """ |
| 59 | + if prediction.dtype is torch.float16: |
| 60 | + prediction = prediction.float() # to FP32 |
| 61 | + |
| 62 | + nc = prediction[0].shape[1] - 5 # number of classes |
| 63 | + xc = prediction[..., 4] > conf_thres # candidates |
| 64 | + |
| 65 | + # Settings |
| 66 | + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height |
| 67 | + max_det = 300 # maximum number of detections per image |
| 68 | + time_limit = 10.0 # seconds to quit after |
| 69 | + redundant = True # require redundant detections |
| 70 | + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) |
| 71 | + |
| 72 | + t = time.time() |
| 73 | + output = [None] * prediction.shape[0] |
| 74 | + for xi, x in enumerate(prediction): # image index, image inference |
| 75 | + # Apply constraints |
| 76 | + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height |
| 77 | + x = x[xc[xi]] # confidence |
| 78 | + |
| 79 | + # If none remain process next image |
| 80 | + if not x.shape[0]: |
| 81 | + continue |
| 82 | + |
| 83 | + # Compute conf |
| 84 | + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf |
| 85 | + |
| 86 | + # Box (center x, center y, width, height) to (x1, y1, x2, y2) |
| 87 | + if not is_p1p2: |
| 88 | + box = xywh2p1p2(x[:, :4]) |
| 89 | + else: |
| 90 | + box = x[:, :4] |
| 91 | + |
| 92 | + # Detections matrix nx6 (xyxy, conf, cls) |
| 93 | + if multi_label: |
| 94 | + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T |
| 95 | + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) |
| 96 | + else: # best class only |
| 97 | + conf, j = x[:, 5:].max(1, keepdim=True) |
| 98 | + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] |
| 99 | + |
| 100 | + # Filter by class |
| 101 | + if classes: |
| 102 | + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] |
| 103 | + |
| 104 | + # Apply finite constraint |
| 105 | + # if not torch.isfinite(x).all(): |
| 106 | + # x = x[torch.isfinite(x).all(1)] |
| 107 | + |
| 108 | + # If none remain process next image |
| 109 | + n = x.shape[0] # number of boxes |
| 110 | + if not n: |
| 111 | + continue |
| 112 | + |
| 113 | + # Sort by confidence |
| 114 | + # x = x[x[:, 4].argsort(descending=True)] |
| 115 | + |
| 116 | + # Batched NMS |
| 117 | + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes |
| 118 | + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores |
| 119 | + i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |
| 120 | + if i.shape[0] > max_det: # limit detections |
| 121 | + i = i[:max_det] |
| 122 | + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) |
| 123 | + try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) |
| 124 | + iou = bbox_iou(boxes[i], boxes) > iou_thres # iou matrix |
| 125 | + weights = iou * scores[None] # box weights |
| 126 | + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes |
| 127 | + if redundant: |
| 128 | + i = i[iou.sum(1) > 1] # require redundancy |
| 129 | + except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139 |
| 130 | + print(x, i, x.shape, i.shape) |
| 131 | + pass |
| 132 | + |
| 133 | + output[xi] = x[i] |
| 134 | + if (time.time() - t) > time_limit: |
| 135 | + break # time limit exceeded |
| 136 | + |
| 137 | + return output |
| 138 | + |
| 139 | + |
49 | 140 | def non_max_suppression(prediction, thres=0.5, nms_thres=0.4, is_p1p2=False):
|
50 | 141 | """
|
51 | 142 | Removes detections with lower object confidence score than 'conf_thres' and performs
|
@@ -73,7 +164,6 @@ def non_max_suppression(prediction, thres=0.5, nms_thres=0.4, is_p1p2=False):
|
73 | 164 | # If none anchor are remaining => process next image
|
74 | 165 | if not image_pred.size(0):
|
75 | 166 | continue
|
76 |
| - |
77 | 167 |
|
78 | 168 | detections = torch.cat((image_pred[:, :5],
|
79 | 169 | class_confs.type(prediction.dtype),
|
|
0 commit comments