Skip to content

Commit 8c55cec

Browse files
committed
fix decoding module docstrings & add static typing
1 parent d8e969c commit 8c55cec

File tree

7 files changed

+1236
-397
lines changed

7 files changed

+1236
-397
lines changed

neuro_py/ensemble/decoding/lstm.py

Lines changed: 175 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Dict, Tuple, Optional
2+
13
import torch
24
import torch.nn.functional as F
35
import lightning as L
@@ -6,43 +8,60 @@
68

79

810
class LSTM(L.LightningModule):
9-
"""Long Short-Term Memory (LSTM) model."""
10-
def __init__(self, in_dim=100, out_dim=2, hidden_dims=(400, 1, .0), use_bias=True, args={}):
11-
"""
12-
Constructs a LSTM model
13-
14-
Parameters
15-
----------
16-
in_dim : int
17-
Dimensionality of input data
18-
out_dim : int
19-
Dimensionality of output data
20-
hidden_dims : List
21-
Architectural parameters of the model
22-
(hidden_size, num_layers, dropout)
23-
use_bias : bool
24-
Whether to use bias or not in the final linear layer
25-
"""
11+
"""
12+
Long Short-Term Memory (LSTM) model.
13+
14+
This class implements an LSTM model using PyTorch Lightning.
15+
16+
Parameters
17+
----------
18+
in_dim : int, optional
19+
Dimensionality of input data, by default 100
20+
out_dim : int, optional
21+
Dimensionality of output data, by default 2
22+
hidden_dims : Tuple[int, int, float], optional
23+
Architectural parameters of the model (hidden_size, num_layers, dropout),
24+
by default (400, 1, 0.0)
25+
use_bias : bool, optional
26+
Whether to use bias or not in the final linear layer, by default True
27+
args : Dict, optional
28+
Additional arguments for model configuration, by default {}
29+
30+
Attributes
31+
----------
32+
lstm : nn.LSTM
33+
LSTM layer
34+
fc : nn.Linear
35+
Fully connected layer
36+
hidden_state : Optional[torch.Tensor]
37+
Hidden state of the LSTM
38+
cell_state : Optional[torch.Tensor]
39+
Cell state of the LSTM
40+
"""
41+
def __init__(self, in_dim: int = 100, out_dim: int = 2,
42+
hidden_dims: Tuple[int, int, float] = (400, 1, 0.0),
43+
use_bias: bool = True, args: Dict = {}):
2644
super().__init__()
2745
self.save_hyperparameters()
2846
self.in_dim = in_dim
2947
self.out_dim = out_dim
3048
if len(hidden_dims) != 3:
3149
raise ValueError('`hidden_dims` should be of size 3')
32-
hidden_size, nlayers, dropout = hidden_dims
33-
self.nlayers = nlayers
34-
self.hidden_size = hidden_size
35-
self.dropout = dropout
50+
self.hidden_size, self.nlayers, self.dropout = hidden_dims
3651
self.args = args
3752

38-
# Add final layer to the number of classes
39-
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden_size,
40-
num_layers=nlayers, batch_first=True, dropout=dropout, bidirectional=True)
41-
self.fc = nn.Linear(in_features=2*hidden_size, out_features=out_dim, bias=use_bias)
42-
self.hidden_state = None
43-
self.cell_state = None
53+
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=self.hidden_size,
54+
num_layers=self.nlayers, batch_first=True,
55+
dropout=self.dropout, bidirectional=True)
56+
self.fc = nn.Linear(in_features=2*self.hidden_size, out_features=out_dim, bias=use_bias)
57+
self.hidden_state: Optional[torch.Tensor] = None
58+
self.cell_state: Optional[torch.Tensor] = None
4459

45-
def init_params(m):
60+
self._init_params()
61+
62+
def _init_params(self) -> None:
63+
"""Initialize model parameters."""
64+
def init_params(m: nn.Module) -> None:
4665
if isinstance(m, nn.Linear):
4766
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
4867
if m.bias is not None:
@@ -51,35 +70,63 @@ def init_params(m):
5170
nn.init.uniform_(m.bias, -bound, bound) # LeCunn init
5271
init_params(self.fc)
5372

54-
def forward(self, x):
55-
lstm_out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
56-
B, L, H = lstm_out.shape
57-
# Shape: [batch_size x max_length x hidden_dim]
73+
def forward(self, x: torch.Tensor) -> torch.Tensor:
74+
"""
75+
Forward pass of the LSTM model.
5876
59-
# Select the activation of the last Hidden Layer
60-
# lstm_out = lstm_out.view(B, L, 2, -1).sum(dim=2)
61-
lstm_out = lstm_out[:,-1,:].contiguous()
62-
63-
# Shape: [batch_size x hidden_dim]
77+
Parameters
78+
----------
79+
x : torch.Tensor
80+
Input tensor of shape (batch_size, sequence_length, input_dim)
6481
65-
# Fully connected layer
82+
Returns
83+
-------
84+
torch.Tensor
85+
Output tensor of shape (batch_size, output_dim)
86+
"""
87+
lstm_out, (self.hidden_state, self.cell_state) = \
88+
self.lstm(x, (self.hidden_state, self.cell_state))
89+
lstm_out = lstm_out[:, -1, :].contiguous()
6690
out = self.fc(lstm_out)
67-
if self.args['clf']:
91+
if self.args.get('clf', False):
6892
out = F.log_softmax(out, dim=1)
69-
7093
return out
7194

