Skip to content

Commit 08a11c1

Browse files
authored
add a jupyter notebook demo (open-mmlab#1158)
1 parent 63b9d10 commit 08a11c1

File tree

6 files changed

+134
-6
lines changed

6 files changed

+134
-6
lines changed

GETTING_STARTED.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
6262
We provide a webcam demo to illustrate the results.
6363

6464
```shell
65-
python tools/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}]
65+
python demo/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}]
6666
```
6767

6868
Examples:
6969

7070
```shell
71-
python tools/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
71+
python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
7272
checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth
7373
```
7474

@@ -103,6 +103,8 @@ for frame in video:
103103
show_result(frame, result, model.CLASSES, wait_time=1)
104104
```
105105

106+
A notebook demo can be found in [demo/inference_demo.ipynb](demo/inference_demo.ipynb).
107+
106108

107109
## Train a model
108110

demo/demo.jpg

254 KB
Loading

demo/inference_demo.ipynb

Lines changed: 92 additions & 0 deletions
Large diffs are not rendered by default.
File renamed without changes.

mmdet/apis/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .env import get_root_logger, init_dist, set_random_seed
2-
from .inference import inference_detector, init_detector, show_result
2+
from .inference import (inference_detector, init_detector, show_result,
3+
show_result_pyplot)
34
from .train import train_detector
45

56
__all__ = [
67
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
7-
'init_detector', 'inference_detector', 'show_result'
8+
'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot'
89
]

mmdet/apis/inference.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22

3+
import matplotlib.pyplot as plt
34
import mmcv
45
import numpy as np
56
import pycocotools.mask as maskUtils
@@ -105,6 +106,7 @@ def show_result(img,
105106
class_names,
106107
score_thr=0.3,
107108
wait_time=0,
109+
show=True,
108110
out_file=None):
109111
"""Visualize the detection results on the image.
110112
@@ -115,11 +117,17 @@ def show_result(img,
115117
class_names (list[str] or tuple[str]): A list of class names.
116118
score_thr (float): The threshold to visualize the bboxes and masks.
117119
wait_time (int): Value of waitKey param.
120+
show (bool, optional): Whether to show the image with opencv or not.
118121
out_file (str, optional): If specified, the visualization result will
119122
be written to the out file instead of shown in a window.
123+
124+
Returns:
125+
np.ndarray or None: If neither `show` nor `out_file` is specified, the
126+
visualized image is returned, otherwise None is returned.
120127
"""
121128
assert isinstance(class_names, (tuple, list))
122129
img = mmcv.imread(img)
130+
img = img.copy()
123131
if isinstance(result, tuple):
124132
bbox_result, segm_result = result
125133
else:
@@ -140,11 +148,36 @@ def show_result(img,
140148
]
141149
labels = np.concatenate(labels)
142150
mmcv.imshow_det_bboxes(
143-
img.copy(),
151+
img,
144152
bboxes,
145153
labels,
146154
class_names=class_names,
147155
score_thr=score_thr,
148-
show=out_file is None,
156+
show=show,
149157
wait_time=wait_time,
150158
out_file=out_file)
159+
if not (show or out_file):
160+
return img
161+
162+
163+
def show_result_pyplot(img,
164+
result,
165+
class_names,
166+
score_thr=0.3,
167+
fig_size=(15, 10)):
168+
"""Visualize the detection results on the image.
169+
170+
Args:
171+
img (str or np.ndarray): Image filename or loaded image.
172+
result (tuple[list] or list): The detection result, can be either
173+
(bbox, segm) or just bbox.
174+
class_names (list[str] or tuple[str]): A list of class names.
175+
score_thr (float): The threshold to visualize the bboxes and masks.
176+
fig_size (tuple): Figure size of the pyplot figure.
177+
out_file (str, optional): If specified, the visualization result will
178+
be written to the out file instead of shown in a window.
179+
"""
180+
img = show_result(
181+
img, result, class_names, score_thr=score_thr, show=False)
182+
plt.figure(figsize=fig_size)
183+
plt.imshow(mmcv.bgr2rgb(img))

0 commit comments

Comments
 (0)