-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgru_pretrained.py
47 lines (38 loc) · 1.95 KB
/
gru_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
class LSTMClassifier(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size, batch_size):
super(LSTMClassifier, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim//2, bidirectional=True)
self.hidden2label = nn.Linear(hidden_dim, label_size)
self.dropout = nn.Dropout(0.5)
def last_timestep(self, unpacked, lengths):
# Index of the last output for each sequence.
idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
unpacked.size(2)).unsqueeze(1)
if torch.cuda.is_available():
idx = idx.cuda()
return unpacked.gather(1, idx).squeeze()
def init_hidden(self):
if torch.cuda.is_available():
h0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim//2)).cuda()
c0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim//2)).cuda()
else:
h0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim//2))
c0 = Variable(torch.zeros(2, self.batch_size, self.hidden_dim//2))
return (h0, c0)
def forward(self, sentence,lengths):
packed = torch.nn.utils.rnn.pack_padded_sequence(sentence, lengths,batch_first=True)
lstm_out, self.hidden = self.lstm(packed, self.hidden)
unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(lstm_out,batch_first=True)
# get the outputs from the last *non-masked* timestep for each sentence
last_outputs = self.last_timestep(unpacked, unpacked_len)
last_outputs = self.dropout(last_outputs)
#hidden_1 = self.relu1(last_outputs)
y = self.hidden2label(last_outputs)
return y