@@ -173,7 +173,8 @@ def train_epoch(cfg, model, train_data_loader, loss, optimzer):
173
173
174
174
with torch .no_grad ():
175
175
batch_gt [:,:,:,2 ] = batch_gt [:,:,:,2 ] - batch_gt [:,0 :1 ,0 :1 ,2 ] # Place the depth of first frame root to 0
176
- batch_input = cfg .aug .augment2D (batch_input , mask = cfg .mask , noise = cfg .noise )
176
+ if cfg .noise or cfg .mask :
177
+ batch_input = cfg .aug .augment2D (batch_input , mask = cfg .mask , noise = cfg .noise )
177
178
178
179
predicted_3d_pos = model (batch_input ) # (N,T,17,3)
179
180
@@ -287,7 +288,7 @@ def train(args, cfg):
287
288
288
289
# Load checkpoint if given
289
290
if checkpoint :
290
- st = checkpoint ['epoch' ]
291
+ st = checkpoint ['epoch' ] if not cfg . finetune else 0
291
292
if 'optimizer' in checkpoint and checkpoint ['optimizer' ] != None :
292
293
optimizer .load_state_dict (checkpoint ['optimizer' ])
293
294
else :
@@ -299,7 +300,7 @@ def train(args, cfg):
299
300
cfg .mask = (cfg .mask_ratio > 0 and cfg .mask_T_ratio > 0 )
300
301
if cfg .mask or cfg .noise :
301
302
cfg .aug = Augmenter2D (cfg ) # Data Augmentation: flip and add noise
302
-
303
+
303
304
for epoch in range (st , args .epochs or cfg .epochs ): # Start training
304
305
print ("Training Epoch %d" % (epoch + 1 ))
305
306
start_time = time ()
0 commit comments