|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/ |
| 3 | +master/code/train_LA_HD.py (Apache-2.0 License)""" |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +from scipy.ndimage import distance_transform_edt as distance |
| 8 | +from torch import Tensor |
| 9 | + |
| 10 | +from mmseg.registry import MODELS |
| 11 | +from .utils import get_class_weight, weighted_loss |
| 12 | + |
| 13 | + |
| 14 | +def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor: |
| 15 | + """ |
| 16 | + compute the distance transform map of foreground in mask |
| 17 | + Args: |
| 18 | + img_gt: Ground truth of the image, (b, h, w) |
| 19 | + pred: Predictions of the segmentation head after softmax, (b, c, h, w) |
| 20 | +
|
| 21 | + Returns: |
| 22 | + output: the foreground Distance Map (SDM) |
| 23 | + dtm(x) = 0; x in segmentation boundary |
| 24 | + inf|x-y|; x in segmentation |
| 25 | + """ |
| 26 | + |
| 27 | + fg_dtm = torch.zeros_like(pred) |
| 28 | + out_shape = pred.shape |
| 29 | + for b in range(out_shape[0]): # batch size |
| 30 | + for c in range(1, out_shape[1]): # default 0 channel is background |
| 31 | + posmask = img_gt[b].byte() |
| 32 | + if posmask.any(): |
| 33 | + posdis = distance(posmask) |
| 34 | + fg_dtm[b][c] = torch.from_numpy(posdis) |
| 35 | + |
| 36 | + return fg_dtm |
| 37 | + |
| 38 | + |
| 39 | +@weighted_loss |
| 40 | +def hd_loss(seg_soft: Tensor, |
| 41 | + gt: Tensor, |
| 42 | + seg_dtm: Tensor, |
| 43 | + gt_dtm: Tensor, |
| 44 | + class_weight=None, |
| 45 | + ignore_index=255) -> Tensor: |
| 46 | + """ |
| 47 | + compute huasdorff distance loss for segmentation |
| 48 | + Args: |
| 49 | + seg_soft: softmax results, shape=(b,c,x,y) |
| 50 | + gt: ground truth, shape=(b,x,y) |
| 51 | + seg_dtm: segmentation distance transform map, shape=(b,c,x,y) |
| 52 | + gt_dtm: ground truth distance transform map, shape=(b,c,x,y) |
| 53 | +
|
| 54 | + Returns: |
| 55 | + output: hd_loss |
| 56 | + """ |
| 57 | + assert seg_soft.shape[0] == gt.shape[0] |
| 58 | + total_loss = 0 |
| 59 | + num_class = seg_soft.shape[1] |
| 60 | + if class_weight is not None: |
| 61 | + assert class_weight.ndim == num_class |
| 62 | + for i in range(1, num_class): |
| 63 | + if i != ignore_index: |
| 64 | + delta_s = (seg_soft[:, i, ...] - gt.float())**2 |
| 65 | + s_dtm = seg_dtm[:, i, ...]**2 |
| 66 | + g_dtm = gt_dtm[:, i, ...]**2 |
| 67 | + dtm = s_dtm + g_dtm |
| 68 | + multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm) |
| 69 | + hd_loss = multiplied.mean() |
| 70 | + if class_weight is not None: |
| 71 | + hd_loss *= class_weight[i] |
| 72 | + total_loss += hd_loss |
| 73 | + |
| 74 | + return total_loss / num_class |
| 75 | + |
| 76 | + |
| 77 | +@MODELS.register_module() |
| 78 | +class HuasdorffDisstanceLoss(nn.Module): |
| 79 | + """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform |
| 80 | + Maps Boost Segmentation CNNs: An Empirical Study. |
| 81 | +
|
| 82 | + <http://proceedings.mlr.press/v121/ma20b.html>`_. |
| 83 | + Args: |
| 84 | + reduction (str, optional): The method used to reduce the loss into |
| 85 | + a scalar. Defaults to 'mean'. |
| 86 | + class_weight (list[float] | str, optional): Weight of each class. If in |
| 87 | + str format, read them from a file. Defaults to None. |
| 88 | + loss_weight (float): Weight of the loss. Defaults to 1.0. |
| 89 | + ignore_index (int | None): The label index to be ignored. Default: 255. |
| 90 | + loss_name (str): Name of the loss item. If you want this loss |
| 91 | + item to be included into the backward graph, `loss_` must be the |
| 92 | + prefix of the name. Defaults to 'loss_boundary'. |
| 93 | + """ |
| 94 | + |
| 95 | + def __init__(self, |
| 96 | + reduction='mean', |
| 97 | + class_weight=None, |
| 98 | + loss_weight=1.0, |
| 99 | + ignore_index=255, |
| 100 | + loss_name='loss_huasdorff_disstance', |
| 101 | + **kwargs): |
| 102 | + super().__init__() |
| 103 | + self.reduction = reduction |
| 104 | + self.loss_weight = loss_weight |
| 105 | + self.class_weight = get_class_weight(class_weight) |
| 106 | + self._loss_name = loss_name |
| 107 | + self.ignore_index = ignore_index |
| 108 | + |
| 109 | + def forward(self, |
| 110 | + pred: Tensor, |
| 111 | + target: Tensor, |
| 112 | + avg_factor=None, |
| 113 | + reduction_override=None, |
| 114 | + **kwargs) -> Tensor: |
| 115 | + """Forward function. |
| 116 | +
|
| 117 | + Args: |
| 118 | + pred (Tensor): Predictions of the segmentation head. (B, C, H, W) |
| 119 | + target (Tensor): Ground truth of the image. (B, H, W) |
| 120 | + avg_factor (int, optional): Average factor that is used to |
| 121 | + average the loss. Defaults to None. |
| 122 | + reduction_override (str, optional): The reduction method used |
| 123 | + to override the original reduction method of the loss. |
| 124 | + Options are "none", "mean" and "sum". |
| 125 | + Returns: |
| 126 | + Tensor: Loss tensor. |
| 127 | + """ |
| 128 | + assert reduction_override in (None, 'none', 'mean', 'sum') |
| 129 | + reduction = ( |
| 130 | + reduction_override if reduction_override else self.reduction) |
| 131 | + if self.class_weight is not None: |
| 132 | + class_weight = pred.new_tensor(self.class_weight) |
| 133 | + else: |
| 134 | + class_weight = None |
| 135 | + |
| 136 | + pred_soft = F.softmax(pred, dim=1) |
| 137 | + valid_mask = (target != self.ignore_index).long() |
| 138 | + target = target * valid_mask |
| 139 | + |
| 140 | + with torch.no_grad(): |
| 141 | + gt_dtm = compute_dtm(target.cpu(), pred_soft) |
| 142 | + gt_dtm = gt_dtm.float() |
| 143 | + seg_dtm2 = compute_dtm( |
| 144 | + pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft) |
| 145 | + seg_dtm2 = seg_dtm2.float() |
| 146 | + |
| 147 | + loss_hd = self.loss_weight * hd_loss( |
| 148 | + pred_soft, |
| 149 | + target, |
| 150 | + seg_dtm=seg_dtm2, |
| 151 | + gt_dtm=gt_dtm, |
| 152 | + reduction=reduction, |
| 153 | + avg_factor=avg_factor, |
| 154 | + class_weight=class_weight, |
| 155 | + ignore_index=self.ignore_index) |
| 156 | + return loss_hd |
| 157 | + |
| 158 | + @property |
| 159 | + def loss_name(self): |
| 160 | + return self._loss_name |
0 commit comments