Skip to content

Commit 8318714

Browse files
author
rajdeephu
committed
added comments
1 parent a05a7d8 commit 8318714

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

metrics.py

+10
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,25 @@ def getBatchDetectionAcc(label_mask, pred_mask):
88
assert ((label_mask == 0) | (label_mask == 1)).all()
99
assert ((pred_mask == 0) | (pred_mask == 1)).all()
1010

11+
# get individual masks
1112
masks = getIndividualMasks(label_mask[0])
1213
detection = []
14+
15+
# find the prediction for each individual mask
1316
for mask in masks:
17+
# make individual mask same shape as pred mask
1418
mask = mask.reshape((1, mask.shape[0], mask.shape[1]))
1519
mask = np.repeat(mask, label_mask.shape[0], axis=0)
20+
21+
# get intersection for individual mask
1622
intersection = mask * pred_mask
23+
24+
# get prediction for individual mask
1725
num_ones = (intersection == 1).sum(axis=(1,2))
1826
num_ones[num_ones > 0] = 1
1927
detection.append(num_ones)
28+
29+
# combine individual masks prediction to find acc
2030
detection = np.column_stack(detection)
2131
acc = detection.mean(axis=-1)
2232
batch_acc = acc.mean(axis=-1)

0 commit comments

Comments
 (0)