File tree 1 file changed +10
-0
lines changed
1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -8,15 +8,25 @@ def getBatchDetectionAcc(label_mask, pred_mask):
8
8
assert ((label_mask == 0 ) | (label_mask == 1 )).all ()
9
9
assert ((pred_mask == 0 ) | (pred_mask == 1 )).all ()
10
10
11
+ # get individual masks
11
12
masks = getIndividualMasks (label_mask [0 ])
12
13
detection = []
14
+
15
+ # find the prediction for each individual mask
13
16
for mask in masks :
17
+ # make individual mask same shape as pred mask
14
18
mask = mask .reshape ((1 , mask .shape [0 ], mask .shape [1 ]))
15
19
mask = np .repeat (mask , label_mask .shape [0 ], axis = 0 )
20
+
21
+ # get intersection for individual mask
16
22
intersection = mask * pred_mask
23
+
24
+ # get prediction for individual mask
17
25
num_ones = (intersection == 1 ).sum (axis = (1 ,2 ))
18
26
num_ones [num_ones > 0 ] = 1
19
27
detection .append (num_ones )
28
+
29
+ # combine individual masks prediction to find acc
20
30
detection = np .column_stack (detection )
21
31
acc = detection .mean (axis = - 1 )
22
32
batch_acc = acc .mean (axis = - 1 )
You can’t perform that action at this time.
0 commit comments