Skip to content

Commit 6e4c236

Browse files
committed
Use pandas DataFrame
1 parent 2de08b7 commit 6e4c236

File tree

2 files changed

+133
-133
lines changed

2 files changed

+133
-133
lines changed

MTM/NMS.py

Lines changed: 104 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -91,112 +91,110 @@ def NMS(List_Hit, scoreThreshold=None, sortDescending=True, N_object=float("inf"
9191
This iteration is terminate once we have collected N best hit, or if there are no more hit left to test for overlap
9292
9393
INPUT
94-
- ListHit : a list of dictionnary, with each dictionnary being a hit following the formating {'TemplateIdx'= (int),'BBox'=(x,y,width,height),'Score'=(float)}
95-
the TemplateIdx is the row index in the panda/Knime table
96-
97-
- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
98-
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
99-
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
100-
101-
- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
102-
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
103-
- sortDescending : use True when high score means better prediction, False otherwise (ex : if score is a difference measure, then the best prediction are low difference and we sort by ascending order)
104-
105-
OUTPUT
106-
List_nHit : List of the best detection after NMS, it contains max N detection (but potentially less)
107-
'''
108-
109-
# Apply threshold on prediction score
110-
if scoreThreshold==None :
111-
List_ThreshHit = List_Hit[:] # copy to avoid modifying the input list in place
112-
113-
elif sortDescending : # We keep hit above the threshold
114-
List_ThreshHit = [dico for dico in List_Hit if dico['Score']>=scoreThreshold]
115-
116-
elif not sortDescending : # We keep hit below the threshold
117-
List_ThreshHit = [dico for dico in List_Hit if dico['Score']<=scoreThreshold]
118-
119-
120-
# Sort score to have best predictions first (important as we loop testing the best boxes against the other boxes)
121-
if sortDescending:
122-
List_ThreshHit.sort(key=lambda dico: dico['Score'], reverse=True) # Hit = [list of (x,y),score] - sort according to descending (best = high correlation)
123-
else:
124-
List_ThreshHit.sort(key=lambda dico: dico['Score']) # sort according to ascending score (best = small difference)
125-
126-
127-
# Split the inital pool into Final Hit that are kept and restHit that can be tested
128-
# Initialisation : 1st keep is kept for sure, restHit is the rest of the list
129-
#print("\nInitialise final hit list with first best hit")
130-
FinalHit = [List_ThreshHit[0]]
131-
restHit = List_ThreshHit[1:]
132-
133-
134-
# Loop to compute overlap
135-
while len(FinalHit)<N_object and restHit : # second condition is restHit is not empty
136-
137-
# Report state of the loop
138-
#print("\n\n\nNext while iteration")
139-
140-
#print("-> Final hit list")
141-
#for hit in FinalHit: print(hit)
142-
143-
#print("\n-> Remaining hit list")
144-
#for hit in restHit: print(hit)
145-
146-
# pick the next best peak in the rest of peak
147-
test_hit = restHit[0]
148-
test_bbox = test_hit['BBox']
149-
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))
150-
151-
# Loop over hit in FinalHit to compute successively overlap with test_peak
152-
for hit in FinalHit:
153-
154-
# Recover Bbox from hit
155-
bbox2 = hit['BBox']
156-
157-
# Compute the Intersection over Union between test_peak and current peak
158-
IoU = computeIoU(test_bbox, bbox2)
159-
160-
# Initialise the boolean value to true before test of overlap
161-
ToAppend = True
162-
163-
if IoU>maxOverlap:
164-
ToAppend = False
165-
#print("IoU above threshold\n")
166-
break # no need to test overlap with the other peaks
167-
168-
else:
169-
#print("IoU below threshold\n")
170-
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
171-
continue
172-
173-
174-
# After testing against all peaks (for loop is over), append or not the peak to final
175-
if ToAppend:
176-
# Move the test_hit from restHit to FinalHit
177-
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
178-
FinalHit.append(test_hit)
179-
restHit.remove(test_hit)
180-
181-
else:
182-
# only remove the test_peak from restHit
183-
#print("Remove {} from Remaining hit list".format(test_hit))
184-
restHit.remove(test_hit)
185-
186-
187-
# Once function execution is done, return list of hit without overlap
188-
#print("\nCollected N expected hit, or no hit left to test")
189-
#print("NMS over\n")
190-
return FinalHit
191-
192-
94+
- tableHit : (Panda DataFrame) Each row is a hit, with columns "TemplateName"(String),"BBox"(x,y,width,height),"Score"(float)
95+
96+
- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
97+
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
98+
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
99+
100+
- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
101+
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
102+
- sortDescending : use True when high score means better prediction, False otherwise (ex : if score is a difference measure, then the best prediction are low difference and we sort by ascending order)
103+
104+
OUTPUT
105+
Panda DataFrame with best detection after NMS, it contains max N detection (but potentially less)
106+
'''
107+
108+
# Apply threshold on prediction score
109+
if scoreThreshold==None :
110+
threshTable = tableHit.copy() # copy to avoid modifying the input list in place
111+
112+
elif sortDescending : # We keep rows above the threshold
113+
threshTable = tableHit[ tableHit['Score']>=scoreThreshold ]
114+
115+
elif not sortDescending : # We keep hit below the threshold
116+
threshTable = tableHit[ tableHit['Score']<=scoreThreshold ]
117+
118+
# Sort score to have best predictions first (important as we loop testing the best boxes against the other boxes)
119+
if sortDescending:
120+
threshTable.sort_values("Score", ascending=False, inplace=True) # Hit = [list of (x,y),score] - sort according to descending (best = high correlation)
121+
else:
122+
threshTable.sort_values("Score", ascending=True, inplace=True) # sort according to ascending score (best = small difference)
123+
124+
125+
# Split the inital pool into Final Hit that are kept and restTable that can be tested
126+
# Initialisation : 1st keep is kept for sure, restTable is the rest of the list
127+
#print("\nInitialise final hit list with first best hit")
128+
outTable = threshTable.iloc[[0]].to_dict('records') # double square bracket to recover a DataFrame
129+
restTable = threshTable.iloc[1:].to_dict('records')
130+
131+
132+
# Loop to compute overlap
133+
while len(outTable)<N_object and restTable: # second condition is restTable is not empty
134+
135+
# Report state of the loop
136+
#print("\n\n\nNext while iteration")
137+
138+
#print("-> Final hit list")
139+
#for hit in outTable: print(hit)
140+
141+
#print("\n-> Remaining hit list")
142+
#for hit in restTable: print(hit)
143+
144+
# pick the next best peak in the rest of peak
145+
testHit_dico = restTable[0] # dico
146+
test_bbox = testHit_dico['BBox']
147+
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))
148+
149+
# Loop over hit in outTable to compute successively overlap with testHit
150+
for hit_dico in outTable:
151+
152+
# Recover Bbox from hit
153+
bbox2 = hit_dico['BBox']
154+
155+
# Compute the Intersection over Union between test_peak and current peak
156+
IoU = computeIoU(test_bbox, bbox2)
157+
158+
# Initialise the boolean value to true before test of overlap
159+
ToAppend = True
160+
161+
if IoU>maxOverlap:
162+
ToAppend = False
163+
#print("IoU above threshold\n")
164+
break # no need to test overlap with the other peaks
165+
166+
else:
167+
#print("IoU below threshold\n")
168+
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
169+
continue
170+
171+
172+
# After testing against all peaks (for loop is over), append or not the peak to final
173+
if ToAppend:
174+
# Move the test_hit from restTable to outTable
175+
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
176+
outTable.append(testHit_dico)
177+
restTable.remove(testHit_dico)
178+
179+
else:
180+
# only remove the test_peak from restTable
181+
#print("Remove {} from Remaining hit list".format(test_hit))
182+
restTable.remove(testHit_dico)
183+
184+
185+
# Once function execution is done, return list of hit without overlap
186+
#print("\nCollected N expected hit, or no hit left to test")
187+
#print("NMS over\n")
188+
return pd.DataFrame(outTable)
189+
190+
193191
if __name__ == "__main__":
194-
Hit1 = {'TemplateIdx':1,'BBox':(780, 350, 700, 480), 'Score':0.8}
195-
Hit2 = {'TemplateIdx':1,'BBox':(806, 416, 716, 442), 'Score':0.6}
196-
Hit3 = {'TemplateIdx':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
197-
198-
ListHit = [Hit1, Hit2, Hit3]
192+
ListHit =[
193+
{'TemplateName':1,'BBox':(780, 350, 700, 480), 'Score':0.8},
194+
{'TemplateName':1,'BBox':(806, 416, 716, 442), 'Score':0.6},
195+
{'TemplateName':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
196+
]
199197

200-
ListFinalHit = NMS(ListHit)
198+
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.7, sortDescending=True, maxOverlap=0.5 )
201199

202-
print(ListFinalHit)
200+
print(FinalHits)

MTM/__init__.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import cv2
2-
import numpy as np
2+
import numpy as np
3+
import pandas as pd
34
from skimage.feature import peak_local_max
45
from scipy.signal import find_peaks
56

67
from MTM.NMS import NMS
8+
#from NMS import NMS # for test purpose (should be commented then)
79

810
__all__ = ['NMS']
911

@@ -66,7 +68,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
6668
6769
Returns
6870
-------
69-
- listHit: list of match as dictionaries {"TemplateName":string, "BBox":(X, Y, Width, Height), "Score":float}
71+
- Pandas DataFrame with 1 row per hit and column "TemplateName"(string), "BBox":(X, Y, Width, Height), "Score":float
7072
'''
7173
if N_object!=float("inf") and type(N_object)!=int:
7274
raise TypeError("N_object must be an integer")
@@ -128,7 +130,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
128130
# append to list of potential hit before Non maxima suppression
129131
listHit.append(newHit)
130132

131-
return listHit # All possible hit before Non-Maxima Supression
133+
return pd.DataFrame(listHit) # All possible hits before Non-Maxima Supression
132134

133135

134136
def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, maxOverlap=0.25, searchBox=None):
@@ -154,31 +156,31 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f
154156
155157
Returns
156158
-------
157-
- bestHits: list of match as dictionaries {"TemplateName":string, "BBox":(X, Y, Width, Height), "Score":float}
158-
if N=1, return the best matches independently of the score_threshold
159-
if N<inf, returns up to N best matches that passed the score_threshold
160-
if N=inf, returns all matches that passed the score_threshold
159+
Pandas DataFrame with 1 row per hit and column "TemplateName"(string), "BBox":(X, Y, Width, Height), "Score":float
160+
if N=1, return the best matches independently of the score_threshold
161+
if N<inf, returns up to N best matches that passed the score_threshold
162+
if N=inf, returns all matches that passed the score_threshold
161163
'''
162164
if maxOverlap<0 or maxOverlap>1:
163165
raise ValueError("Maximal overlap between bounding box is in range [0-1]")
164166

165-
listHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)
167+
tableHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)
166168

