-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
124 lines (103 loc) · 5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pdb
import numpy as np
from utils import cal_performance, normalize_duration
def train(args, model, train_loader, optimizer, scheduler, criterion, model_save_path, pad_idx, device):
model.to(device)
print("Training Start")
for epoch in range(args.epochs):
epoch_loss = 0
epoch_loss_class = 0
epoch_loss_dur = 0
epoch_loss_seg = 0
total_class = 0
total_class_correct = 0
total_seg = 0
total_seg_correct = 0
for i, data in enumerate(train_loader):
optimizer.zero_grad()
features, past_label, trans_dur_future, trans_future_target, detections, gt_nodes, imp_weights = data
'''
gt_nodes: List of length = batch size, each element is size = (num_future_action, num_nodes)
imp_weights: List of length = batch size, each element is size = (num_actions)
'''
features = features.to(device) #[B, S, C]
past_label = past_label.to(device) #[B, S]
trans_dur_future = trans_dur_future.to(device)
trans_future_target = trans_future_target.to(device)
trans_dur_future_mask = (trans_dur_future != pad_idx).long().to(device)
B = trans_dur_future.size(0)
target_dur = trans_dur_future*trans_dur_future_mask
target = trans_future_target
if args.input_type == 'i3d_transcript':
inputs = (features, past_label)
elif args.input_type == 'gt':
gt_features = past_label.int()
inputs = (gt_features, past_label)
detected_objs = detections if args.kg_attn == True and args.use_gsnn else None
target_nodes = gt_nodes if args.kg_attn == True and args.use_gsnn else None
outputs, importance_loss, _ = model(inputs, detected_objs, target_nodes)
losses = 0
if args.seg :
output_seg = outputs['seg']
B, T, C = output_seg.size()
output_seg = output_seg.view(-1, C).to(device)
target_past_label = past_label.view(-1)
loss_seg, n_seg_correct, n_seg_total = cal_performance(output_seg, target_past_label, pad_idx)
losses += loss_seg
if args.kg_attn == True and args.use_gsnn: losses = losses + importance_loss
total_seg += n_seg_total
total_seg_correct += n_seg_correct
epoch_loss_seg += loss_seg.item()
if args.anticipate :
output = outputs['action']
B, T, C = output.size()
output = output.view(-1, C).to(device)
target = target.contiguous().view(-1)
out = output.max(1)[1] #oneshot
out = out.view(B, -1)
loss, n_correct, n_total = cal_performance(output, target, pad_idx)
acc = n_correct / n_total
loss_class = loss.item()
losses += loss
total_class += n_total
total_class_correct += n_correct
epoch_loss_class += loss_class
output_dur = outputs['duration']
output_dur = normalize_duration(output_dur, trans_dur_future_mask)
target_dur = target_dur * trans_dur_future_mask
loss_dur = torch.sum(criterion(output_dur, target_dur)) / \
torch.sum(trans_dur_future_mask)
losses += loss_dur
if args.kg_attn == True and args.use_gsnn: losses = losses + importance_loss
epoch_loss_dur += loss_dur.item()
epoch_loss += losses.item()
losses.backward()
optimizer.step()
epoch_loss = epoch_loss / (i+1)
print("Epoch [", (epoch+1), '/', args.epochs, '] Loss : %.3f'%epoch_loss)
if args.anticipate :
accuracy = total_class_correct/total_class
epoch_loss_class = epoch_loss_class / (i+1)
print('Training Acc :%.3f'%accuracy, 'CE loss :%.3f'%epoch_loss_class )
if args.task == 'long' :
epoch_loss_dur = epoch_loss_dur / (i+1)
print('dur loss: %.5f'%epoch_loss_dur)
if args.seg :
acc_seg = total_seg_correct / total_seg
epoch_loss_seg = epoch_loss_seg / (i+1)
print('seg loss :%.3f'%epoch_loss_seg, ', seg acc : %.5f'%acc_seg)
if args.kg_attn == True and args.use_gsnn == True:
print('imp loss :%.3f'%importance_loss.item())
scheduler.step()
save_path = os.path.join(model_save_path)
if epoch >= 30 :
print ('Saving to ', os.path.join(save_path, 'checkpoint'+str(epoch)+'.ckpt'))
save_file = os.path.join(save_path, 'checkpoint'+str(epoch)+'.ckpt')
torch.save(model.state_dict(), save_file)
if not os.path.exists(save_path):
os.makedirs(save_path)
return model