Skip to content

Commit 6386483

Browse files
authored
Merge pull request #254 from basf/master
Release v1.4.0
2 parents 7886cb3 + 5db5426 commit 6386483

File tree

12 files changed

+554
-9
lines changed

12 files changed

+554
-9
lines changed

Diff for: mambular/__version__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.3.2"
20+
__version__ = "1.4.0"
21+

Diff for: mambular/arch_utils/simple_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class MLP_Block(nn.Module):
5+
def __init__(self, d_in: int, d: int, dropout: float):
6+
super().__init__()
7+
self.block = nn.Sequential(
8+
nn.BatchNorm1d(d_in),
9+
nn.Linear(d_in, d),
10+
nn.ReLU(inplace=True),
11+
nn.Dropout(dropout),
12+
nn.Linear(d, d_in)
13+
)
14+
def forward(self, x: torch.Tensor) -> torch.Tensor:
15+
return self.block(x)
16+
17+
18+
import torch
19+
20+
def make_random_batches(
21+
train_size: int, batch_size: int, device = None
22+
) :
23+
permutation = torch.randperm(train_size, device=device)
24+
batches = permutation.split(batch_size)
25+
26+
assert torch.equal(
27+
torch.arange(train_size, device=device), permutation.sort().values
28+
)
29+
return batches

Diff for: mambular/base_models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from .trompt import Trompt
1515
from .enode import ENODE
1616
from .tangos import Tangos
17+
from .modern_nca import ModernNCA
1718

