Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vision mamba backbone to latest updates #401

Merged
merged 10 commits into from
Nov 13, 2024
70 changes: 23 additions & 47 deletions experiments/vision-mamba/vimunet/run_cremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch_em
from torch_em.loss import DiceLoss
from torch_em.util import segmentation
from torch_em.data import MinInstanceSampler
from torch_em.model import get_vimunet_model
from torch_em.data.datasets import get_cremi_loader
from torch_em.util.prediction import predict_with_halo
Expand All @@ -28,70 +27,49 @@
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original"


def get_loaders(args, patch_shape=(1, 512, 512)):
def get_loaders(input):
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

sampler = MinInstanceSampler()

train_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=2,
rois=train_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
val_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=1,
rois=val_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
kwargs = {
"path": input,
"patch_shape": (1, 512, 512),
"ndim": 2,
"label_dtype": torch.float32,
"defect_augmentation_kwargs": None,
"boundaries": True,
"num_workers": 16,
"download": True,
"shuffle": True,
}

train_loader = get_cremi_loader(batch_size=2, rois=train_rois, **kwargs)
val_loader = get_cremi_loader(batch_size=1, rois=val_rois, **kwargs)
return train_loader, val_loader


def run_cremi_training(args):
# the dataloaders for cremi dataset
train_loader, val_loader = get_loaders(args)
train_loader, val_loader = get_loaders(input=args.input)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True
)

model = get_vimunet_model(out_channels=1, model_type=args.model_type, with_cls_token=True)
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)

# loss function
loss = DiceLoss()

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
name="cremi-vimunet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
loss=loss,
metric=loss,
loss=DiceLoss(),
metric=DiceLoss(),
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10},
mixed_precision=False,
)
trainer.fit(iterations=int(1e5))

Expand All @@ -102,10 +80,7 @@ def run_cremi_inference(args, device):

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint
out_channels=1, model_type=args.model_type, with_cls_token=True, checkpoint=checkpoint
)

all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif"))
Expand Down Expand Up @@ -134,6 +109,7 @@ def run_cremi_inference(args, device):
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}

res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
Expand Down
104 changes: 37 additions & 67 deletions experiments/vision-mamba/vimunet/run_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,93 +22,66 @@
ROOT = "/scratch/usr/nimanwai"


def get_loaders(args, patch_shape=(512, 512)):
if args.distances:
def get_loaders(input, boundaries, distances):
label_trafo = None
if distances:
label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25
distances=True, boundary_distances=True, directed_distances=False, foreground=True, min_size=25,
)
else:
label_trafo = None

train_loader = get_livecell_loader(
path=args.input,
split="train",
patch_shape=patch_shape,
batch_size=2,
label_dtype=torch.float32,
boundaries=args.boundaries,
label_transform=label_trafo,
num_workers=16,
download=True,
)
val_loader = get_livecell_loader(
path=args.input,
split="val",
patch_shape=patch_shape,
batch_size=1,
label_dtype=torch.float32,
boundaries=args.boundaries,
label_transform=label_trafo,
num_workers=16,
download=True,
)

kwargs = {
"path": input,
"patch_shape": (512, 512),
"label_dtype": torch.float32,
"boundaries": boundaries,
"label_transform": label_trafo,
"num_workers": 16,
"download": True,
"shuffle": True,
}

train_loader = get_livecell_loader(split="train", batch_size=2, **kwargs)
val_loader = get_livecell_loader(split="val", batch_size=1, **kwargs)
return train_loader, val_loader


def get_output_channels(args):
if args.boundaries:
def get_output_channels(boundaries):
if boundaries:
output_channels = 2
else:
output_channels = 3

return output_channels


def get_loss_function(args):
if args.distances:
def get_loss_function(distances):
if distances:
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)

else:
loss = DiceLoss()

return loss


def get_save_root(args):
def get_save_root(boundaries, model_type, save_root):
# experiment_type
if args.boundaries:
if boundaries:
experiment_type = "boundaries"
else:
experiment_type = "distances"

model_name = args.model_type

# saving the model checkpoints
save_root = os.path.join(args.save_root, "scratch", experiment_type, model_name)
save_root = os.path.join(save_root, "scratch", experiment_type, model_type)
return save_root


def run_livecell_training(args):
# the dataloaders for livecell dataset
train_loader, val_loader = get_loaders(args)

