-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utils.py
109 lines (98 loc) · 3.95 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import matplotlib.pyplot as plt
def train(teacher_model,student_model, train_loader, optimizer, criterion, device,mode):
if mode == "distil":
teacher_model.eval()
student_model.train()
total_loss = 0
correct_pred = 0
total_pred = 0
for batch in train_loader:
optimizer.zero_grad()
X, mask, y = batch
X, mask, y = X.to(device), mask.to(device), y.to(device)
with torch.no_grad():
teacher_output = teacher_model(X, mask)
student_output = student_model(X, mask)
T = 2
distill_loss = torch.nn.KLDivLoss()(torch.nn.functional.softmax(student_output, dim=1), torch.nn.functional.softmax(teacher_output , dim=1)) / (X.size(0) * T * T)
alpha = 0.8
cross_entropy_loss = criterion(student_output, y)
loss = alpha * distill_loss + (1 - alpha) * cross_entropy_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(student_output, 1)
correct_pred += (predicted == y).sum().item()
total_pred += len(y)
elif mode == "LoRA":
teacher_model.train()
total_loss = 0
correct_pred = 0
total_pred = 0
for batch in train_loader:
optimizer.zero_grad()
X, mask, y = batch
X, mask, y = X.to(device), mask.to(device), y.to(device)
output = teacher_model(X, mask)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(output, 1)
correct_pred += (predicted == y).sum().item()
total_pred += len(y)
elif mode == "rnn":
student_model.train()
total_loss = 0
correct_pred = 0
total_pred = 0
for batch in train_loader:
optimizer.zero_grad()
X, mask, y = batch
X, mask, y = X.to(device), mask.to(device), y.to(device)
output = student_model(X, mask)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(output, 1)
correct_pred += (predicted == y).sum().item()
total_pred += len(y)
return total_loss / len(train_loader), correct_pred / total_pred
def evaluate(model, val_loader, criterion, device):
model.eval()
total_loss = 0
correct_pred = 0
total_pred = 0
with torch.no_grad():
for batch in val_loader:
X, mask, y = batch
X, mask, y = X.to(device), mask.to(device), y.to(device)
output = model(X, mask)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = torch.max(output, 1)
correct_pred += (predicted == y).sum().item()
total_pred += len(y)
return total_loss / len(val_loader) , correct_pred / total_pred
def plot_losses(train_losses, val_losses, mode, args):
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title(f'{mode} Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f'plots/{mode}_{args.sr_no}_loss.png')
plt.close()
print(f"Plots saved at plots/{mode}_{args.sr_no}_loss.png")
def plot_metrics(train_accs, val_accs, mode, args):
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.title(f'{mode} Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(f'plots/{mode}_{args.sr_no}_acc.png')
plt.close()
print(f"Plots saved at plots/{mode}_{args.sr_no}_acc.png")