Skip to content

Commit 2eb381c

Browse files
committed
fix mssking bias bug in Attention._attn
1 parent 8833a83 commit 2eb381c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

generate.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def append_batch(X, next_idx):
6565
parser.add_argument('--e', type=float, default=1e-8)
6666
parser.add_argument('--n_valid', type=int, default=374)
6767
parser.add_argument('--gen_len', type=int, default=20)
68+
parser.add_argument('--topk', type=int, default=10)
6869

6970
args = parser.parse_args()
7071
print(args)
@@ -108,7 +109,11 @@ def append_batch(X, next_idx):
108109

109110
for _ in range(args.gen_len):
110111
lm_probs = lm_model(XMB)
111-
next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
112+
if args.topk == 0:
113+
next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
114+
else:
115+
values, indices = lm_probs[:, -1, :].topk(args.topk)
116+
next_idx = indices.gather(-1, torch.multinomial(values, 1))
112117
next_token = text_encoder.decoder[next_idx.item()].replace('</w>', '')
113118
print(next_token, end=' ')
114119
XMB = append_batch(XMB, next_idx)

model_pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _attn(self, q, k, v):
8686
w = w / math.sqrt(v.size(-1))
8787
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
8888
# XD: self.b may be larger than w, so we need to crop it
89-
b = self.b[:, :, w.size(-2), w.size(-1)]
89+
b = self.b[:, :, :w.size(-2), :w.size(-1)]
9090
w = w * b + -1e9 * (1 - b)
9191

9292
w = nn.Softmax(dim=-1)(w)

0 commit comments

Comments
 (0)