Skip to content

Commit f1da952

Browse files
committed
fix
1 parent 496dfc5 commit f1da952

File tree

2 files changed

+35
-49
lines changed

2 files changed

+35
-49
lines changed

Diff for: srai/classification/hex2vec.py

+22-26
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
"""
2-
"""
1+
""""""
32
from pathlib import Path
4-
from typing import TYPE_CHECKING, List, Tuple
3+
from typing import TYPE_CHECKING, List
54

6-
from srai.utils._optional import import_optional_dependencies
75
from srai.embedders.hex2vec.model import Hex2VecModel
6+
from srai.utils._optional import import_optional_dependencies
87

98
if TYPE_CHECKING: # pragma: no cover
109
import torch
@@ -18,11 +17,11 @@
1817

1918

2019
class Hex2VecModelForRegionClassification(LightningModule): # type: ignore
21-
"""
22-
Hex2Vec classification model.
23-
"""
20+
"""Hex2Vec classification model."""
2421

25-
def __init__(self, hex2vec_layer_sizes: List[int], n_classes: int, learning_rate: float = 0.001):
22+
def __init__(
23+
self, hex2vec_layer_sizes: List[int], n_classes: int, learning_rate: float = 0.001
24+
):
2625
"""
2726
Initialize Hex2VecModel.
2827
@@ -39,9 +38,8 @@ def __init__(self, hex2vec_layer_sizes: List[int], n_classes: int, learning_rate
3938
super().__init__()
4039
self.learning_rate = learning_rate
4140
self.n_classes = n_classes
42-
self.hex2vec_model = Hex2VecModel(layer_sizes=layer_sizes)
43-
self.classification_head = nn.Linear(layer_sizes[-1], n_classes)
44-
41+
self.hex2vec_model = Hex2VecModel(layer_sizes=hex2vec_layer_sizes)
42+
self.classification_head = nn.Linear(hex2vec_layer_sizes[-1], n_classes)
4543

4644
def forward(self, X_anchor: "torch.Tensor") -> "torch.Tensor":
4745
"""
@@ -50,31 +48,27 @@ def forward(self, X_anchor: "torch.Tensor") -> "torch.Tensor":
5048
Args:
5149
X_anchor (torch.Tensor): Region features.
5250
"""
53-
import torch
5451
import torch.nn.functional as F
55-
from torchmetrics.functional import f1_score as f1
52+
5653
x = self.hex2vec_model(X_anchor)
5754
x = F.relu(x)
5855
x = F.log_softmax(self.classification_head(x), dim=1)
5956
return x
6057

61-
62-
6358
def training_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch.Tensor":
64-
"""
65-
"""
59+
""""""
6660
import torch
6761
import torch.nn.functional as F
6862
from torchmetrics.functional import f1_score as f1
6963

7064
x, y = batch
7165
logits = self(x)
7266
loss = F.nll_loss(logits, y)
73-
67+
7468
preds = torch.argmax(logits, dim=1)
7569
acc = self.accuracy(preds, y)
76-
train_f1 = f1(preds, y, task="multiclass", num_classes=self.n_classes, average="macro")
77-
70+
f_score = f1(preds, y, task="multiclass", num_classes=self.n_classes, average="macro")
71+
7872
self.log("train_loss", loss, on_step=True, on_epoch=True)
7973
self.log("train_acc", acc, on_step=True, on_epoch=True)
8074
self.log("train_f1", f_score, on_step=True, on_epoch=True)
@@ -89,11 +83,11 @@ def validation_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch
8983
x, y = batch
9084
logits = self(x)
9185
loss = F.nll_loss(logits, y)
92-
86+
9387
preds = torch.argmax(logits, dim=1)
9488
acc = self.accuracy(preds, y)
95-
val_f1 = f1(preds, y, task="multiclass", num_classes=self.n_classes, average="macro")
96-
89+
f_score = f1(preds, y, task="multiclass", num_classes=self.n_classes, average="macro")
90+
9791
self.log("val_loss", loss, on_step=True, on_epoch=True)
9892
self.log("val_acc", acc, on_step=True, on_epoch=True)
9993
self.log("val_f1", f_score, on_step=True, on_epoch=True)
@@ -108,7 +102,11 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":
108102

109103
def get_kwargs(self) -> dict:
110104
"""Get model save kwargs."""
111-
return {"hex2vec_layer_sizes": self.hex2vec_layer_sizes, "n_classes": n_classes, "learning_rate": self.learning_rate}
105+
return {
106+
"hex2vec_layer_sizes": self.hex2vec_layer_sizes,
107+
"n_classes": self.n_classes,
108+
"learning_rate": self.learning_rate,
109+
}
112110

113111
@classmethod
114112
def load(cls, path: Path, **kwargs: dict) -> "Hex2VecModelForRegionClassification":
@@ -124,5 +122,3 @@ def load(cls, path: Path, **kwargs: dict) -> "Hex2VecModelForRegionClassificatio
124122
model = cls(**kwargs)
125123
model.load_state_dict(torch.load(path))
126124
return model
127-
128-

Diff for: srai/regression/hex2vec.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
"""
2-
"""
1+
""""""
32
from pathlib import Path
4-
from typing import TYPE_CHECKING, List, Tuple
3+
from typing import TYPE_CHECKING, List
54

6-
from srai.utils._optional import import_optional_dependencies
75
from srai.embedders.hex2vec.model import Hex2VecModel
6+
from srai.utils._optional import import_optional_dependencies
87

98
if TYPE_CHECKING: # pragma: no cover
109
import torch
@@ -18,9 +17,7 @@
1817

1918

2019
class Hex2VecModelForRegionRegression(LightningModule): # type: ignore
21-
"""
22-
Hex2Vec regression model.
23-
"""
20+
"""Hex2Vec regression model."""
2421

2522
def __init__(self, hex2vec_layer_sizes: List[int], learning_rate: float = 0.001):
2623
"""
@@ -39,27 +36,20 @@ def __init__(self, hex2vec_layer_sizes: List[int], learning_rate: float = 0.001)
3936
self.n_classes = n_classes
4037
self.hex2vec_model = Hex2VecModel(layer_sizes=layer_sizes)
4138
self.regression_head = nn.Linear(layer_sizes[-1], 1)
42-
4339

