|
34 | 34 |
|
35 | 35 | class LossCompute:
|
36 | 36 | "A Loss compute and train function."
|
37 |
| - def __init__(self, generator, lm_criterion, n_embed, opt=None): |
| 37 | + def __init__(self, generator, lm_criterion, n_embed, clf_token, opt=None): |
38 | 38 | self.generator = generator
|
39 | 39 | self.lm_criterion = lm_criterion
|
40 | 40 | self.opt = opt
|
41 | 41 | self.n_embed = n_embed
|
| 42 | + self.clf_token = clf_token |
42 | 43 |
|
43 |
| - def __call__(self, X, Y, M, h, norm): |
| 44 | + def __call__(self, X, Y, M, lm_logits, clf_logits, norm): |
44 | 45 | # Language modeling loss
|
45 |
| - h_trunc = h[:, :-1].contiguous().view(-1, self.n_embed) # Shape: 252, 768 |
46 | 46 | x_shifted = X[:, 1:, 0].contiguous().view(-1) # Shape: 252
|
47 |
| - lm_logits = self.generator(h_trunc) |
48 |
| - lm_losses = self.lm_criterion(h_trunc, x_shifted) |
49 |
| - lm_losses = lm_losses.view(x.size(0), X.size(1)) |
| 47 | + lm_losses = self.lm_criterion(lm_logits, x_shifted) |
| 48 | + lm_losses = lm_losses.view(X.size(0), X.size(1)) |
50 | 49 | lm_losses = lm_losses * M[:, 1:]
|
51 | 50 | lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
52 | 51 |
|
53 | 52 | # Classification loss
|
54 |
| - clf_h = h.view(-1, self.n_embed) |
55 |
| - |
56 |
| - # loss.backward() |
57 |
| - # if self.opt is not None: |
58 |
| - # self.opt.step() |
59 |
| - # self.opt.optimizer.zero_grad() |
60 |
| - return lm_losses |
61 |
| - |
62 |
| -def model(X, M, Y, train=False, reuse=False): |
63 |
| - we = tf.get_variable("we", [n_vocab+n_special+n_ctx, n_embd], |
64 |
| - initializer=tf.random_normal_initializer(stddev=0.02)) |
65 |
| - we = dropout(we, embd_pdrop, train) |
66 |
| - |
67 |
| - X = tf.reshape(X, [-1, n_ctx, 2]) |
68 |
| - M = tf.reshape(M, [-1, n_ctx]) |
69 |
| - |
70 |
| - h = embed(X, we) |
71 |
| - for layer in range(n_layer): |
72 |
| - h = block(h, 'h%d'%layer, train=train, scale=True) |
73 |
| - |
74 |
| - lm_h = tf.reshape(h[:, :-1], [-1, n_embd]) |
75 |
| - lm_logits = tf.matmul(lm_h, we, transpose_b=True) |
76 |
| - lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=lm_logits, labels=tf.reshape(X[:, 1:, 0], [-1])) |
77 |
| - lm_losses = tf.reshape(lm_losses, [shape_list(X)[0], shape_list(X)[1]-1]) |
78 |
| - lm_losses = tf.reduce_sum(lm_losses*M[:, 1:], 1)/tf.reduce_sum(M[:, 1:], 1) |
79 |
| - |
80 |
| - clf_h = tf.reshape(h, [-1, n_embd]) |
81 |
| - pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32) |
82 |
| - clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32)*n_ctx+pool_idx) |
83 |
| - |
84 |
| - clf_h = tf.reshape(clf_h, [-1, 2, n_embd]) |
85 |
| - if train and clf_pdrop > 0: |
86 |
| - shape = shape_list(clf_h) |
87 |
| - shape[1] = 1 |
88 |
| - clf_h = tf.nn.dropout(clf_h, 1-clf_pdrop, shape) |
89 |
| - clf_h = tf.reshape(clf_h, [-1, n_embd]) |
90 |
| - clf_logits = clf(clf_h, 1, train=train) |
91 |
| - clf_logits = tf.reshape(clf_logits, [-1, 2]) |
92 |
| - |
93 |
| - clf_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=clf_logits, labels=Y) |
94 |
| - return clf_logits, clf_losses, lm_losses |
| 53 | + clf_losses = self.clf_criterion(clf_logits, Y) |
| 54 | + |
| 55 | + if lm_coef > 0: |
| 56 | + train_loss = clf_losses.sum() + lm_coef * lm_losses.sum()) |
| 57 | + else: |
| 58 | + train_loss = clf_losses.sum() |
| 59 | + return train_loss |
95 | 60 |
|
96 | 61 | def mgpu_train(*xs):
|
97 | 62 | gpu_ops = []
|
|
0 commit comments