Skip to content

Commit 4a208be

Browse files
committed
make it work with transformers 4
1 parent 0cf2be0 commit 4a208be

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformers_multi_label_classification.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@
912912
" self.l3 = torch.nn.Linear(768, 6)\n",
913913
" \n",
914914
" def forward(self, ids, mask, token_type_ids):\n",
915-
" _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)\n",
915+
" _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)\n",
916916
" output_2 = self.l2(output_1)\n",
917917
" output = self.l3(output_2)\n",
918918
" return output\n",

0 commit comments

Comments
 (0)