Skip to content

Commit a7c46e6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent aea7263 commit a7c46e6

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

vista3d/cvpr_workshop/infer_cvpr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from train_cvpr import ROI_SIZE
1818

19+
1920
def convert_clicks(alldata):
2021
# indexes = list(alldata.keys())
2122
# data = [alldata[i] for i in indexes]

vista3d/cvpr_workshop/train_cvpr.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import matplotlib.pyplot as plt
2323

2424
NUM_PATCHES_PER_IMAGE = 2
25-
ROI_SIZE= [128, 128, 128]
25+
ROI_SIZE = [128, 128, 128]
26+
2627

2728
def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs):
2829
"""
@@ -109,7 +110,7 @@ def __getitem__(self, idx):
109110
keys=["image", "label"],
110111
label_key="label",
111112
num_classes=label.max() + 1,
112-
ratios=tuple(float(i > 0) for i in range(label.max()+1)),
113+
ratios=tuple(float(i > 0) for i in range(label.max() + 1)),
113114
num_samples=NUM_PATCHES_PER_IMAGE,
114115
),
115116
monai.transforms.RandScaleIntensityd(
@@ -137,17 +138,19 @@ def __getitem__(self, idx):
137138
mode=["constant", "constant"],
138139
keys=["image", "label"],
139140
spatial_size=ROI_SIZE,
140-
)
141+
),
141142
]
142143
)
143144
data = transforms(data)
144145
return data
145146

147+
146148
import re
147149

150+
148151
def get_latest_epoch(directory):
149152
# Pattern to match filenames like 'model_epoch<number>.pth'
150-
pattern = re.compile(r'model_epoch(\d+)\.pth')
153+
pattern = re.compile(r"model_epoch(\d+)\.pth")
151154
max_epoch = -1
152155

153156
for filename in os.listdir(directory):
@@ -159,6 +162,7 @@ def get_latest_epoch(directory):
159162

160163
return max_epoch if max_epoch != -1 else None
161164

165+
162166
# Training function
163167
def train():
164168
json_file = "allset.json" # Update with your JSON file
@@ -169,7 +173,6 @@ def train():
169173
start_epoch = get_latest_epoch(checkpoint_dir)
170174
start_checkpoint = "./CPRR25_vista3D_model_final_10percent_data.pth"
171175

172-
173176
os.makedirs(checkpoint_dir, exist_ok=True)
174177
dist.init_process_group(backend="nccl")
175178
world_size = int(os.environ["WORLD_SIZE"])
@@ -189,11 +192,12 @@ def train():
189192
model.load_state_dict(pretrained_ckpt, strict=True)
190193
else:
191194
print(f"Resuming from epoch {start_epoch}")
192-
pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
193-
model.load_state_dict(pretrained_ckpt['model'], strict=True)
195+
pretrained_ckpt = torch.load(
196+
os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")
197+
)
198+
model.load_state_dict(pretrained_ckpt["model"], strict=True)
194199
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
195200

196-
197201
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)
198202
lr_scheduler = monai.optimizers.WarmupCosineSchedule(
199203
optimizer=optimizer,
@@ -265,10 +269,16 @@ def train():
265269
if local_rank == 0:
266270
writer.add_scalar("loss", loss.item(), step)
267271
if local_rank == 0 and (epoch + 1) % save_interval == 0:
268-
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch + 1}.pth")
272+
checkpoint_path = os.path.join(
273+
checkpoint_dir, f"model_epoch{epoch + 1}.pth"
274+
)
269275
if world_size > 1:
270276
torch.save(
271-
{"model": model.module.state_dict(), "epoch": epoch + 1, "step": step},
277+
{
278+
"model": model.module.state_dict(),
279+
"epoch": epoch + 1,
280+
"step": step,
281+
},
272282
checkpoint_path,
273283
)
274284
print(

0 commit comments

Comments
 (0)