output_channels = get_output_channels(args)
train_loader, val_loader = get_loaders(input=args.input, boundaries=args.boundaries, distances=args.distances)
output_channels = get_output_channels(boundaries=args.boundaries)
loss = get_loss_function(distances=args.distances)
save_root = get_save_root(boundaries=args.boundaries, model_type=args.model_type, save_root=args.save_root)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
with_cls_token=True,
)

save_root = get_save_root(args)

# loss function
loss = get_loss_function(args)
model = get_vimunet_model(out_channels=output_channels, model_type=args.model_type, with_cls_token=True)

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
Expand All @@ -122,24 +95,20 @@ def run_livecell_training(args):
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10},
mixed_precision=False,
)
trainer.fit(iterations=int(1e5))


def run_livecell_inference(args, device):
output_channels = get_output_channels(args)

save_root = get_save_root(args)

output_channels = get_output_channels(boundaries=args.boundaries)
save_root = get_save_root(boundaries=args.boundaries, model_type=args.model_type, save_root=args.save_root)
checkpoint = os.path.join(save_root, "checkpoints", "livecell-vimunet", "best.pt")

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=output_channels,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint,
out_channels=output_channels, model_type=args.model_type, with_cls_token=True, checkpoint=checkpoint,
)

# the splits are provided with the livecell dataset
Expand All @@ -149,7 +118,7 @@ def run_livecell_inference(args, device):
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))

msa_list, sa50_list, sa75_list = [], [], []
for label_path in tqdm(all_test_labels):
for label_path in tqdm(all_test_labels, desc="Prediction for LIVECell"):
labels = imageio.imread(label_path)
image_id = os.path.split(label_path)[-1]

Expand Down Expand Up @@ -184,6 +153,7 @@ def run_livecell_inference(args, device):
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}

res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
Expand Down
2 changes: 1 addition & 1 deletion torch_em/data/datasets/medical/cbis_ddsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_cbis_ddsm_paths(
path: Union[os.PathLike, str],
split: Literal['Train', 'Val', 'Test'],
task: Literal['Calc', 'Mass'],
tumour_type: Literal['MALIGNANT', 'BENIGN'],
tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None,
download: bool = False
):
"""Get paths to the CBIS DDSM data.
Expand Down
8 changes: 4 additions & 4 deletions torch_em/data/datasets/medical/piccolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def get_piccolo_data(path: Union[os.PathLike, str], download: bool = False) -> s
Returns:
Filepath where the data is downloaded.
"""
data_dir = os.path.join(path, r"piccolo dataset-release0.1")
if os.path.exists(data_dir):
return data_dir

if download:
raise NotImplementedError(
"Automatic download is not possible for this dataset. See 'get_piccolo_data' for details."
)

data_dir = os.path.join(path, r"piccolo dataset-release0.1")
if os.path.exists(data_dir):
return data_dir

rar_file = os.path.join(path, r"piccolo dataset_widefield-release0.1.rar")
if not os.path.exists(rar_file):
raise FileNotFoundError(
Expand Down
20 changes: 6 additions & 14 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch.nn as nn
import torch.nn.functional as F

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer
from .unet import Decoder, ConvBlock2d, Upsampler2d

try:
from micro_sam.util import get_sam_model
Expand All @@ -22,18 +22,15 @@
class UNETR(nn.Module):

def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

"""Function to load pretrained weights to the image encoder.
"""
if isinstance(checkpoint, str):
if backbone == "sam" and isinstance(encoder, str):
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
try:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=checkpoint,
return_sam=True
)
_, model = get_sam_model(model_type=encoder, checkpoint_path=checkpoint, return_sam=True)
encoder_state = model.image_encoder.state_dict()
except Exception:
# Try loading the encoder state directly from a checkpoint.
Expand All @@ -47,8 +44,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
k: v for k, v in encoder_state.items()
if (k != "mask_token" and not k.startswith("decoder"))
})

# let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
# Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
current_encoder_state = self.encoder.state_dict()
if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
del self.encoder.head
Expand All @@ -72,7 +68,7 @@ def __init__(
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None,
use_conv_transpose=True,
use_conv_transpose: bool = True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -150,15 +146,11 @@ def __init__(
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv_out = _upsampler(
scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
)

self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down
Loading