We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8833a83 commit 2eb381cCopy full SHA for 2eb381c
generate.py
@@ -65,6 +65,7 @@ def append_batch(X, next_idx):
65
parser.add_argument('--e', type=float, default=1e-8)
66
parser.add_argument('--n_valid', type=int, default=374)
67
parser.add_argument('--gen_len', type=int, default=20)
68
+ parser.add_argument('--topk', type=int, default=10)
69
70
args = parser.parse_args()
71
print(args)
@@ -108,7 +109,11 @@ def append_batch(X, next_idx):
108
109
110
for _ in range(args.gen_len):
111
lm_probs = lm_model(XMB)
- next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
112
+ if args.topk == 0:
113
+ next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
114
+ else:
115
+ values, indices = lm_probs[:, -1, :].topk(args.topk)
116
+ next_idx = indices.gather(-1, torch.multinomial(values, 1))
117
next_token = text_encoder.decoder[next_idx.item()].replace('</w>', '')
118
print(next_token, end=' ')
119
XMB = append_batch(XMB, next_idx)
model_pytorch.py
@@ -86,7 +86,7 @@ def _attn(self, q, k, v):
86
w = w / math.sqrt(v.size(-1))
87
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
88
# XD: self.b may be larger than w, so we need to crop it
89
- b = self.b[:, :, w.size(-2), w.size(-1)]
+ b = self.b[:, :, :w.size(-2), :w.size(-1)]
90
w = w * b + -1e9 * (1 - b)
91
92
w = nn.Softmax(dim=-1)(w)
0 commit comments