|
| 1 | +import torch |
| 2 | + |
| 3 | +from torchdet3d.evaluation import (compute_average_distance, compute_metrics_per_cls, |
| 4 | + compute_2d_based_iou, compute_accuracy) |
| 5 | + |
| 6 | +from torchdet3d.losses import WingLoss, ADD_loss, DiagLoss |
| 7 | +from torchdet3d.builders import (build_loss, build_optimizer, build_scheduler, build_loader, |
| 8 | + build_model, AVAILABLE_LOSS, AVAILABLE_OPTIMS, AVAILABLE_SCHEDS) |
| 9 | +from torchdet3d.utils import read_py_config |
| 10 | + |
| 11 | + |
| 12 | +class TestCasesPipeline: |
| 13 | + gt_kps = torch.rand(128,9,2) |
| 14 | + test_kps = torch.rand(128,9,2, requires_grad=True) |
| 15 | + gt_cats = torch.randint(0,9,(128,)) |
| 16 | + test_cats = torch.rand(128,9) |
| 17 | + config = read_py_config("./configs/default_config.py") |
| 18 | + |
| 19 | + def test_metrics(self): |
| 20 | + ADD, SADD = compute_average_distance(self.test_kps, self.gt_kps) |
| 21 | + metrics = compute_metrics_per_cls(self.test_kps, self.gt_kps, self.test_cats, self.gt_cats) |
| 22 | + IOU = compute_2d_based_iou(self.test_kps, self.gt_kps) |
| 23 | + acc = compute_accuracy(self.test_cats, self.gt_cats) |
| 24 | + assert 0 <= ADD <= 1 and 0 <= SADD <= 1 and 0 <= IOU <= 1 and 0 <= acc <= 1 |
| 25 | + assert len(metrics) == 9 and len(metrics[0]) == 4 |
| 26 | + |
| 27 | + def test_losses(self): |
| 28 | + for loss in [WingLoss(), ADD_loss(), DiagLoss()]: |
| 29 | + input_ = torch.sigmoid(torch.randn(512, 9, 2, requires_grad=True)) |
| 30 | + target = torch.sigmoid(torch.randn(512, 9, 2)) |
| 31 | + output = loss(input_, target) |
| 32 | + assert not torch.any(torch.isnan(output)) |
| 33 | + output.backward() |
| 34 | + |
| 35 | + def test_builders(self): |
| 36 | + for loss_ in AVAILABLE_LOSS: |
| 37 | + if loss_ != 'cross_entropy': |
| 38 | + self.config['loss']['names']=[loss_, 'cross_entropy'] |
| 39 | + self.config.loss.coeffs=([1.],[1.]) |
| 40 | + regress_criterions, class_criterions = build_loss(self.config) |
| 41 | + assert len(regress_criterions) == 1 and len(class_criterions) == 1 |
| 42 | + model = build_model(self.config) |
| 43 | + assert model is not None |
| 44 | + for optim_ in AVAILABLE_OPTIMS: |
| 45 | + self.config['optim']['name'] = optim_ |
| 46 | + optimizer = build_optimizer(self.config, model) |
| 47 | + assert optimizer is not None |
| 48 | + for schd in AVAILABLE_SCHEDS: |
| 49 | + self.config['scheduler']['name'] = schd |
| 50 | + scheduler = build_scheduler(self.config, optimizer) |
| 51 | + assert scheduler is not None |
| 52 | + |
| 53 | + def test_random_inference(self): |
| 54 | + model = build_model(self.config) |
| 55 | + image = torch.rand(128,3,224,224) |
| 56 | + kp, cat = model(image, self.gt_cats) |
| 57 | + assert kp.shape == (128,9,2) |
| 58 | + assert cat.shape == (128,9) |
0 commit comments