1
+ from typing import List , Dict , Tuple , Optional
2
+
1
3
import torch
2
4
import torch .nn .functional as F
3
5
import lightning as L
6
8
7
9
8
10
class LSTM (L .LightningModule ):
9
- """Long Short-Term Memory (LSTM) model."""
10
- def __init__ (self , in_dim = 100 , out_dim = 2 , hidden_dims = (400 , 1 , .0 ), use_bias = True , args = {}):
11
- """
12
- Constructs a LSTM model
13
-
14
- Parameters
15
- ----------
16
- in_dim : int
17
- Dimensionality of input data
18
- out_dim : int
19
- Dimensionality of output data
20
- hidden_dims : List
21
- Architectural parameters of the model
22
- (hidden_size, num_layers, dropout)
23
- use_bias : bool
24
- Whether to use bias or not in the final linear layer
25
- """
11
+ """
12
+ Long Short-Term Memory (LSTM) model.
13
+
14
+ This class implements an LSTM model using PyTorch Lightning.
15
+
16
+ Parameters
17
+ ----------
18
+ in_dim : int, optional
19
+ Dimensionality of input data, by default 100
20
+ out_dim : int, optional
21
+ Dimensionality of output data, by default 2
22
+ hidden_dims : Tuple[int, int, float], optional
23
+ Architectural parameters of the model (hidden_size, num_layers, dropout),
24
+ by default (400, 1, 0.0)
25
+ use_bias : bool, optional
26
+ Whether to use bias or not in the final linear layer, by default True
27
+ args : Dict, optional
28
+ Additional arguments for model configuration, by default {}
29
+
30
+ Attributes
31
+ ----------
32
+ lstm : nn.LSTM
33
+ LSTM layer
34
+ fc : nn.Linear
35
+ Fully connected layer
36
+ hidden_state : Optional[torch.Tensor]
37
+ Hidden state of the LSTM
38
+ cell_state : Optional[torch.Tensor]
39
+ Cell state of the LSTM
40
+ """
41
+ def __init__ (self , in_dim : int = 100 , out_dim : int = 2 ,
42
+ hidden_dims : Tuple [int , int , float ] = (400 , 1 , 0.0 ),
43
+ use_bias : bool = True , args : Dict = {}):
26
44
super ().__init__ ()
27
45
self .save_hyperparameters ()
28
46
self .in_dim = in_dim
29
47
self .out_dim = out_dim
30
48
if len (hidden_dims ) != 3 :
31
49
raise ValueError ('`hidden_dims` should be of size 3' )
32
- hidden_size , nlayers , dropout = hidden_dims
33
- self .nlayers = nlayers
34
- self .hidden_size = hidden_size
35
- self .dropout = dropout
50
+ self .hidden_size , self .nlayers , self .dropout = hidden_dims
36
51
self .args = args
37
52
38
- # Add final layer to the number of classes
39
- self . lstm = nn . LSTM ( input_size = in_dim , hidden_size = hidden_size ,
40
- num_layers = nlayers , batch_first = True , dropout = dropout , bidirectional = True )
41
- self .fc = nn .Linear (in_features = 2 * hidden_size , out_features = out_dim , bias = use_bias )
42
- self .hidden_state = None
43
- self .cell_state = None
53
+ self . lstm = nn . LSTM ( input_size = in_dim , hidden_size = self . hidden_size ,
54
+ num_layers = self . nlayers , batch_first = True ,
55
+ dropout = self . dropout , bidirectional = True )
56
+ self .fc = nn .Linear (in_features = 2 * self . hidden_size , out_features = out_dim , bias = use_bias )
57
+ self .hidden_state : Optional [ torch . Tensor ] = None
58
+ self .cell_state : Optional [ torch . Tensor ] = None
44
59
45
- def init_params (m ):
60
+ self ._init_params ()
61
+
62
+ def _init_params (self ) -> None :
63
+ """Initialize model parameters."""
64
+ def init_params (m : nn .Module ) -> None :
46
65
if isinstance (m , nn .Linear ):
47
66
torch .nn .init .kaiming_uniform_ (m .weight , nonlinearity = 'leaky_relu' )
48
67
if m .bias is not None :
@@ -51,35 +70,63 @@ def init_params(m):
51
70
nn .init .uniform_ (m .bias , - bound , bound ) # LeCunn init
52
71
init_params (self .fc )
53
72
54
- def forward (self , x ):
55
- lstm_out , (self .hidden_state , self .cell_state ) = self .lstm (x , (self .hidden_state , self .cell_state ))
56
- B , L , H = lstm_out .shape
57
- # Shape: [batch_size x max_length x hidden_dim]
73
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
74
+ """
75
+ Forward pass of the LSTM model.
58
76
59
- # Select the activation of the last Hidden Layer
60
- # lstm_out = lstm_out.view(B, L, 2, -1).sum(dim=2)
61
- lstm_out = lstm_out [:,- 1 ,:].contiguous ()
62
-
63
- # Shape: [batch_size x hidden_dim]
77
+ Parameters
78
+ ----------
79
+ x : torch.Tensor
80
+ Input tensor of shape (batch_size, sequence_length, input_dim)
64
81
65
- # Fully connected layer
82
+ Returns
83
+ -------
84
+ torch.Tensor
85
+ Output tensor of shape (batch_size, output_dim)
86
+ """
87
+ lstm_out , (self .hidden_state , self .cell_state ) = \
88
+ self .lstm (x , (self .hidden_state , self .cell_state ))
89
+ lstm_out = lstm_out [:, - 1 , :].contiguous ()
66
90
out = self .fc (lstm_out )
67
- if self .args [ 'clf' ] :
91
+ if self .args . get ( 'clf' , False ) :
68
92
out = F .log_softmax (out , dim = 1 )
69
-
70
93
return out
71
94
72
- def init_hidden (self , batch_size ):
73
- ''' Initializes hidden state '''
74
- # Create two new tensors with sizes n_layers x batch_size x hidden_dim,
75
- # initialized to zero, for hidden state and cell state of LSTM
95
+ def init_hidden (self , batch_size : int ) -> None :
96
+ """
97
+ Initialize hidden state and cell state.
98
+
99
+ Parameters
100
+ ----------
101
+ batch_size : int
102
+ Batch size for initialization
103
+ """
76
104
self .batch_size = batch_size
77
- h0 = torch .zeros ((2 * self .nlayers ,batch_size ,self .hidden_size ), requires_grad = False )
78
- c0 = torch .zeros ((2 * self .nlayers ,batch_size ,self .hidden_size ), requires_grad = False )
105
+ h0 = torch .zeros (
106
+ (2 * self .nlayers , batch_size , self .hidden_size ),
107
+ requires_grad = False
108
+ )
109
+ c0 = torch .zeros (
110
+ (2 * self .nlayers , batch_size , self .hidden_size ),
111
+ requires_grad = False
112
+ )
79
113
self .hidden_state = h0
80
114
self .cell_state = c0
81
115
82
- def predict (self , x ):
116
+ def predict (self , x : torch .Tensor ) -> torch .Tensor :
117
+ """
118
+ Make predictions using the LSTM model.
119
+
120
+ Parameters
121
+ ----------
122
+ x : torch.Tensor
123
+ Input tensor
124
+
125
+ Returns
126
+ -------
127
+ torch.Tensor
128
+ Predicted output
129
+ """
83
130
self .hidden_state = self .hidden_state .to (x .device )
84
131
self .cell_state = self .cell_state .to (x .device )
85
132
preds = []
@@ -93,45 +140,110 @@ def predict(self, x):
93
140
pred_loc = pred_loc [:batch_size - (i - x .shape [0 ])]
94
141
preds .extend (pred_loc )
95
142
out = torch .stack (preds )
96
- if self .args [ 'clf' ] :
143
+ if self .args . get ( 'clf' , False ) :
97
144
out = F .log_softmax (out , dim = 1 )
98
145
return out
99
146
100
- def _step (self , batch , batch_idx ) -> torch .Tensor :
101
- xs , ys = batch # unpack the batch
102
- outs = self (xs ) # apply the model
147
+ def _step (self , batch : Tuple [torch .Tensor , torch .Tensor ], batch_idx : int ) -> torch .Tensor :
148
+ """
149
+ Perform a single step (forward pass + loss calculation).
150
+
151
+ Parameters
152
+ ----------
153
+ batch : Tuple[torch.Tensor, torch.Tensor]
154
+ Batch of input data and labels
155
+ batch_idx : int
156
+ Index of the current batch
157
+
158
+ Returns
159
+ -------
160
+ torch.Tensor
161
+ Computed loss
162
+ """
163
+ xs , ys = batch
164
+ outs = self (xs )
103
165
loss = self .args ['criterion' ](outs , ys )
104
166
return loss
105
167
106
- def training_step (self , batch , batch_idx ) -> torch .Tensor :
168
+ def training_step (self , batch : Tuple [torch .Tensor , torch .Tensor ], batch_idx : int ) -> torch .Tensor :
169
+ """
170
+ Lightning method for training step.
171
+
172
+ Parameters
173
+ ----------
174
+ batch : Tuple[torch.Tensor, torch.Tensor]
175
+ Batch of input data and labels
176
+ batch_idx : int
177
+ Index of the current batch
178
+
179
+ Returns
180
+ -------
181
+ torch.Tensor
182
+ Computed loss
183
+ """
107
184
loss = self ._step (batch , batch_idx )
108
185
self .log ('train_loss' , loss )
109
186
return loss
110
187
111
- def on_after_backward (self ):
112
- # LSTM specific
188
+ def on_after_backward (self ) -> None :
189
+ """Lightning method called after backpropagation."""
113
190
self .hidden_state .detach_ ()
114
191
self .cell_state .detach_ ()
115
- # self.hidden_state.data.fill_(.0)
116
- # self.cell_state.data.fill_(.0)
117
192
118
- def validation_step (self , batch , batch_idx ) -> torch .Tensor :
193
+ def validation_step (self , batch : Tuple [torch .Tensor , torch .Tensor ], batch_idx : int ) -> torch .Tensor :
194
+ """
195
+ Lightning method for validation step.
196
+
197
+ Parameters
198
+ ----------
199
+ batch : Tuple[torch.Tensor, torch.Tensor]
200
+ Batch of input data and labels
201
+ batch_idx : int
202
+ Index of the current batch
203
+
204
+ Returns
205
+ -------
206
+ torch.Tensor
207
+ Computed loss
208
+ """
119
209
loss = self ._step (batch , batch_idx )
120
210
self .log ('val_loss' , loss )
121
211
return loss
122
212
123
- def test_step (self , batch , batch_idx ) -> torch .Tensor :
213
+ def test_step (self , batch : Tuple [torch .Tensor , torch .Tensor ], batch_idx : int ) -> torch .Tensor :
214
+ """
215
+ Lightning method for test step.
216
+
217
+ Parameters
218
+ ----------
219
+ batch : Tuple[torch.Tensor, torch.Tensor]
220
+ Batch of input data and labels
221
+ batch_idx : int
222
+ Index of the current batch
223
+
224
+ Returns
225
+ -------
226
+ torch.Tensor
227
+ Computed loss
228
+ """
124
229
loss = self ._step (batch , batch_idx )
125
230
self .log ('test_loss' , loss )
126
231
return loss
127
232
128
- def configure_optimizers (self ):
129
- args = self .args
233
+ def configure_optimizers (self ) -> Tuple [List [torch .optim .Optimizer ], List [Dict ]]:
234
+ """
235
+ Configure optimizers and learning rate schedulers.
236
+
237
+ Returns
238
+ -------
239
+ Tuple[List[torch.optim.Optimizer], List[Dict]]
240
+ Tuple containing a list of optimizers and a list of scheduler configurations
241
+ """
130
242
optimizer = torch .optim .AdamW (
131
- self .parameters (), weight_decay = args ['weight_decay' ])
243
+ self .parameters (), weight_decay = self . args ['weight_decay' ])
132
244
scheduler = torch .optim .lr_scheduler .OneCycleLR (
133
- optimizer , max_lr = args ['lr' ],
134
- epochs = args ['epochs' ],
245
+ optimizer , max_lr = self . args ['lr' ],
246
+ epochs = self . args ['epochs' ],
135
247
steps_per_epoch = len (
136
248
self .trainer ._data_connector ._train_dataloader_source .dataloader ()
137
249
)
0 commit comments