Skip to content

Commit 5eb7937

Browse files
committed
loss computation
1 parent 8858f99 commit 5eb7937

File tree

2 files changed

+28
-48
lines changed

2 files changed

+28
-48
lines changed

model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,19 @@ def forward(self, x, m):
160160
h = e.sum(dim=2)
161161
for block in self.blocks:
162162
h = block(h)
163-
return h
163+
164+
# Language modeling logits
165+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embed) # Shape: 252, 768
166+
lm_logits = self.decoder(h_trunc)
167+
168+
# Classification logits
169+
clf_h = h.view(-1, self.n_embed)
170+
pool_idx = torch.eq(X[:, :, 0].contiguous().view(-1), self.clf_token)
171+
clf_h = clf_h[pool_idx, :]
172+
clf_h = clf_h.view(-1, 2, self.n_embed, 1)
173+
m = nn.Dropout2d(clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
174+
clf_h = m(clf_h)
175+
clf_h = clf_h.view(-1, self.n_embed)
176+
clf_logits = self.linear(clf_h)
177+
178+
return lm_logits, clf_logits

train.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -34,64 +34,29 @@
3434

3535
class LossCompute:
3636
"A Loss compute and train function."
37-
def __init__(self, generator, lm_criterion, n_embed, opt=None):
37+
def __init__(self, generator, lm_criterion, n_embed, clf_token, opt=None):
3838
self.generator = generator
3939
self.lm_criterion = lm_criterion
4040
self.opt = opt
4141
self.n_embed = n_embed
42+
self.clf_token = clf_token
4243

43-
def __call__(self, X, Y, M, h, norm):
44+
def __call__(self, X, Y, M, lm_logits, clf_logits, norm):
4445
# Language modeling loss
45-
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embed) # Shape: 252, 768
4646
x_shifted = X[:, 1:, 0].contiguous().view(-1) # Shape: 252
47-
lm_logits = self.generator(h_trunc)
48-
lm_losses = self.lm_criterion(h_trunc, x_shifted)
49-
lm_losses = lm_losses.view(x.size(0), X.size(1))
47+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
48+
lm_losses = lm_losses.view(X.size(0), X.size(1))
5049
lm_losses = lm_losses * M[:, 1:]
5150
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
5251

5352
# Classification loss
54-
clf_h = h.view(-1, self.n_embed)
55-
56-
# loss.backward()
57-
# if self.opt is not None:
58-
# self.opt.step()
59-
# self.opt.optimizer.zero_grad()
60-
return lm_losses
61-
62-
def model(X, M, Y, train=False, reuse=False):
63-
we = tf.get_variable("we", [n_vocab+n_special+n_ctx, n_embd],
64-
initializer=tf.random_normal_initializer(stddev=0.02))
65-
we = dropout(we, embd_pdrop, train)
66-
67-
X = tf.reshape(X, [-1, n_ctx, 2])
68-
M = tf.reshape(M, [-1, n_ctx])
69-
70-
h = embed(X, we)
71-
for layer in range(n_layer):
72-
h = block(h, 'h%d'%layer, train=train, scale=True)
73-
74-
lm_h = tf.reshape(h[:, :-1], [-1, n_embd])
75-
lm_logits = tf.matmul(lm_h, we, transpose_b=True)
76-
lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=lm_logits, labels=tf.reshape(X[:, 1:, 0], [-1]))
77-
lm_losses = tf.reshape(lm_losses, [shape_list(X)[0], shape_list(X)[1]-1])
78-
lm_losses = tf.reduce_sum(lm_losses*M[:, 1:], 1)/tf.reduce_sum(M[:, 1:], 1)
79-
80-
clf_h = tf.reshape(h, [-1, n_embd])
81-
pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32)
82-
clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32)*n_ctx+pool_idx)
83-
84-
clf_h = tf.reshape(clf_h, [-1, 2, n_embd])
85-
if train and clf_pdrop > 0:
86-
shape = shape_list(clf_h)
87-
shape[1] = 1
88-
clf_h = tf.nn.dropout(clf_h, 1-clf_pdrop, shape)
89-
clf_h = tf.reshape(clf_h, [-1, n_embd])
90-
clf_logits = clf(clf_h, 1, train=train)
91-
clf_logits = tf.reshape(clf_logits, [-1, 2])
92-
93-
clf_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=clf_logits, labels=Y)
94-
return clf_logits, clf_losses, lm_losses
53+
clf_losses = self.clf_criterion(clf_logits, Y)
54+
55+
if lm_coef > 0:
56+
train_loss = clf_losses.sum() + lm_coef * lm_losses.sum())
57+
else:
58+
train_loss = clf_losses.sum()
59+
return train_loss
9560

9661
def mgpu_train(*xs):
9762
gpu_ops = []

0 commit comments

Comments
 (0)