Skip to content

Commit f204b68

Browse files
committed
fix 2d augment
1 parent e7a3459 commit f204b68

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

train.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def train_epoch(cfg, model, train_data_loader, loss, optimzer):
173173

174174
with torch.no_grad():
175175
batch_gt[:,:,:,2] = batch_gt[:,:,:,2] - batch_gt[:,0:1,0:1,2] # Place the depth of first frame root to 0
176-
batch_input = cfg.aug.augment2D(batch_input, mask=cfg.mask, noise=cfg.noise)
176+
if cfg.noise or cfg.mask:
177+
batch_input = cfg.aug.augment2D(batch_input, mask=cfg.mask, noise=cfg.noise)
177178

178179
predicted_3d_pos = model(batch_input) # (N,T,17,3)
179180

@@ -287,7 +288,7 @@ def train(args, cfg):
287288

288289
# Load checkpoint if given
289290
if checkpoint:
290-
st = checkpoint['epoch']
291+
st = checkpoint['epoch'] if not cfg.finetune else 0
291292
if 'optimizer' in checkpoint and checkpoint['optimizer'] != None:
292293
optimizer.load_state_dict(checkpoint['optimizer'])
293294
else:
@@ -299,7 +300,7 @@ def train(args, cfg):
299300
cfg.mask = (cfg.mask_ratio > 0 and cfg.mask_T_ratio > 0)
300301
if cfg.mask or cfg.noise:
301302
cfg.aug = Augmenter2D(cfg) # Data Augmentation: flip and add noise
302-
303+
303304
for epoch in range(st, args.epochs or cfg.epochs): # Start training
304305
print("Training Epoch %d" % (epoch + 1))
305306
start_time = time()

0 commit comments

Comments
 (0)