Skip to content

Commit 08b83b4

Browse files
committed
Moving loss function to a new module.
1 parent e1adeda commit 08b83b4

File tree

2 files changed

+72
-39
lines changed

2 files changed

+72
-39
lines changed

loss.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
3+
class MultipleChoiceLossCompute:
4+
"A Loss compute and train function for multiple choice tasks."
5+
6+
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
7+
self.lm_criterion = lm_criterion
8+
self.clf_criterion = clf_criterion
9+
self.lm_coef = lm_coef
10+
self.opt = opt
11+
12+
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
13+
# Language modeling loss
14+
if lm_logits is not None:
15+
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
16+
M = M.view(-1, M.size(2))
17+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
18+
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
19+
lm_losses = lm_losses * M[:, 1:]
20+
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
21+
# Classification loss
22+
clf_losses = self.clf_criterion(clf_logits, Y)
23+
if only_return_losses:
24+
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
25+
26+
if self.lm_coef > 0 and lm_logits is not None:
27+
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
28+
else:
29+
train_loss = clf_losses.sum()
30+
train_loss.backward()
31+
if self.opt is not None:
32+
self.opt.step()
33+
self.opt.zero_grad()
34+
return train_loss.item()
35+
36+
class ClassificationLossCompute:
37+
"A Loss compute and train function for classification tasks."
38+
39+
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
40+
self.lm_criterion = lm_criterion
41+
self.clf_criterion = clf_criterion
42+
self.lm_coef = lm_coef
43+
self.opt = opt
44+
45+
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
46+
# Language modeling loss
47+
if lm_logits is not None:
48+
x_shifted = X[:, 1:, 0].contiguous().view(-1)
49+
M = M.view(-1, M.size(-1))
50+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
51+
lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1)
52+
lm_losses = lm_losses * M[:, 1:]
53+
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
54+
# Classification loss
55+
clf_losses = self.clf_criterion(clf_logits, Y)
56+
if only_return_losses:
57+
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
58+
59+
if self.lm_coef > 0 and lm_logits is not None:
60+
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
61+
else:
62+
train_loss = clf_losses.sum()
63+
train_loss.backward()
64+
if self.opt is not None:
65+
self.opt.step()
66+
self.opt.zero_grad()
67+
return train_loss.item()

train.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,7 @@
1515
from text_utils import TextEncoder
1616
from utils import (encode_dataset, iter_data,
1717
ResultLogger, make_path)
18-
19-
20-
class LossCompute:
21-
"A Loss compute and train function."
22-
23-
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
24-
self.lm_criterion = lm_criterion
25-
self.clf_criterion = clf_criterion
26-
self.lm_coef = lm_coef
27-
self.opt = opt
28-
29-
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
30-
# Language modeling loss
31-
if lm_logits is not None:
32-
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
33-
M = M.view(-1, M.size(2))
34-
lm_losses = self.lm_criterion(lm_logits, x_shifted)
35-
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
36-
lm_losses = lm_losses * M[:, 1:]
37-
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
38-
# Classification loss
39-
clf_losses = self.clf_criterion(clf_logits, Y)
40-
if only_return_losses:
41-
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
42-
43-
if self.lm_coef > 0 and lm_logits is not None:
44-
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
45-
else:
46-
train_loss = clf_losses.sum()
47-
train_loss.backward()
48-
if self.opt is not None:
49-
self.opt.step()
50-
self.opt.zero_grad()
51-
return train_loss.item()
52-
18+
from loss import MultipleChoiceLossCompute
5319

5420
def transform_roc(X1, X2, X3):
5521
n_batch = len(X1)
@@ -277,10 +243,10 @@ def run_epoch():
277243
l2=args.l2,
278244
vector_l2=args.vector_l2,
279245
max_grad_norm=args.max_grad_norm)
280-
compute_loss_fct = LossCompute(criterion,
281-
criterion,
282-
args.lm_coef,
283-
model_opt)
246+
compute_loss_fct = MultipleChoiceLossCompute(criterion,
247+
criterion,
248+
args.lm_coef,
249+
model_opt)
284250
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
285251

286252
dh_model.to(device)

0 commit comments

Comments
 (0)