4440
def forward(self, X_anchor: "torch.Tensor") -> "torch.Tensor":
45-
"""
46-
"""
47-
import torch
41+
""""""
4842
import torch.nn.functional as F
49-
from torchmetrics.functional import f1_score as f1
43+
5044
x = self.hex2vec_model(X_anchor)
5145
x = F.relu(x)
5246
x = self.regression_head(x)
5347
return x
5448

55-
56-
5749
def training_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch.Tensor":
58-
"""
59-
"""
60-
import torch
50+
""""""
6151
import torch.nn.functional as F
62-
from torchmetrics.functional import mean_squared_error, mean_absolute_error
52+
from torchmetrics.functional import mean_absolute_error, mean_squared_error
6353

6454
x, y = batch
6555
y_hat = self.forward(x)
@@ -73,9 +63,8 @@ def training_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch.T
7363
return loss
7464

7565
def validation_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch.Tensor":
76-
import torch
7766
import torch.nn.functional as F
78-
from torchmetrics.functional import mean_squared_error, mean_absolute_error
67+
from torchmetrics.functional import mean_absolute_error, mean_squared_error
7968

8069
x, y = batch
8170
y_hat = self.forward(x)
@@ -96,7 +85,10 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":
9685

9786
def get_kwargs(self) -> dict:
9887
"""Get model save kwargs."""
99-
return {"hex2vec_layer_sizes": self.hex2vec_layer_sizes, "learning_rate": self.learning_rate}
88+
return {
89+
"hex2vec_layer_sizes": self.hex2vec_layer_sizes,
90+
"learning_rate": self.learning_rate,
91+
}
10092

10193
@classmethod
10294
def load(cls, path: Path, **kwargs: dict) -> "Hex2VecModelForRegionRegression":
@@ -112,5 +104,3 @@ def load(cls, path: Path, **kwargs: dict) -> "Hex2VecModelForRegionRegression":
112104
model = cls(**kwargs)
113105
model.load_state_dict(torch.load(path))
114106
return model
115-
116-

0 commit comments

Comments
 (0)