Skip to content

Commit bb93b48

Browse files
jinxianweixiexinch
andauthored
[Feature] huasdorff distance loss (#2820)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Add Huasdorff distance loss --------- Co-authored-by: xiexinch <[email protected]>
1 parent b2f4b4f commit bb93b48

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

mmseg/models/losses/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
cross_entropy, mask_cross_entropy)
66
from .dice_loss import DiceLoss
77
from .focal_loss import FocalLoss
8+
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
89
from .lovasz_loss import LovaszLoss
910
from .ohem_cross_entropy_loss import OhemCrossEntropy
1011
from .tversky_loss import TverskyLoss
@@ -14,5 +15,6 @@
1415
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
1516
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
1617
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
17-
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss'
18+
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
19+
'HuasdorffDisstanceLoss'
1820
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
3+
master/code/train_LA_HD.py (Apache-2.0 License)"""
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from scipy.ndimage import distance_transform_edt as distance
8+
from torch import Tensor
9+
10+
from mmseg.registry import MODELS
11+
from .utils import get_class_weight, weighted_loss
12+
13+
14+
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
15+
"""
16+
compute the distance transform map of foreground in mask
17+
Args:
18+
img_gt: Ground truth of the image, (b, h, w)
19+
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
20+
21+
Returns:
22+
output: the foreground Distance Map (SDM)
23+
dtm(x) = 0; x in segmentation boundary
24+
inf|x-y|; x in segmentation
25+
"""
26+
27+
fg_dtm = torch.zeros_like(pred)
28+
out_shape = pred.shape
29+
for b in range(out_shape[0]): # batch size
30+
for c in range(1, out_shape[1]): # default 0 channel is background
31+
posmask = img_gt[b].byte()
32+
if posmask.any():
33+
posdis = distance(posmask)
34+
fg_dtm[b][c] = torch.from_numpy(posdis)
35+
36+
return fg_dtm
37+
38+
39+
@weighted_loss
40+
def hd_loss(seg_soft: Tensor,
41+
gt: Tensor,
42+
seg_dtm: Tensor,
43+
gt_dtm: Tensor,
44+
class_weight=None,
45+
ignore_index=255) -> Tensor:
46+
"""
47+
compute huasdorff distance loss for segmentation
48+
Args:
49+
seg_soft: softmax results, shape=(b,c,x,y)
50+
gt: ground truth, shape=(b,x,y)
51+
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
52+
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
53+
54+
Returns:
55+
output: hd_loss
56+
"""
57+
assert seg_soft.shape[0] == gt.shape[0]
58+
total_loss = 0
59+
num_class = seg_soft.shape[1]
60+
if class_weight is not None:
61+
assert class_weight.ndim == num_class
62+
for i in range(1, num_class):
63+
if i != ignore_index:
64+
delta_s = (seg_soft[:, i, ...] - gt.float())**2
65+
s_dtm = seg_dtm[:, i, ...]**2
66+
g_dtm = gt_dtm[:, i, ...]**2
67+
dtm = s_dtm + g_dtm
68+
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
69+
hd_loss = multiplied.mean()
70+
if class_weight is not None:
71+
hd_loss *= class_weight[i]
72+
total_loss += hd_loss
73+
74+
return total_loss / num_class
75+
76+
77+
@MODELS.register_module()
78+
class HuasdorffDisstanceLoss(nn.Module):
79+
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
80+
Maps Boost Segmentation CNNs: An Empirical Study.
81+
82+
<http://proceedings.mlr.press/v121/ma20b.html>`_.
83+
Args:
84+
reduction (str, optional): The method used to reduce the loss into
85+
a scalar. Defaults to 'mean'.
86+
class_weight (list[float] | str, optional): Weight of each class. If in
87+
str format, read them from a file. Defaults to None.
88+
loss_weight (float): Weight of the loss. Defaults to 1.0.
89+
ignore_index (int | None): The label index to be ignored. Default: 255.
90+
loss_name (str): Name of the loss item. If you want this loss
91+
item to be included into the backward graph, `loss_` must be the
92+
prefix of the name. Defaults to 'loss_boundary'.
93+
"""
94+
95+
def __init__(self,
96+
reduction='mean',
97+
class_weight=None,
98+
loss_weight=1.0,
99+
ignore_index=255,
100+
loss_name='loss_huasdorff_disstance',
101+
**kwargs):
102+
super().__init__()
103+
self.reduction = reduction
104+
self.loss_weight = loss_weight
105+
self.class_weight = get_class_weight(class_weight)
106+
self._loss_name = loss_name
107+
self.ignore_index = ignore_index
108+
109+
def forward(self,
110+
pred: Tensor,
111+
target: Tensor,
112+
avg_factor=None,
113+
reduction_override=None,
114+
**kwargs) -> Tensor:
115+
"""Forward function.
116+
117+
Args:
118+
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
119+
target (Tensor): Ground truth of the image. (B, H, W)
120+
avg_factor (int, optional): Average factor that is used to
121+
average the loss. Defaults to None.
122+
reduction_override (str, optional): The reduction method used
123+
to override the original reduction method of the loss.
124+
Options are "none", "mean" and "sum".
125+
Returns:
126+
Tensor: Loss tensor.
127+
"""
128+
assert reduction_override in (None, 'none', 'mean', 'sum')
129+
reduction = (
130+
reduction_override if reduction_override else self.reduction)
131+
if self.class_weight is not None:
132+
class_weight = pred.new_tensor(self.class_weight)
133+
else:
134+
class_weight = None
135+
136+
pred_soft = F.softmax(pred, dim=1)
137+
valid_mask = (target != self.ignore_index).long()
138+
target = target * valid_mask
139+
140+
with torch.no_grad():
141+
gt_dtm = compute_dtm(target.cpu(), pred_soft)
142+
gt_dtm = gt_dtm.float()
143+
seg_dtm2 = compute_dtm(
144+
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
145+
seg_dtm2 = seg_dtm2.float()
146+
147+
loss_hd = self.loss_weight * hd_loss(
148+
pred_soft,
149+
target,
150+
seg_dtm=seg_dtm2,
151+
gt_dtm=gt_dtm,
152+
reduction=reduction,
153+
avg_factor=avg_factor,
154+
class_weight=class_weight,
155+
ignore_index=self.ignore_index)
156+
return loss_hd
157+
158+
@property
159+
def loss_name(self):
160+
return self._loss_name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
5+
from mmseg.models.losses import HuasdorffDisstanceLoss
6+
7+
8+
def test_huasdorff_distance_loss():
9+
loss_class = HuasdorffDisstanceLoss
10+
pred = torch.rand((10, 8, 6, 6))
11+
target = torch.rand((10, 6, 6))
12+
class_weight = torch.rand(8)
13+
14+
# Test loss forward
15+
loss = loss_class()(pred, target)
16+
assert isinstance(loss, torch.Tensor)
17+
18+
# Test loss forward with avg_factor
19+
loss = loss_class()(pred, target, avg_factor=10)
20+
assert isinstance(loss, torch.Tensor)
21+
22+
# Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
23+
for reduction in [None, 'sum', 'mean']:
24+
loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
25+
assert isinstance(loss, torch.Tensor)
26+
27+
# Test loss forward with class_weight
28+
with pytest.raises(AssertionError):
29+
loss_class(class_weight=class_weight)(pred, target)

0 commit comments

Comments
 (0)