-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathclassifier.py
33 lines (25 loc) · 1.12 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch.nn as nn
from transformers import BertModel
import torch
class SentimentClassifier(nn.Module):
def __init__(self, num_classes, device, freeze_bert = True):
super(SentimentClassifier, self).__init__()
self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
self.device = device
if freeze_bert:
for p in self.bert_layer.parameters():
p.requires_grad = False
self.cls_layer = nn.Linear(768, num_classes)
def forward(self, seq, attn_masks):
'''
Inputs:
-seq : Tensor of shape [B, T] containing token ids of sequences
-attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
'''
#Feeding the input to BERT model to obtain contextualized representations
cont_reps, _ = self.bert_layer(seq, attention_mask = attn_masks)
#Obtaining the representation of [CLS] head
cls_rep = cont_reps[:, 0]
#Feeding cls_rep to the classifier layer
logits = self.cls_layer(cls_rep)
return logits.to(self.device)