1819
__all__ = [
20+
"ModernNCA",
1921
"Tangos",
2022
"ENODE",
2123
"Trompt",

Diff for: mambular/base_models/modern_nca.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
from ..utils.get_feature_dimensions import get_feature_dimensions
6+
from ..arch_utils.get_norm_fn import get_normalization_layer
7+
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
8+
from ..arch_utils.mlp_utils import MLPhead
9+
from ..configs.modernnca_config import DefaultModernNCAConfig
10+
from .utils.basemodel import BaseModel
11+
12+
13+
class ModernNCA(BaseModel):
14+
def __init__(
15+
self,
16+
feature_information: tuple,
17+
num_classes=1,
18+
config: DefaultModernNCAConfig = DefaultModernNCAConfig(), # noqa: B008
19+
**kwargs,
20+
):
21+
super().__init__(config=config, **kwargs)
22+
self.save_hyperparameters(ignore=["feature_information"])
23+
24+
self.returns_ensemble = False
25+
self.uses_nca_candidates = True
26+
27+
self.T = config.temperature
28+
self.sample_rate = config.sample_rate
29+
if self.hparams.use_embeddings:
30+
self.embedding_layer = EmbeddingLayer(
31+
*feature_information,
32+
config=config,
33+
)
34+
input_dim = np.sum(
35+
[len(info) * self.hparams.d_model for info in feature_information]
36+
)
37+
else:
38+
input_dim = get_feature_dimensions(*feature_information)
39+
40+
self.encoder = nn.Linear(input_dim, config.dim)
41+
42+
if config.n_blocks > 0:
43+
self.post_encoder = nn.Sequential(
44+
*[self.make_layer(config) for _ in range(config.n_blocks)],
45+
nn.BatchNorm1d(config.dim),
46+
)
47+
48+
self.tabular_head = MLPhead(
49+
input_dim=config.dim,
50+
config=config,
51+
output_dim=num_classes,
52+
)
53+
54+
self.hparams.num_classes = num_classes
55+
56+
def make_layer(self, config):
57+
return nn.Sequential(
58+
nn.BatchNorm1d(config.dim),
59+
nn.Linear(config.dim, config.d_block),
60+
nn.ReLU(inplace=True),
61+
nn.Dropout(config.dropout),
62+
nn.Linear(config.d_block, config.dim),
63+
)
64+
65+
def forward(self, *data):
66+
"""Standard forward pass without candidate selection (for baseline compatibility)."""
67+
if self.hparams.use_embeddings:
68+
x = self.embedding_layer(*data)
69+
B, S, D = x.shape
70+
x = x.reshape(B, S * D)
71+
else:
72+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
73+
x = self.encoder(x)
74+
if hasattr(self, "post_encoder"):
75+
x = self.post_encoder(x)
76+
return self.tabular_head(x)
77+
78+
def nca_train(self, *data, targets, candidate_x, candidate_y):
79+
"""NCA-style training forward pass selecting candidates."""
80+
if self.hparams.use_embeddings:
81+
x = self.embedding_layer(*data)
82+
B, S, D = x.shape
83+
x = x.reshape(B, S * D)
84+
candidate_x = self.embedding_layer(*candidate_x)
85+
B, S, D = candidate_x.shape
86+
candidate_x = candidate_x.reshape(B, S * D)
87+
else:
88+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
89+
candidate_x = torch.cat(
90+
[t for tensors in candidate_x for t in tensors], dim=1
91+
)
92+
93+
# Encode input
94+
x = self.encoder(x)
95+
candidate_x = self.encoder(candidate_x)
96+
97+
if hasattr(self, "post_encoder"):
98+
x = self.post_encoder(x)
99+
candidate_x = self.post_encoder(candidate_x)
100+
101+
# Select a subset of candidates
102+
data_size = candidate_x.shape[0]
103+
retrieval_size = int(data_size * self.sample_rate)
104+
sample_idx = torch.randperm(data_size)[:retrieval_size]
105+
candidate_x = candidate_x[sample_idx]
106+
candidate_y = candidate_y[sample_idx]
107+
108+
# Concatenate with training batch
109+
candidate_x = torch.cat([x, candidate_x], dim=0)
110+
candidate_y = torch.cat([targets, candidate_y], dim=0)
111+
112+
# One-hot encode if classification
113+
if self.hparams.num_classes > 1:
114+
candidate_y = F.one_hot(
115+
candidate_y, num_classes=self.hparams.num_classes
116+
).to(x.dtype)
117+
elif len(candidate_y.shape) == 1:
118+
candidate_y = candidate_y.unsqueeze(-1)
119+
120+
# Compute distances
121+
distances = torch.cdist(x, candidate_x, p=2) / self.T
122+
# remove the label of training index
123+
distances = distances.fill_diagonal_(torch.inf)
124+
distances = F.softmax(-distances, dim=-1)
125+
logits = torch.mm(distances, candidate_y)
126+
eps = 1e-7
127+
if self.hparams.num_classes > 1:
128+
logits = torch.log(logits + eps)
129+
130+
return logits
131+
132+
def nca_validate(self, *data, candidate_x, candidate_y):
133+
"""Validation forward pass with NCA-style candidate selection."""
134+
if self.hparams.use_embeddings:
135+
x = self.embedding_layer(*data)
136+
B, S, D = x.shape
137+
x = x.reshape(B, S * D)
138+
candidate_x = self.embedding_layer(*candidate_x)
139+
B, S, D = candidate_x.shape
140+
candidate_x = candidate_x.reshape(B, S * D)
141+
else:
142+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
143+
candidate_x = torch.cat(
144+
[t for tensors in candidate_x for t in tensors], dim=1
145+
)
146+
147+
# Encode input
148+
x = self.encoder(x)
149+
candidate_x = self.encoder(candidate_x)
150+
151+
if hasattr(self, "post_encoder"):
152+
x = self.post_encoder(x)
153+
candidate_x = self.post_encoder(candidate_x)
154+
155+
# One-hot encode if classification
156+
if self.hparams.num_classes > 1:
157+
candidate_y = F.one_hot(
158+
candidate_y, num_classes=self.hparams.num_classes
159+
).to(x.dtype)
160+
elif len(candidate_y.shape) == 1:
161+
candidate_y = candidate_y.unsqueeze(-1)
162+
163+
# Compute distances
164+
distances = torch.cdist(x, candidate_x, p=2) / self.T
165+
distances = F.softmax(-distances, dim=-1)
166+
167+
# Compute logits
168+
logits = torch.mm(distances, candidate_y)
169+
eps = 1e-7
170+
if self.hparams.num_classes > 1:
171+
logits = torch.log(logits + eps)
172+
173+
return logits
174+
175+
def nca_predict(self, *data, candidate_x, candidate_y):
176+
"""Prediction forward pass with candidate selection."""
177+
if self.hparams.use_embeddings:
178+
x = self.embedding_layer(*data)
179+
B, S, D = x.shape
180+
x = x.reshape(B, S * D)
181+
candidate_x = self.embedding_layer(*candidate_x)
182+
B, S, D = candidate_x.shape
183+
candidate_x = candidate_x.reshape(B, S * D)
184+
else:
185+
x = torch.cat([t for tensors in data for t in tensors], dim=1)
186+
candidate_x = torch.cat(
187+
[t for tensors in candidate_x for t in tensors], dim=1
188+
)
189+
190+
# Encode input
191+
x = self.encoder(x)
192+
candidate_x = self.encoder(candidate_x)
193+
194+
if hasattr(self, "post_encoder"):
195+
x = self.post_encoder(x)
196+
candidate_x = self.post_encoder(candidate_x)
197+
198+
# One-hot encode if classification
199+
if self.hparams.num_classes > 1:
200+
candidate_y = F.one_hot(
201+
candidate_y, num_classes=self.hparams.num_classes
202+
).to(x.dtype)
203+
elif len(candidate_y.shape) == 1:
204+
candidate_y = candidate_y.unsqueeze(-1)
205+
206+
# Compute distances
207+
distances = torch.cdist(x, candidate_x, p=2) / self.T
208+
distances = F.softmax(-distances, dim=-1)
209+
210+
# Compute logits
211+
logits = torch.mm(distances, candidate_y)
212+
eps = 1e-7
213+
if self.hparams.num_classes > 1:
214+
logits = torch.log(logits + eps)
215+
216+
return logits

Diff for: mambular/base_models/utils/lightning_wrapper.py

+64-7
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,39 @@ def __init__(
9696
**kwargs,
9797
)
9898

99+
def setup(self, stage=None):
100+
if stage == "fit" and hasattr(self.estimator, "uses_nca_candidates"):
101+
all_train_num = []
102+
all_train_cat = []
103+
all_train_embeddings = []
104+
all_train_targets = []
105+
106+
device = self.device if hasattr(self, "device") else self.trainer.device
107+
108+
for batch in self.trainer.datamodule.train_dataloader():
109+
(num_features, cat_features, embeddings), labels = batch
110+
111+
all_train_num.append([f.to(device) for f in num_features]) # Keep lists
112+
all_train_cat.append([f.to(device) for f in cat_features]) # Keep lists
113+
if embeddings is not None:
114+
all_train_embeddings.append([f.to(device) for f in embeddings])
115+
all_train_targets.append(labels.to(device))
116+
117+
# Maintain structure: each feature type remains a list of tensors
118+
self.train_features = (
119+
[torch.cat(features, dim=0) for features in zip(*all_train_num)],
120+
[torch.cat(features, dim=0) for features in zip(*all_train_cat)],
121+
(
122+
[
123+
torch.cat(features, dim=0)
124+
for features in zip(*all_train_embeddings)
125+
]
126+
if all_train_embeddings
127+
else None
128+
),
129+
)
130+
self.train_targets = torch.cat(all_train_targets, dim=0)
131+
99132
def forward(self, num_features, cat_features, embeddings):
100133
"""Forward pass through the model.
101134
@@ -184,7 +217,7 @@ def training_step(self, batch, batch_idx): # type: ignore
184217
Index of the batch.
185218
186219
Returns
187-
-------
220+
------
188221
Tensor
189222
Training loss.
190223
"""
@@ -194,6 +227,14 @@ def training_step(self, batch, batch_idx): # type: ignore
194227
if hasattr(self.estimator, "penalty_forward"):
195228
preds, penalty = self.estimator.penalty_forward(*data)
196229
loss = self.compute_loss(preds, labels) + penalty
230+
elif hasattr(self.estimator, "uses_nca_candidates"):
231+
preds = self.estimator.nca_train(
232+
*data,
233+
targets=labels,
234+
candidate_x=self.train_features,
235+
candidate_y=self.train_targets,
236+
)
237+
loss = self.compute_loss(preds, labels)
197238
else:
198239
preds = self(*data)
199240
loss = self.compute_loss(preds, labels)
@@ -234,7 +275,12 @@ def validation_step(self, batch, batch_idx): # type: ignore
234275
"""
235276

236277
data, labels = batch
237-
preds = self(*data)
278+
if hasattr(self.estimator, "nca_validate") and self.train_features is not None:
279+
preds = self.estimator.nca_validate(
280+
*data, candidate_x=self.train_features, candidate_y=self.train_targets
281+
)
282+
else:
283+
preds = self(*data)
238284
val_loss = self.compute_loss(preds, labels)
239285

240286
self.log(
@@ -276,7 +322,12 @@ def test_step(self, batch, batch_idx): # type: ignore
276322
Test loss.
277323
"""
278324
data, labels = batch
279-
preds = self(*data)
325+
if hasattr(self.estimator, "nca_predict") and self.train_features is not None:
326+
preds = self.estimator.nca_predict(
327+
*data, candidates_x=self.train_features, candidates_y=self.train_targets
328+
)
329+
else:
330+
preds = self(*data)
280331
test_loss = self.compute_loss(preds, labels)
281332

282333
self.log(
@@ -305,8 +356,14 @@ def predict_step(self, batch, batch_idx):
305356
Tensor
306357
Predictions.
307358
"""
308-
309-
preds = self(*batch)
359+
if hasattr(self.estimator, "nca_predict") and self.train_features is not None:
360+
preds = self.estimator.nca_predict(
361+
*batch,
362+
candidate_x=self.train_features,
363+
candidate_y=self.train_targets,
364+
)
365+
else:
366+
preds = self(*batch)
310367

311368
return preds
312369

@@ -425,7 +482,7 @@ def pretrain_embeddings(
425482
temperature=0.1,
426483
save_path="pretrained_embeddings.pth",
427484
regression=True,
428-
lr=1e-04
485+
lr=1e-04,
429486
):
430487
"""Pretrain embeddings before full model training.
431488
@@ -594,7 +651,7 @@ def contrastive_loss(self, embeddings, knn_indices, temperature=0.1):
594651
) # Shape: (N * k_neighbors)
595652

596653
# Compute cosine embedding loss
597-
loss += -1.0*loss_fn(embeddings_s, positive_pairs, labels)
654+
loss += -1.0 * loss_fn(embeddings_s, positive_pairs, labels)
598655

599656
# Average loss across all sequence steps
600657
loss /= S

0 commit comments

Comments
 (0)