Skip to content

Commit 4401f33

Browse files
Soorya19Pradeepziw-liuedyoshikunmattersoflight
authored
Cellular infection phenotyping using annotated viral sensor data & label-free images (#70)
* refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu <[email protected]> * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * added model and annotation code draft * chnaged to simple unet model * start with lesser augmentations * added readme file * added tensorboard logging * added validation step * chnaged to viscy 2d unet * used crossentropyloss with one-hot encoding * added sample image logging * attempt to build magicgui annotation * renamed infection annotation tool * added normalization and augmentations * added model testing code * removed annotation refiner * corrected conversion of class to int * corrected prediction module * cleaned up the code and comments for the LightningUNet * removed confusion matrix code, finding runtime error with model * moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts * combine the lightning modules for training and prediction, fix the DDP exception * all the stubs for computing and logging confusion matrix per cell * separated training and test scripts * lightning module * corrected test cm compute * corrected test module * separated test and prediction scripts * changed confusion matrix compute * fix merge error * split 2D and 2.5D model scripts * added covnext script * fix model input parameter * update input file * add augmentations * refactor infection_classification code to viscy/applications * changes made for BJ5 classification * format code * add explicit packaging list * rename testing script * update readme * move function to preprocessing * format code * formatting * histogram with dask * fix index and test * fix import * black * fix float comp * clean up headers * clean up import * add argument to change number of classes --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]> Co-authored-by: Shalin Mehta <[email protected]> Co-authored-by: Ziwen Liu <[email protected]>
1 parent dbf4ddc commit 4401f33

12 files changed

