1
- """
2
- """
1
+ """"""
3
2
from pathlib import Path
4
- from typing import TYPE_CHECKING , List , Tuple
3
+ from typing import TYPE_CHECKING , List
5
4
6
- from srai .utils ._optional import import_optional_dependencies
7
5
from srai .embedders .hex2vec .model import Hex2VecModel
6
+ from srai .utils ._optional import import_optional_dependencies
8
7
9
8
if TYPE_CHECKING : # pragma: no cover
10
9
import torch
18
17
19
18
20
19
class Hex2VecModelForRegionClassification (LightningModule ): # type: ignore
21
- """
22
- Hex2Vec classification model.
23
- """
20
+ """Hex2Vec classification model."""
24
21
25
- def __init__ (self , hex2vec_layer_sizes : List [int ], n_classes : int , learning_rate : float = 0.001 ):
22
+ def __init__ (
23
+ self , hex2vec_layer_sizes : List [int ], n_classes : int , learning_rate : float = 0.001
24
+ ):
26
25
"""
27
26
Initialize Hex2VecModel.
28
27
@@ -39,9 +38,8 @@ def __init__(self, hex2vec_layer_sizes: List[int], n_classes: int, learning_rate
39
38
super ().__init__ ()
40
39
self .learning_rate = learning_rate
41
40
self .n_classes = n_classes
42
- self .hex2vec_model = Hex2VecModel (layer_sizes = layer_sizes )
43
- self .classification_head = nn .Linear (layer_sizes [- 1 ], n_classes )
44
-
41
+ self .hex2vec_model = Hex2VecModel (layer_sizes = hex2vec_layer_sizes )
42
+ self .classification_head = nn .Linear (hex2vec_layer_sizes [- 1 ], n_classes )
45
43
46
44
def forward (self , X_anchor : "torch.Tensor" ) -> "torch.Tensor" :
47
45
"""
@@ -50,31 +48,27 @@ def forward(self, X_anchor: "torch.Tensor") -> "torch.Tensor":
50
48
Args:
51
49
X_anchor (torch.Tensor): Region features.
52
50
"""
53
- import torch
54
51
import torch .nn .functional as F
55
- from torchmetrics . functional import f1_score as f1
52
+
56
53
x = self .hex2vec_model (X_anchor )
57
54
x = F .relu (x )
58
55
x = F .log_softmax (self .classification_head (x ), dim = 1 )
59
56
return x
60
57
61
-
62
-
63
58
def training_step (self , batch : List ["torch.Tensor" ], batch_idx : int ) -> "torch.Tensor" :
64
- """
65
- """
59
+ """"""
66
60
import torch
67
61
import torch .nn .functional as F
68
62
from torchmetrics .functional import f1_score as f1
69
63
70
64
x , y = batch
71
65
logits = self (x )
72
66
loss = F .nll_loss (logits , y )
73
-
67
+
74
68
preds = torch .argmax (logits , dim = 1 )
75
69
acc = self .accuracy (preds , y )
76
- train_f1 = f1 (preds , y , task = "multiclass" , num_classes = self .n_classes , average = "macro" )
77
-
70
+ f_score = f1 (preds , y , task = "multiclass" , num_classes = self .n_classes , average = "macro" )
71
+
78
72
self .log ("train_loss" , loss , on_step = True , on_epoch = True )
79
73
self .log ("train_acc" , acc , on_step = True , on_epoch = True )
80
74
self .log ("train_f1" , f_score , on_step = True , on_epoch = True )
@@ -89,11 +83,11 @@ def validation_step(self, batch: List["torch.Tensor"], batch_idx: int) -> "torch
89
83
x , y = batch
90
84
logits = self (x )
91
85
loss = F .nll_loss (logits , y )
92
-
86
+
93
87
preds = torch .argmax (logits , dim = 1 )
94
88
acc = self .accuracy (preds , y )
95
- val_f1 = f1 (preds , y , task = "multiclass" , num_classes = self .n_classes , average = "macro" )
96
-
89
+ f_score = f1 (preds , y , task = "multiclass" , num_classes = self .n_classes , average = "macro" )
90
+
97
91
self .log ("val_loss" , loss , on_step = True , on_epoch = True )
98
92
self .log ("val_acc" , acc , on_step = True , on_epoch = True )
99
93
self .log ("val_f1" , f_score , on_step = True , on_epoch = True )
@@ -108,7 +102,11 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":
108
102
109
103
def get_kwargs (self ) -> dict :
110
104
"""Get model save kwargs."""
111
- return {"hex2vec_layer_sizes" : self .hex2vec_layer_sizes , "n_classes" : n_classes , "learning_rate" : self .learning_rate }
105
+ return {
106
+ "hex2vec_layer_sizes" : self .hex2vec_layer_sizes ,
107
+ "n_classes" : self .n_classes ,
108
+ "learning_rate" : self .learning_rate ,
109
+ }
112
110
113
111
@classmethod
114
112
def load (cls , path : Path , ** kwargs : dict ) -> "Hex2VecModelForRegionClassification" :
@@ -124,5 +122,3 @@ def load(cls, path: Path, **kwargs: dict) -> "Hex2VecModelForRegionClassificatio
124
122
model = cls (** kwargs )
125
123
model .load_state_dict (torch .load (path ))
126
124
return model
127
-
128
-
0 commit comments