Skip to content

Commit

Permalink
corrected prediction module
Browse files Browse the repository at this point in the history
  • Loading branch information
Soorya19Pradeep committed Mar 26, 2024
1 parent 82428ed commit b470ed1
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions examples/infection_phenotyping/test_infection_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from viscy.data.hcs import Sample
import lightning.pytorch as pl
import torch

import torchmetrics
from viscy.light.predict_writer import HCSPredictionWriter
from monai.transforms import DivisiblePad

Expand All @@ -16,7 +16,7 @@
data_module = HCSDataModule(
test_datapath,
source_channel=["Sensor", "Phase"],
target_channel=[],
target_channel=["inf_mask"],
split_ratio=0.8,
z_window_size=1,
architecture="2D",
Expand All @@ -36,8 +36,6 @@
data_module.prepare_data()

data_module.setup(stage="predict")
test_dm = data_module.test_dataloader()
sample = next(iter(test_dm))

# %%
class LightningUNet(pl.LightningModule):
Expand All @@ -49,6 +47,9 @@ def __init__(
):
super(LightningUNet, self).__init__()
self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels)
# self.pred_cm = torchmetrics.classification.ConfusionMatrix(
# task="multiclass", num_classes=self.n_classes
# )
if ckpt_path is not None:
state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
"state_dict"
Expand All @@ -62,8 +63,8 @@ def forward(self, x):
def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
source = self._predict_pad(batch["source"])
pred_class = self.forward(source)
pred_int = torch.argmax(pred_class, dim=4, keepdim=True)
return self._predict_pad.inverse(pred_int)
pred_int = torch.argmax(pred_class, dim=1, keepdim=True)
return pred_int

def on_predict_start(self):
"""Pad the input shape to be divisible by the downsampling factor.
Expand All @@ -79,7 +80,7 @@ def on_predict_start(self):

trainer = pl.Trainer(
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase",
callbacks=[HCSPredictionWriter(output_path, write_input=True)],
callbacks=[HCSPredictionWriter(output_path, write_input=False)],
)
model = LightningUNet(
in_channels=2,
Expand Down

0 comments on commit b470ed1

Please sign in to comment.