167-
if method == 1: bestHits = NMS(listHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=False)
169+
if method == 1: bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=False)
168170

169-
elif method in (3,5): bestHits = NMS(listHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=True)
171+
elif method in (3,5): bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=True)
170172

171173
return bestHits
172174

173175

174-
def drawBoxesOnRGB(image, listHit, boxThickness=2, boxColor=(255, 255, 00), showLabel=False, labelColor=(255, 255, 0), labelScale=0.5 ):
176+
def drawBoxesOnRGB(image, tableHit, boxThickness=2, boxColor=(255, 255, 00), showLabel=False, labelColor=(255, 255, 0), labelScale=0.5 ):
175177
'''
176178
Return a copy of the image with predicted template locations as bounding boxes overlaid on the image
177179
The name of the template can also be displayed on top of the bounding box with showLabel=True
178180
Parameters
179181
----------
180182
- image : image in which the search was performed
181-
- listHit: list of hit as returned by matchTemplates or findMatches
183+
- tableHit: list of hit as returned by matchTemplates or findMatches
182184
- boxThickness: int
183185
thickness of bounding box contour in pixels
184186
- boxColor: (int, int, int)
@@ -197,22 +199,22 @@ def drawBoxesOnRGB(image, listHit, boxThickness=2, boxColor=(255, 255, 00), show
197199
if image.ndim == 2: outImage = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert to RGB to be able to show detections as color box on grayscale image
198200
else: outImage = image.copy()
199201

200-
for hit in listHit:
201-
x,y,w,h = hit['BBox']
202+
for index, row in tableHit.iterrows():
203+
x,y,w,h = row['BBox']
202204
cv2.rectangle(outImage, (x, y), (x+w, y+h), color=boxColor, thickness=boxThickness)
203-
if showLabel: cv2.putText(outImage, text=hit['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
205+
if showLabel: cv2.putText(outImage, text=row['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
204206

205207
return outImage
206208

207209

208-
def drawBoxesOnGray(image, listHit, boxThickness=2, boxColor=255, showLabel=False, labelColor=255, labelScale=0.5):
210+
def drawBoxesOnGray(image, tableHit, boxThickness=2, boxColor=255, showLabel=False, labelColor=255, labelScale=0.5):
209211
'''
210212
Same as drawBoxesOnRGB but with Graylevel.
211213
If a RGB image is provided, the output image will be a grayscale image
212214
Parameters
213215
----------
214216
- image : image in which the search was performed
215-
- listHit: list of hit as returned by matchTemplates or findMatches
217+
- tableHit: list of hit as returned by matchTemplates or findMatches
216218
- boxThickness: int
217219
thickness of bounding box contour in pixels
218220
- boxColor: int
@@ -231,10 +233,10 @@ def drawBoxesOnGray(image, listHit, boxThickness=2, boxColor=255, showLabel=Fals
231233
if image.ndim == 3: outImage = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # convert to RGB to be able to show detections as color box on grayscale image
232234
else: outImage = image.copy()
233235

234-
for hit in listHit:
235-
x,y,w,h = hit['BBox']
236+
for index, row in tableHit.iterrows():
237+
x,y,w,h = row['BBox']
236238
cv2.rectangle(outImage, (x, y), (x+w, y+h), color=boxColor, thickness=boxThickness)
237-
if showLabel: cv2.putText(outImage, text=hit['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
239+
if showLabel: cv2.putText(outImage, text=row['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=labelScale, color=labelColor, lineType=cv2.LINE_AA)
238240

239241
return outImage
240242

@@ -243,20 +245,20 @@ def drawBoxesOnGray(image, listHit, boxThickness=2, boxColor=255, showLabel=Fals
243245
from skimage.data import coins
244246
import matplotlib.pyplot as plt
245247

246-
## Get image and template
248+
## Get image and templates by cropping
247249
smallCoin = coins()[37:37+38, 80:80+41]
248250
bigCoin = coins()[14:14+59,302:302+65]
249251
image = coins()
250252

251253
## Perform matching
252-
listHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
253-
#listHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, N_object=1, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
254+
#tableHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
255+
tableHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, N_object=1, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
254256

255-
print("Found {} coins".format(len(listHit)))
257+
print("Found {} coins".format(len(tableHit)))
256258

257-
for hit in listHit:
258-
print(hit)
259+
print(tableHit)
260+
259261

260262
## Display matches
261-
Overlay = drawBoxesOnRGB(image, listHit, showLabel=True)
263+
Overlay = drawBoxesOnRGB(image, tableHit, showLabel=True)
262264
plt.imshow(Overlay)

0 commit comments

Comments
 (0)