Skip to content

Commit

Permalink
Add doc for torch_em.self_training
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 4, 2025
1 parent 6b7c4be commit 65bc76e
Show file tree
Hide file tree
Showing 16 changed files with 842 additions and 241 deletions.
1 change: 0 additions & 1 deletion torch_em/data/datasets/electron_microscopy/asem.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"er": ["cell_3.zarr", "cell_6.zarr", "cell_13.zarr"],
}


VOLUMES = {
"cell_1": "cell_1/cell_1.zarr", # mito (Y) golgi (Y) er (Y)
"cell_2": "cell_2/cell_2.zarr", # mito (Y) golgi (Y) er (Y)
Expand Down
170 changes: 150 additions & 20 deletions torch_em/metric/instance_segmentation_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import List, Optional

import numpy as np
import elf.evaluation as elfval
Expand Down Expand Up @@ -211,16 +212,25 @@ def __call__(self, seg, target):


class EmbeddingMWSIOUMetric(BaseInstanceSegmentationMetric):
"""
"""Intersection over union metric based on mutex watershed computed from embedding-derived affinites.
This class can be used as validation metric when training a network for instance segmentation.
Args:
delta:
offsets:
min_seg_size:
iou_threshold:
strides:
delta: The hinge distance of the contrastive loss for training the embeddings.
offsets: The offsets for deriving the affinities from the embeddings.
min_seg_size: Size for filtering the segmentation objects.
iou_threshold: Threshold for the intersection over union metric.
strides: The strides for the mutex watershed.
"""
def __init__(self, delta, offsets, min_seg_size, iou_threshold=0.5, strides=None):
def __init__(
self,
delta: float,
offsets: List[List[int]],
min_seg_size: int,
iou_threshold: float = 0.5,
strides: Optional[List[int]] = None,
):
segmenter = EmbeddingMWS(delta, offsets, with_background=True, min_seg_size=min_seg_size)
metric = IOUError(iou_threshold)
super().__init__(segmenter, metric)
Expand All @@ -229,63 +239,136 @@ def __init__(self, delta, offsets, min_seg_size, iou_threshold=0.5, strides=None


class EmbeddingMWSSBDMetric(BaseInstanceSegmentationMetric):
def __init__(self, delta, offsets, min_seg_size, strides=None):
"""Symmetric best dice metric based on mutex watershed computed from embedding-derived affinites.
This class can be used as validation metric when training a network for instance segmentation.
Args:
delta: The hinge distance of the contrastive loss for training the embeddings.
offsets: The offsets for deriving the affinities from the embeddings.
min_seg_size: Size for filtering the segmentation objects.
strides: The strides for the mutex watershed.
"""
def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
segmenter = EmbeddingMWS(delta, offsets, with_background=True, min_seg_size=min_seg_size)
metric = SymmetricBestDice()
super().__init__(segmenter, metric)
self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}


class EmbeddingMWSVOIMetric(BaseInstanceSegmentationMetric):
def __init__(self, delta, offsets, min_seg_size, strides=None):
"""Variation of inofrmation metric based on mutex watershed computed from embedding-derived affinites.
This class can be used as validation metric when training a network for instance segmentation.
Args:
delta: The hinge distance of the contrastive loss for training the embeddings.
offsets: The offsets for deriving the affinities from the embeddings.
min_seg_size: Size for filtering the segmentation objects.
strides: The strides for the mutex watershed.
"""
def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
segmenter = EmbeddingMWS(delta, offsets, with_background=False, min_seg_size=min_seg_size)
metric = VariationOfInformation()
super().__init__(segmenter, metric)
self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}


class EmbeddingMWSRandMetric(BaseInstanceSegmentationMetric):
def __init__(self, delta, offsets, min_seg_size, strides=None):
"""Rand index metric based on mutex watershed computed from embedding-derived affinites.
This class can be used as validation metric when training a network for instance segmentation.
Args:
delta: The hinge distance of the contrastive loss for training the embeddings.
offsets: The offsets for deriving the affinities from the embeddings.
min_seg_size: Size for filtering the segmentation objects.
strides: The strides for the mutex watershed.
"""
def __init__(self, delta: float, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int] ] = None):
segmenter = EmbeddingMWS(delta, offsets, with_background=False, min_seg_size=min_seg_size)
metric = AdaptedRandError()
super().__init__(segmenter, metric)
self.init_kwargs = {"delta": delta, "offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}


class HDBScanIOUMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_size, eps, iou_threshold=0.5):
"""Intersection over union metric based on HDBScan computed from embeddings.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_size: The minimal segment size.
eps: The epsilon value for HDBScan.
iou_threshold: The threshold for the intersection over union value.
"""
def __init__(self, min_size: int, eps: float, iou_threshold: float = 0.5):
segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
metric = IOUError(iou_threshold)
super().__init__(segmenter, metric)
self.init_kwargs = {"min_size": min_size, "eps": eps, "iou_threshold": iou_threshold}


class HDBScanSBDMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_size, eps):
"""Symmetric best dice metric based on HDBScan computed from embeddings.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_size: The minimal segment size.
eps: The epsilon value for HDBScan.
"""
def __init__(self, min_size: int, eps: float):
segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
metric = SymmetricBestDice()
super().__init__(segmenter, metric)
self.init_kwargs = {"min_size": min_size, "eps": eps}


class HDBScanRandMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_size, eps):
"""Rand index metric based on HDBScan computed from embeddings.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_size: The minimal segment size.
eps: The epsilon value for HDBScan.
"""
def __init__(self, min_size: int, eps: float):
segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
metric = AdaptedRandError()
super().__init__(segmenter, metric)
self.init_kwargs = {"min_size": min_size, "eps": eps}


class HDBScanVOIMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_size, eps):
"""Variation of information metric based on HDBScan computed from embeddings.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_size: The minimal segment size.
eps: The epsilon value for HDBScan.
"""
def __init__(self, min_size: int, eps: float):
segmenter = HDBScan(min_size=min_size, eps=eps, remove_largest=True)
metric = VariationOfInformation()
super().__init__(segmenter, metric)
self.init_kwargs = {"min_size": min_size, "eps": eps}


class MulticutVOIMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_seg_size, anisotropic=False, dt_threshold=0.25, sigma_seeds=2.0):
"""Variation of information metric based on a multicut computed from boundary predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_seg_size: The minimal segment size.
anisotropic: Whether to compute the watersheds in 2d for volumetric data.
dt_threshold: The threshold to apply to the boundary predictions before computing the distance transform.
sigma_seeds: The sigma value for smoothing the distance transform before computing seeds.
"""
def __init__(self, min_seg_size: int, anisotropic: bool = False, dt_threshold: float = 0.25, sigma_seeds: float = 2.0):
segmenter = Multicut(dt_threshold, anisotropic, sigma_seeds)
metric = VariationOfInformation()
super().__init__(segmenter, metric)
Expand All @@ -294,7 +377,17 @@ def __init__(self, min_seg_size, anisotropic=False, dt_threshold=0.25, sigma_see


class MulticutRandMetric(BaseInstanceSegmentationMetric):
def __init__(self, min_seg_size, anisotropic=False, dt_threshold=0.25, sigma_seeds=2.0):
"""Rand index metric based on a multicut computed from boundary predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
min_seg_size: The minimal segment size.
anisotropic: Whether to compute the watersheds in 2d for volumetric data.
dt_threshold: The threshold to apply to the boundary predictions before computing the distance transform.
sigma_seeds: The sigma value for smoothing the distance transform before computing seeds.
"""
def __init__(self, min_seg_size: int, anisotropic: bool = False, dt_threshold: float = 0.25, sigma_seeds: float = 2.0):
segmenter = Multicut(dt_threshold, anisotropic, sigma_seeds)
metric = AdaptedRandError()
super().__init__(segmenter, metric)
Expand All @@ -303,7 +396,17 @@ def __init__(self, min_seg_size, anisotropic=False, dt_threshold=0.25, sigma_see


class MWSIOUMetric(BaseInstanceSegmentationMetric):
def __init__(self, offsets, min_seg_size, iou_threshold=0.5, strides=None):
"""Intersection over union metric based on a mutex watershed computed from affinity predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
offsets: The offsets corresponding to the affinity channels.
min_seg_size: The minimal segment size.
iou_threshold: The threshold for the intersection over union value.
strides: The strides for the mutex watershed.
"""
def __init__(self, offsets: List[List[int]], min_seg_size: int, iou_threshold: float = 0.5, strides: Optional[List[int]] = None):
segmenter = MWS(offsets, with_background=True, min_seg_size=min_seg_size, strides=strides)
metric = IOUError(iou_threshold)
super().__init__(segmenter, metric)
Expand All @@ -312,23 +415,50 @@ def __init__(self, offsets, min_seg_size, iou_threshold=0.5, strides=None):


class MWSSBDMetric(BaseInstanceSegmentationMetric):
def __init__(self, offsets, min_seg_size, strides=None):
"""Symmetric best dice score metric based on a mutex watershed computed from affinity predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
offsets: The offsets corresponding to the affinity channels.
min_seg_size: The minimal segment size.
strides: The strides for the mutex watershed.
"""
def __init__(self, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
segmenter = MWS(offsets, with_background=True, min_seg_size=min_seg_size, strides=strides)
metric = SymmetricBestDice()
super().__init__(segmenter, metric)
self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}


class MWSVOIMetric(BaseInstanceSegmentationMetric):
def __init__(self, offsets, min_seg_size, strides=None):
"""Variation of information metric based on a mutex watershed computed from affinity predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
offsets: The offsets corresponding to the affinity channels.
min_seg_size: The minimal segment size.
strides: The strides for the mutex watershed.
"""
def __init__(self, offsets: List[List[int]], min_seg_size: int, strides: Optional[List[int]] = None):
segmenter = MWS(offsets, with_background=False, min_seg_size=min_seg_size, strides=strides)
metric = VariationOfInformation()
super().__init__(segmenter, metric)
self.init_kwargs = {"offsets": offsets, "min_seg_size": min_seg_size, "strides": strides}


class MWSRandMetric(BaseInstanceSegmentationMetric):
def __init__(self, offsets, min_seg_size, strides=None):
"""Rand index metric based on a mutex watershed computed from affinity predictions.
This class can be used as validation metric when training a network for instance segmentation.
Args:
offsets: The offsets corresponding to the affinity channels.
min_seg_size: The minimal segment size.
strides: The strides for the mutex watershed.
"""
def __init__(self, offsets: List[Listt[int]], min_seg_size: int, strides: Optional[List[int]] = None):
segmenter = MWS(offsets, with_background=False, min_seg_size=min_seg_size, strides=strides)
metric = AdaptedRandError()
super().__init__(segmenter, metric)
Expand Down
14 changes: 11 additions & 3 deletions torch_em/model/probabilistic_unet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""@private
"""

# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet
# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch

Expand All @@ -12,6 +15,8 @@


def truncated_normal_(tensor, mean=0, std=1):
"""@private
"""
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
Expand All @@ -21,15 +26,19 @@ def truncated_normal_(tensor, mean=0, std=1):


def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
"""@private
"""
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
# nn.init.normal_(m.weight, std=0.001)
# nn.init.normal_(m.bias, std=0.001)
truncated_normal_(m.bias, mean=0, std=0.001)


def init_weights_orthogonal_normal(m):
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
"""@private
"""
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.orthogonal_(m.weight)
truncated_normal_(m.bias, mean=0, std=0.001)
# nn.init.normal_(m.bias, std=0.001)
Expand All @@ -41,7 +50,6 @@ class Encoder(nn.Module):
convolutional layers, after each block a pooling operation is performed.
And after each convolutional layer a non-linear (ReLU) activation function is applied.
"""

def __init__(
self,
input_channels,
Expand Down
Loading

0 comments on commit 65bc76e

Please sign in to comment.