Skip to content

Commit dda89d1

Browse files
committed
merge conflicts fix
2 parents 8069adc + c4b7fb8 commit dda89d1

File tree

15 files changed

+241
-19
lines changed

15 files changed

+241
-19
lines changed

.github/workflows/main.yml

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ jobs:
1414
python-version: [3.7]
1515

1616
steps:
17-
- uses: actions/checkout@v2
17+
- name: Checkout repository and submodules
18+
uses: actions/checkout@v2
19+
with:
20+
submodules: recursive
1821
- name: Set up Python ${{ matrix.python-version }}
1922
uses: actions/setup-python@v2
2023
with:
@@ -33,6 +36,12 @@ jobs:
3336
run: |
3437
python -m pip install --upgrade pip pytest pylint
3538
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
39+
- name: Install torchdet3d
40+
run: |
41+
python setup.py develop
42+
- name: Testing with pytest
43+
run: |
44+
python -m pytest . -s
3645
- name: Linting with pylint
3746
run: |
3847
python tests/run_pylint.py

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Training includes the following stages:
77
- Training a 2d bounding box detection model
88
- Training a 3d bounding box regression model
99

10-
Trained models can be deployed on CPU using [OpenVINO](https://docs.openvinotoolkit.org) framework and then run in [live demo]().
10+
Trained models can be deployed on CPU using [OpenVINO](https://docs.openvinotoolkit.org) framework and then run in [live demo](demo/demo.py).
1111

1212
## Installation guide
1313
```bash

annotation_converters/objectron_2_coco.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def save_2_coco(output_root, subset_name, data_info, obj_classes, fps_divisor,
8787
ann_folder = osp.join(output_root, 'annotations')
8888
img_folder = osp.join(output_root, 'images')
8989
if not osp.isdir(ann_folder):
90-
os.mkdir(ann_folder)
90+
os.makedirs(ann_folder, exist_ok=True)
9191
if not osp.isdir(img_folder):
92-
os.mkdir(img_folder)
92+
os.makedirs(img_folder, exist_ok=True)
9393

9494
img_id = 0
9595
ann_id = 0
@@ -142,8 +142,9 @@ def save_2_coco(output_root, subset_name, data_info, obj_classes, fps_divisor,
142142
frames[frame_idx] = cv.resize(frames[frame_idx], (w, h))
143143
for kp_pixel in keypoints[0]:
144144
cv.circle(frames[frame_idx], (kp_pixel[0], kp_pixel[1]), 5, (255, 0, 0), -1)
145-
for kp_pixel in keypoints[1]:
146-
cv.circle(frames[frame_idx], (kp_pixel[0], kp_pixel[1]), 5, (0, 0, 255), -1)
145+
if len(keypoints) > 1:
146+
for kp_pixel in keypoints[1]:
147+
cv.circle(frames[frame_idx], (kp_pixel[0], kp_pixel[1]), 5, (0, 0, 255), -1)
147148
for bbox in bboxes:
148149
if bbox is not None:
149150
cv.rectangle(frames[frame_idx], (bbox[0], bbox[1]),

annotation_converters/objectron_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def grab_frames(video_file, frame_ids, use_opencv=True):
9191
'-pix_fmt', 'rgb24', '-vcodec', 'rawvideo', '-vsync', 'vfr', '-'
9292
]
9393
pipe = subprocess.Popen(
94-
command, stdout=subprocess.PIPE, bufsize=151 * frame_size)
94+
command, stdout=subprocess.PIPE, bufsize=151 * frame_size, stderr=subprocess.DEVNULL)
9595
current_frame = np.frombuffer(
9696
pipe.stdout.read(frame_size), dtype='uint8').reshape(height, width, 3)
9797
pipe.stdout.flush()

demo/demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def main():
177177
help='MKLDNN (CPU)-targeted custom layers.Absolute path to a shared library with the kernels '
178178
'impl.', type=str, default=None)
179179
parser.add_argument('--write_video', action='store_true',
180-
help='wether or not to record demo video')
180+
help='whether to save a demo video or not')
181181
args = parser.parse_args()
182182

183183
if args.cam_id >= 0:

requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
albumentations
2-
attrdict
2+
addict
33
opencv-python
44
numpy
55
sklearn
@@ -18,4 +18,5 @@ efficientnet_lite1_pytorch_model
1818
efficientnet_lite2_pytorch_model
1919
optuna
2020
pylint
21-
isort
21+
isort
22+
pytest

tests/__init__.py

Whitespace-only changes.

tests/run_pylint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
'configs/detection',
1313
'torchdet3d/models',
1414
'build',
15-
'deprecated'
16-
'.history/',
17-
'torchdet3d/models'
15+
'deprecated',
16+
'.history',
17+
'torchdet3d/models',
1818
]
1919

2020
to_pylint = []

tests/test_geometry.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
3+
from torchdet3d.utils import (lift_2d, get_default_camera_matrix,
4+
convert_camera_matrix_2_ndc, project_3d_points,
5+
convert_2d_to_ndc)
6+
7+
8+
from objectron.dataset import iou
9+
from objectron.dataset import box
10+
11+
12+
class TestCasesGeometry:
13+
test_kps = np.array([[0.47714591, 0.47491544],
14+
[0.73884577, 0.39749265],
15+
[0.18508956, 0.40002537],
16+
[0.74114597, 0.48664019],
17+
[0.18273196, 0.48833901 ],
18+
[0.64639187, 0.46719882],
19+
[0.32766378, 0.46827659],
20+
[0.64726073, 0.51853681],
21+
[0.32699507, 0.51933688]])
22+
EPS = 1e-5
23+
IOU_THR = 0.5
24+
25+
def test_reprojection_error(self):
26+
kps_3d = lift_2d([self.test_kps], portrait=True)[0]
27+
reprojected_kps = project_3d_points(kps_3d, convert_camera_matrix_2_ndc(get_default_camera_matrix()))
28+
test_kps_ndc = convert_2d_to_ndc(self.test_kps, portrait=True)
29+
assert np.any(np.linalg.norm(test_kps_ndc - reprojected_kps, axis=1) < self.EPS)
30+
31+
def test_3d_iou_stability(self):
32+
np.random.seed(10)
33+
noisy_kps = np.clip(self.test_kps + 0.01*np.random.rand(*self.test_kps.shape), 0, 1)
34+
lifted_3d_sets = lift_2d([self.test_kps, noisy_kps], portrait=True)
35+
36+
b1 = box.Box(vertices=lifted_3d_sets[0])
37+
b2 = box.Box(vertices=lifted_3d_sets[1])
38+
39+
loss = iou.IoU(b1, b2)
40+
assert loss.iou() > self.IOU_THR

torchdet3d/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import sys
44
import os
55

6-
from .version import __version__
7-
86
module_path = os.path.abspath(os.path.join(os.path.dirname('__init__.py'), '3rdparty/Objectron'))
97
if module_path not in sys.path:
108
sys.path.append(module_path)
119

10+
#pylint: disable = wrong-import-position
1211
from torchdet3d import builders, evaluation, dataloaders, trainer, models, utils, losses
12+
from .version import __version__
1313

1414
__author__ = 'Sovrasov Vladislav, Prokofiev Kirill'
1515
__description__ = 'A library for deep learning 3D object detection in PyTorch'

torchdet3d/evaluation/evaluate.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88
from copy import deepcopy
99

10-
from .metrics import compute_accuracy, compute_average_distance, compute_metrics_per_cls
10+
from .metrics import compute_accuracy, compute_average_distance, compute_metrics_per_cls, compute_2d_based_iou
1111
from torchdet3d.utils import (AverageMeter, mkdir_if_missing, draw_kp, OBJECTRON_CLASSES)
1212
from torchdet3d.builders import build_augmentations
1313
from torchdet3d.dataloaders import Objectron
@@ -75,15 +75,17 @@ def visual_test(self):
7575
RGB=False,
7676
normalized=False,
7777
label=label)
78-
78+
@torch.no_grad()
7979
def val(self, epoch=None):
8080
''' procedure launching main validation '''
8181
ADD_meter = AverageMeter()
8282
SADD_meter = AverageMeter()
8383
ACC_meter = AverageMeter()
84+
IOU_meter = AverageMeter()
8485
ADD_cls_meter = [AverageMeter() for cl in range(self.num_classes)]
8586
SADD_cls_meter = [AverageMeter() for cl in range(self.num_classes)]
8687
acc_cls_meter = [AverageMeter() for cl in range(self.num_classes)]
88+
IOU__cls_meter = [AverageMeter() for cl in range(self.num_classes)]
8789

8890
# switch to eval mode
8991
self.model.eval()
@@ -95,16 +97,19 @@ def val(self, epoch=None):
9597
pred_kp, pred_cats = self.model(imgs, gt_cats)
9698
# measure metrics
9799
ADD, SADD = compute_average_distance(pred_kp, gt_kp)
100+
IOU = compute_2d_based_iou(pred_kp, gt_kp)
101+
acc = compute_accuracy(pred_cats, gt_cats)
102+
98103
for cl, ADD_cls, SADD_cls, acc_cls in compute_metrics_per_cls(pred_kp, gt_kp, gt_cats, pred_cats):
99104
ADD_cls_meter[cl].update(ADD_cls, imgs.size(0))
100105
SADD_cls_meter[cl].update(SADD_cls, imgs.size(0))
101106
acc_cls_meter[cl].update(acc_cls, imgs.size(0))
102107

103-
acc = compute_accuracy(pred_cats, gt_cats)
104108
# record loss
105109
ADD_meter.update(ADD, imgs.size(0))
106110
SADD_meter.update(SADD, imgs.size(0))
107111
ACC_meter.update(acc, imgs.size(0))
112+
IOU_meter.update(IOU)
108113
if epoch is not None:
109114
# update progress bar
110115
loop.set_description(f'Val Epoch [{epoch}/{self.max_epoch}]')
@@ -136,6 +141,7 @@ def val(self, epoch=None):
136141
f"{ep_mess}"
137142
f"ADD overall ---> {ADD_meter.avg}\n"
138143
f"SADD overall ---> {SADD_meter.avg}\n"
144+
f"IOU ---> {IOU_meter.avg}\n"
139145
f"classification accuracy overall ---> {ACC_meter.avg}\n"
140146
f"{per_class_metr_message}")
141147

torchdet3d/evaluation/metrics.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
import torch
2+
import scipy
3+
import numpy as np
4+
5+
from objectron.dataset import iou
6+
from objectron.dataset import box
7+
8+
from torchdet3d.utils import lift_2d
9+
210

311
def compute_average_distance(pred_kp, gt_kp, num_keypoint=9, **kwargs):
412
"""Computes Average Distance (ADD) metric."""
@@ -41,3 +49,21 @@ def compute_metrics_per_cls(pred_kp, gt_kp, gt_cats, pred_cats, **kwargs):
4149
computed_metrics.append((cl, ADD, SADD, acc))
4250

4351
return computed_metrics
52+
53+
def compute_2d_based_iou(pred_kp: torch.Tensor, gt_kp: torch.Tensor):
54+
assert len(pred_kp.shape) == 3
55+
bs = pred_kp.shape[0]
56+
pred_kp_np = pred_kp.cpu().numpy()
57+
gt_kp_np = gt_kp.cpu().numpy()
58+
total_iou = 0
59+
for i in range(bs):
60+
kps_3d = lift_2d([pred_kp_np[i], gt_kp_np[i]], portrait=True)
61+
b_pred = box.Box(vertices=kps_3d[0])
62+
b_gt = box.Box(vertices=kps_3d[1])
63+
try:
64+
total_iou += iou.IoU(b_pred, b_gt).iou()
65+
except scipy.spatial.qhull.QhullError:
66+
pass
67+
except np.linalg.LinAlgError:
68+
pass
69+
return total_iou / bs

torchdet3d/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .utils import *
22
from .transforms import *
3+
from .geometry import *

torchdet3d/utils/geometry.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
5+
6+
__epnp_alpha__ = np.array([[4, -1, -1, -1],
7+
[2, -1, -1, 1],
8+
[2, -1, 1, -1],
9+
[0, -1, 1, 1],
10+
[2, 1, -1, -1],
11+
[0, 1, -1, 1],
12+
[0, 1, 1, -1],
13+
[-2, 1, 1, 1]])
14+
15+
16+
def get_default_camera_matrix():
17+
return np.array([[1, 0, 0.5],
18+
[0, 1, 0.5],
19+
[0, 0, 1]])
20+
21+
22+
def project_3d_points(points: np.array, camera_matrix: np.array):
23+
assert len(points.shape) == 2
24+
projection = np.matmul(camera_matrix, points.T).T
25+
projection /= -projection[:,2].reshape(-1, 1)
26+
return projection[:, :-1]
27+
28+
29+
def convert_camera_matrix_2_ndc(matrix: np.array, img_shape: Tuple[int, int]=(1, 1)):
30+
ndc_mat = np.copy(matrix)
31+
ndc_mat[0, 0] *= 2.0 / img_shape[0]
32+
ndc_mat[1, 1] *= 2.0 / img_shape[1]
33+
34+
ndc_mat[0, 2] = -ndc_mat[0, 2] * 2.0 / img_shape[0] + 1.0
35+
ndc_mat[1, 2] = -ndc_mat[1, 2] * 2.0 / img_shape[1] + 1.0
36+
37+
return ndc_mat
38+
39+
40+
def convert_2d_to_ndc(points: np.array, portrait: bool=False):
41+
converted_points = np.zeros_like(points)
42+
if portrait:
43+
converted_points[:, 0] = points[:, 1] * 2 - 1
44+
converted_points[:, 1] = points[:, 0] * 2 - 1
45+
else:
46+
converted_points[:, 0] = points[:, 0] * 2 - 1
47+
converted_points[:, 1] = 1 - points[:, 1] * 2
48+
return converted_points
49+
50+
51+
def lift_2d(keypoint_sets: List[np.array],
52+
camera_matrix: np.array=get_default_camera_matrix(),
53+
portrait: bool=False) -> List[np.array]:
54+
"""
55+
Function takes normalized 2d coordinates of 2d keypoints on the image plane,
56+
camera matrix in normalized image space and outputs lifted 3d points in camera coordinates,
57+
which are defined up to an unknown scale factor
58+
"""
59+
ndc_cam_mat = convert_camera_matrix_2_ndc(camera_matrix)
60+
fx = ndc_cam_mat[0, 0]
61+
fy = ndc_cam_mat[1, 1]
62+
cx = ndc_cam_mat[0, 2]
63+
cy = ndc_cam_mat[1, 2]
64+
65+
lifted_keypoint_sets = []
66+
67+
for kp_set in keypoint_sets:
68+
m = np.zeros((16, 12))
69+
assert len(kp_set) == 9
70+
71+
for i in range(8):
72+
kp = kp_set[i + 1]
73+
# Convert 2d point from normalized screen coordinates [0, 1] to NDC coordinates([-1, 1]).
74+
if portrait:
75+
u = kp[1] * 2 - 1
76+
v = kp[0] * 2 - 1
77+
else:
78+
u = kp[0] * 2 - 1
79+
v = 1 - kp[1] * 2
80+
81+
for j in range(4):
82+
# For each of the 4 control points, formulate two rows of the
83+
# m matrix (two equations).
84+
control_alpha = __epnp_alpha__[i, j]
85+
m[i * 2, j * 3] = fx * control_alpha
86+
m[i * 2, j * 3 + 2] = (cx + u) * control_alpha
87+
m[i * 2 + 1, j * 3 + 1] = fy * control_alpha
88+
m[i * 2 + 1, j * 3 + 2] = (cy + v) * control_alpha
89+
90+
mt_m = np.matmul(m.T, m)
91+
w, v = np.linalg.eigh(mt_m)
92+
assert w.shape[0] == 12
93+
control_matrix = v[:, 0].reshape(4, 3)
94+
# All 3d points should be in front of camera (z < 0).
95+
if control_matrix[0, 2] > 0:
96+
control_matrix = -control_matrix
97+
98+
lifted_kp_set = []
99+
lifted_kp_set.append(control_matrix[0, :])
100+
vertices = np.matmul(__epnp_alpha__, control_matrix)
101+
102+
for i in range(8):
103+
lifted_kp_set.append(vertices[i, :])
104+
105+
lifted_kp_set = np.array(lifted_kp_set)
106+
lifted_keypoint_sets.append(lifted_kp_set)
107+
108+
return lifted_keypoint_sets
109+
110+
111+
def draw_boxes(boxes=[], clips=[], colors=['r', 'b', 'g', 'k']):
112+
"""Draw a list of boxes.
113+
The boxes are defined as a list of vertices
114+
"""
115+
import matplotlib.pyplot as plt
116+
from objectron.dataset import box
117+
118+
fig = plt.figure(figsize=(10, 10))
119+
ax = fig.add_subplot(111, projection='3d')
120+
for i, b in enumerate(boxes):
121+
x, y, z = b[:, 0], b[:, 1], b[:, 2]
122+
ax.scatter(x, y, z, c='r')
123+
for e in box.EDGES:
124+
ax.plot(x[e], y[e], z[e], linewidth=2, c=colors[i % len(colors)])
125+
126+
if clips:
127+
points = np.array(clips)
128+
ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=100, c='k')
129+
130+
plt.gca().patch.set_facecolor('white')
131+
ax.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
132+
ax.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
133+
ax.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 1.0))
134+
135+
# rotate the axes and update
136+
ax.view_init(30, 12)
137+
plt.draw()
138+
plt.savefig('3d_boxes.png')

0 commit comments

Comments
 (0)