Skip to content

Commit 4d6e4c5

Browse files
committed
fixing the loading of weights
1 parent 6d58061 commit 4d6e4c5

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

viscy/light/engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,7 @@ def __init__(
146146
self.schedule = schedule
147147
self.log_batches_per_epoch = log_batches_per_epoch
148148
self.log_samples_per_batch = log_samples_per_batch
149-
if chkpt_path is not None:
150-
self.model.load_state_dict(
151-
torch.load(chkpt_path)["state_dict"], strict=False
152-
) # loading only weights
149+
153150
self.training_step_outputs = []
154151
self.validation_step_outputs = []
155152
# required to log the graph
@@ -166,6 +163,10 @@ def __init__(
166163
self.test_cellpose_model_path = test_cellpose_model_path
167164
self.test_cellpose_diameter = test_cellpose_diameter
168165
self.test_evaluate_cellpose = test_evaluate_cellpose
166+
if chkpt_path is not None:
167+
self.load_state_dict(
168+
torch.load(chkpt_path)["state_dict"]
169+
) # loading only weights
169170

170171
def forward(self, x) -> torch.Tensor:
171172
return self.model(x)

0 commit comments

Comments
 (0)