2222import matplotlib .pyplot as plt
2323
2424NUM_PATCHES_PER_IMAGE = 2
25- ROI_SIZE = [128 , 128 , 128 ]
25+ ROI_SIZE = [128 , 128 , 128 ]
26+
2627
2728def plot_to_tensorboard (writer , epoch , inputs , labels , points , outputs ):
2829 """
@@ -109,7 +110,7 @@ def __getitem__(self, idx):
109110 keys = ["image" , "label" ],
110111 label_key = "label" ,
111112 num_classes = label .max () + 1 ,
112- ratios = tuple (float (i > 0 ) for i in range (label .max ()+ 1 )),
113+ ratios = tuple (float (i > 0 ) for i in range (label .max () + 1 )),
113114 num_samples = NUM_PATCHES_PER_IMAGE ,
114115 ),
115116 monai .transforms .RandScaleIntensityd (
@@ -137,17 +138,19 @@ def __getitem__(self, idx):
137138 mode = ["constant" , "constant" ],
138139 keys = ["image" , "label" ],
139140 spatial_size = ROI_SIZE ,
140- )
141+ ),
141142 ]
142143 )
143144 data = transforms (data )
144145 return data
145146
147+
146148import re
147149
150+
148151def get_latest_epoch (directory ):
149152 # Pattern to match filenames like 'model_epoch<number>.pth'
150- pattern = re .compile (r' model_epoch(\d+)\.pth' )
153+ pattern = re .compile (r" model_epoch(\d+)\.pth" )
151154 max_epoch = - 1
152155
153156 for filename in os .listdir (directory ):
@@ -159,6 +162,7 @@ def get_latest_epoch(directory):
159162
160163 return max_epoch if max_epoch != - 1 else None
161164
165+
162166# Training function
163167def train ():
164168 json_file = "allset.json" # Update with your JSON file
@@ -169,7 +173,6 @@ def train():
169173 start_epoch = get_latest_epoch (checkpoint_dir )
170174 start_checkpoint = "./CPRR25_vista3D_model_final_10percent_data.pth"
171175
172-
173176 os .makedirs (checkpoint_dir , exist_ok = True )
174177 dist .init_process_group (backend = "nccl" )
175178 world_size = int (os .environ ["WORLD_SIZE" ])
@@ -189,11 +192,12 @@ def train():
189192 model .load_state_dict (pretrained_ckpt , strict = True )
190193 else :
191194 print (f"Resuming from epoch { start_epoch } " )
192- pretrained_ckpt = torch .load (os .path .join (checkpoint_dir , f"model_epoch{ start_epoch } .pth" ))
193- model .load_state_dict (pretrained_ckpt ['model' ], strict = True )
195+ pretrained_ckpt = torch .load (
196+ os .path .join (checkpoint_dir , f"model_epoch{ start_epoch } .pth" )
197+ )
198+ model .load_state_dict (pretrained_ckpt ["model" ], strict = True )
194199 model = DDP (model , device_ids = [local_rank ], find_unused_parameters = True )
195200
196-
197201 optimizer = optim .AdamW (model .parameters (), lr = lr , weight_decay = 1.0e-05 )
198202 lr_scheduler = monai .optimizers .WarmupCosineSchedule (
199203 optimizer = optimizer ,
@@ -265,10 +269,16 @@ def train():
265269 if local_rank == 0 :
266270 writer .add_scalar ("loss" , loss .item (), step )
267271 if local_rank == 0 and (epoch + 1 ) % save_interval == 0 :
268- checkpoint_path = os .path .join (checkpoint_dir , f"model_epoch{ epoch + 1 } .pth" )
272+ checkpoint_path = os .path .join (
273+ checkpoint_dir , f"model_epoch{ epoch + 1 } .pth"
274+ )
269275 if world_size > 1 :
270276 torch .save (
271- {"model" : model .module .state_dict (), "epoch" : epoch + 1 , "step" : step },
277+ {
278+ "model" : model .module .state_dict (),
279+ "epoch" : epoch + 1 ,
280+ "step" : step ,
281+ },
272282 checkpoint_path ,
273283 )
274284 print (
0 commit comments