72-
def init_hidden(self, batch_size):
73-
''' Initializes hidden state '''
74-
# Create two new tensors with sizes n_layers x batch_size x hidden_dim,
75-
# initialized to zero, for hidden state and cell state of LSTM
95+
def init_hidden(self, batch_size: int) -> None:
96+
"""
97+
Initialize hidden state and cell state.
98+
99+
Parameters
100+
----------
101+
batch_size : int
102+
Batch size for initialization
103+
"""
76104
self.batch_size = batch_size
77-
h0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False)
78-
c0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False)
105+
h0 = torch.zeros(
106+
(2*self.nlayers, batch_size, self.hidden_size),
107+
requires_grad=False
108+
)
109+
c0 = torch.zeros(
110+
(2*self.nlayers, batch_size, self.hidden_size),
111+
requires_grad=False
112+
)
79113
self.hidden_state = h0
80114
self.cell_state = c0
81115

82-
def predict(self, x):
116+
def predict(self, x: torch.Tensor) -> torch.Tensor:
117+
"""
118+
Make predictions using the LSTM model.
119+
120+
Parameters
121+
----------
122+
x : torch.Tensor
123+
Input tensor
124+
125+
Returns
126+
-------
127+
torch.Tensor
128+
Predicted output
129+
"""
83130
self.hidden_state = self.hidden_state.to(x.device)
84131
self.cell_state = self.cell_state.to(x.device)
85132
preds = []
@@ -93,45 +140,110 @@ def predict(self, x):
93140
pred_loc = pred_loc[:batch_size-(i-x.shape[0])]
94141
preds.extend(pred_loc)
95142
out = torch.stack(preds)
96-
if self.args['clf']:
143+
if self.args.get('clf', False):
97144
out = F.log_softmax(out, dim=1)
98145
return out
99146

100-
def _step(self, batch, batch_idx) -> torch.Tensor:
101-
xs, ys = batch # unpack the batch
102-
outs = self(xs) # apply the model
147+
def _step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
148+
"""
149+
Perform a single step (forward pass + loss calculation).
150+
151+
Parameters
152+
----------
153+
batch : Tuple[torch.Tensor, torch.Tensor]
154+
Batch of input data and labels
155+
batch_idx : int
156+
Index of the current batch
157+
158+
Returns
159+
-------
160+
torch.Tensor
161+
Computed loss
162+
"""
163+
xs, ys = batch
164+
outs = self(xs)
103165
loss = self.args['criterion'](outs, ys)
104166
return loss
105167

106-
def training_step(self, batch, batch_idx) -> torch.Tensor:
168+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
169+
"""
170+
Lightning method for training step.
171+
172+
Parameters
173+
----------
174+
batch : Tuple[torch.Tensor, torch.Tensor]
175+
Batch of input data and labels
176+
batch_idx : int
177+
Index of the current batch
178+
179+
Returns
180+
-------
181+
torch.Tensor
182+
Computed loss
183+
"""
107184
loss = self._step(batch, batch_idx)
108185
self.log('train_loss', loss)
109186
return loss
110187

111-
def on_after_backward(self):
112-
# LSTM specific
188+
def on_after_backward(self) -> None:
189+
"""Lightning method called after backpropagation."""
113190
self.hidden_state.detach_()
114191
self.cell_state.detach_()
115-
# self.hidden_state.data.fill_(.0)
116-
# self.cell_state.data.fill_(.0)
117192

118-
def validation_step(self, batch, batch_idx) -> torch.Tensor:
193+
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
194+
"""
195+
Lightning method for validation step.
196+
197+
Parameters
198+
----------
199+
batch : Tuple[torch.Tensor, torch.Tensor]
200+
Batch of input data and labels
201+
batch_idx : int
202+
Index of the current batch
203+
204+
Returns
205+
-------
206+
torch.Tensor
207+
Computed loss
208+
"""
119209
loss = self._step(batch, batch_idx)
120210
self.log('val_loss', loss)
121211
return loss
122212

123-
def test_step(self, batch, batch_idx) -> torch.Tensor:
213+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
214+
"""
215+
Lightning method for test step.
216+
217+
Parameters
218+
----------
219+
batch : Tuple[torch.Tensor, torch.Tensor]
220+
Batch of input data and labels
221+
batch_idx : int
222+
Index of the current batch
223+
224+
Returns
225+
-------
226+
torch.Tensor
227+
Computed loss
228+
"""
124229
loss = self._step(batch, batch_idx)
125230
self.log('test_loss', loss)
126231
return loss
127232

128-
def configure_optimizers(self):
129-
args = self.args
233+
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]:
234+
"""
235+
Configure optimizers and learning rate schedulers.
236+
237+
Returns
238+
-------
239+
Tuple[List[torch.optim.Optimizer], List[Dict]]
240+
Tuple containing a list of optimizers and a list of scheduler configurations
241+
"""
130242
optimizer = torch.optim.AdamW(
131-
self.parameters(), weight_decay=args['weight_decay'])
243+
self.parameters(), weight_decay=self.args['weight_decay'])
132244
scheduler = torch.optim.lr_scheduler.OneCycleLR(
133-
optimizer, max_lr=args['lr'],
134-
epochs=args['epochs'],
245+
optimizer, max_lr=self.args['lr'],
246+
epochs=self.args['epochs'],
135247
steps_per_epoch=len(
136248
self.trainer._data_connector._train_dataloader_source.dataloader()
137249
)

0 commit comments

Comments
 (0)