+1588
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# %%
2+
import lightning.pytorch as pl
3+
import torch
4+
import torch.nn as nn
5+
from applications.infection_classification.classify_infection_25D import (
6+
SemanticSegUNet25D,
7+
)
8+
from pytorch_lightning.callbacks import ModelCheckpoint
9+
from pytorch_lightning.loggers import TensorBoardLogger
10+
11+
from viscy.data.hcs import HCSDataModule
12+
from viscy.preprocessing.pixel_ratio import sematic_class_weights
13+
from viscy.transforms import NormalizeSampled, RandWeightedCropd
14+
15+
# %% Create a dataloader and visualize the batches.
16+
17+
# Set the path to the dataset
18+
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr"
19+
20+
# %% create data module
21+
22+
# Create an instance of HCSDataModule
23+
data_module = HCSDataModule(
24+
dataset_path,
25+
source_channel=["Phase", "HSP90"],
26+
target_channel=["Inf_mask"],
27+
yx_patch_size=[512, 512],
28+
split_ratio=0.8,
29+
z_window_size=5,
30+
architecture="2.5D",
31+
num_workers=3,
32+
batch_size=32,
33+
normalizations=[
34+
NormalizeSampled(
35+
keys=["Phase", "HSP90"],
36+
level="fov_statistics",
37+
subtrahend="median",
38+
divisor="iqr",
39+
)
40+
],
41+
augmentations=[
42+
RandWeightedCropd(
43+
num_samples=4,
44+
spatial_size=[-1, 512, 512],
45+
keys=["Phase", "HSP90"],
46+
w_key="Inf_mask",
47+
)
48+
],
49+
)
50+
51+
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")
52+
53+
# Prepare the data
54+
data_module.prepare_data()
55+
56+
# Setup the data
57+
data_module.setup(stage="fit")
58+
59+
# Create a dataloader
60+
train_dm = data_module.train_dataloader()
61+
62+
val_dm = data_module.val_dataloader()
63+
64+
65+
# %% Define the logger
66+
logger = TensorBoardLogger(
67+
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
68+
name="logs",
69+
)
70+
71+
# Pass the logger to the Trainer
72+
trainer = pl.Trainer(
73+
logger=logger,
74+
max_epochs=200,
75+
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
76+
log_every_n_steps=1,
77+
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
78+
)
79+
80+
# Define the checkpoint callback
81+
checkpoint_callback = ModelCheckpoint(
82+
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
83+
filename="checkpoint_{epoch:02d}",
84+
save_top_k=-1,
85+
verbose=True,
86+
monitor="loss/validate",
87+
mode="min",
88+
)
89+
90+
# Add the checkpoint callback to the trainer
91+
trainer.callbacks.append(checkpoint_callback)
92+
93+
# Fit the model
94+
model = SemanticSegUNet25D(
95+
in_channels=2,
96+
out_channels=3,
97+
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
98+
)
99+
100+
print(model)
101+
102+
# %% Run training.
103+
104+
trainer.fit(model, data_module)
105+
106+
# %%
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# %%
2+
import lightning.pytorch as pl
3+
import torch
4+
import torch.nn as nn
5+
from applications.infection_classification.classify_infection_2D import (
6+
SemanticSegUNet2D,
7+
)
8+
from pytorch_lightning.callbacks import ModelCheckpoint
9+
from pytorch_lightning.loggers import TensorBoardLogger
10+
11+
from viscy.data.hcs import HCSDataModule
12+
from viscy.preprocessing.pixel_ratio import sematic_class_weights
13+
from viscy.transforms import (
14+
NormalizeSampled,
15+
RandGaussianSmoothd,
16+
RandScaleIntensityd,
17+
RandWeightedCropd,
18+
)
19+
20+
# %% calculate the ratio of background, uninfected and infected pixels in the input dataset
21+
22+
# Set the path to the dataset
23+
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr"
24+
25+
# %% Create an instance of HCSDataModule
26+
27+
data_module = HCSDataModule(
28+
dataset_path,
29+
source_channel=["TXR_Density3D", "Phase3D"],
30+
target_channel=["Inf_mask"],
31+
yx_patch_size=[128, 128],
32+
split_ratio=0.7,
33+
z_window_size=1,
34+
architecture="2D",
35+
num_workers=1,
36+
batch_size=256,
37+
normalizations=[
38+
NormalizeSampled(
39+
keys=["Phase3D", "TXR_Density3D"],
40+
level="fov_statistics",
41+
subtrahend="median",
42+
divisor="iqr",
43+
)
44+
],
45+
augmentations=[
46+
RandWeightedCropd(
47+
num_samples=16,
48+
spatial_size=[-1, 128, 128],
49+
keys=["TXR_Density3D", "Phase3D", "Inf_mask"],
50+
w_key="Inf_mask",
51+
),
52+
RandScaleIntensityd(
53+
keys=["TXR_Density3D", "Phase3D"],
54+
factors=[0.5, 0.5],
55+
prob=0.5,
56+
),
57+
RandGaussianSmoothd(
58+
keys=["TXR_Density3D", "Phase3D"],
59+
prob=0.5,
60+
sigma_x=[0.5, 1.0],
61+
sigma_y=[0.5, 1.0],
62+
sigma_z=[0.5, 1.0],
63+
),
64+
],
65+
)
66+
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")
67+
68+
# Prepare the data
69+
data_module.prepare_data()
70+
71+
# Setup the data
72+
data_module.setup(stage="fit")
73+
74+
# Create a dataloader
75+
train_dm = data_module.train_dataloader()
76+
77+
val_dm = data_module.val_dataloader()
78+
79+
# %% Set up for training
80+
81+
# define the logger
82+
logger = TensorBoardLogger(
83+
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/",
84+
name="logs",
85+
)
86+
87+
# Pass the logger to the Trainer
88+
trainer = pl.Trainer(
89+
logger=logger,
90+
max_epochs=500,
91+
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
92+
log_every_n_steps=1,
93+
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
94+
)
95+
96+
# Define the checkpoint callback
97+
checkpoint_callback = ModelCheckpoint(
98+
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
99+
filename="checkpoint_{epoch:02d}",
100+
save_top_k=-1,
101+
verbose=True,
102+
monitor="loss/validate",
103+
mode="min",
104+
)
105+
106+
# Add the checkpoint callback to the trainer
107+
trainer.callbacks.append(checkpoint_callback)
108+
109+
# Fit the model
110+
model = SemanticSegUNet2D(
111+
in_channels=2,
112+
out_channels=3,
113+
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
114+
)
115+
116+
# visualize the model
117+
print(model)
118+
119+
# %% Run training.
120+
121+
trainer.fit(model, data_module)
122+
123+
# %%
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# %%
2+
# import sys
3+
# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/")
4+
import lightning.pytorch as pl
5+
import torch
6+
import torch.nn as nn
7+
from applications.infection_classification.classify_infection_covnext import (
8+
SemanticSegUNet22D,
9+
)
10+
from pytorch_lightning.callbacks import ModelCheckpoint
11+
from pytorch_lightning.loggers import TensorBoardLogger
12+
13+
from viscy.data.hcs import HCSDataModule
14+
from viscy.preprocessing.pixel_ratio import sematic_class_weights
15+
from viscy.transforms import NormalizeSampled, RandWeightedCropd
16+
17+
# %% Create a dataloader and visualize the batches.
18+
19+
# Set the path to the dataset
20+
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr"
21+
22+
# %% craete data module
23+
24+
# Create an instance of HCSDataModule
25+
data_module = HCSDataModule(
26+
dataset_path,
27+
source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
28+
target_channel=["Inf_mask"],
29+
yx_patch_size=[256, 256],
30+
split_ratio=0.8,
31+
z_window_size=5,
32+
architecture="2.2D",
33+
num_workers=3,
34+
batch_size=16,
35+
normalizations=[
36+
NormalizeSampled(
37+
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
38+
level="fov_statistics",
39+
subtrahend="median",
40+
divisor="iqr",
41+
)
42+
],
43+
augmentations=[
44+
RandWeightedCropd(
45+
num_samples=4,
46+
spatial_size=[-1, 256, 256],
47+
keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
48+
w_key="Inf_mask",
49+
)
50+
],
51+
)
52+
pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")
53+
54+
# Prepare the data
55+
data_module.prepare_data()
56+
57+
# Setup the data
58+
data_module.setup(stage="fit")
59+
60+
# Create a dataloader
61+
train_dm = data_module.train_dataloader()
62+
63+
val_dm = data_module.val_dataloader()
64+
65+
66+
# %% Define the logger
67+
logger = TensorBoardLogger(
68+
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
69+
name="logs",
70+
)
71+
72+
# Pass the logger to the Trainer
73+
trainer = pl.Trainer(
74+
logger=logger,
75+
max_epochs=200,
76+
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
77+
log_every_n_steps=1,
78+
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
79+
)
80+
81+
# Define the checkpoint callback
82+
checkpoint_callback = ModelCheckpoint(
83+
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
84+
filename="checkpoint_{epoch:02d}",
85+
save_top_k=-1,
86+
verbose=True,
87+
monitor="loss/validate",
88+
mode="min",
89+
)
90+
91+
# Add the checkpoint callback to the trainer``
92+
trainer.callbacks.append(checkpoint_callback)
93+
94+
# Fit the model
95+
model = SemanticSegUNet22D(
96+
in_channels=4,
97+
out_channels=3,
98+
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
99+
)
100+
101+
print(model)
102+
103+
# %% Run training.
104+
105+
trainer.fit(model, data_module)
106+
107+
# %%

0 commit comments

Comments
 (0)