Skip to content

Commit c75ace6

Browse files
authored
Merge pull request #241 from basf/master
Release Version 1.3.0
2 parents 9bdd113 + 180f126 commit c75ace6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2737
-937
lines changed

README.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ Mambular is a Python package that brings the power of advanced deep learning arc
7676
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
7777
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
7878
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
79-
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
79+
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
80+
| `AutoInt` | Automatic Feature Interaction Learning via Self-Attentive Neural Networks introduced [here](https://arxiv.org/abs/1810.11921). |
81+
| `Trompt ` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
82+
8083

8184

8285

@@ -211,13 +214,13 @@ random_search.fit(X, y, **fit_params)
211214
print("Best Parameters:", random_search.best_params_)
212215
print("Best Score:", random_search.best_score_)
213216
```
214-
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
217+
Note, that using this, you can also optimize the preprocessing. Just specify the necessary parameters when specifying the preprocessor arguments you want to optimize:
215218
```python
216219
param_dist = {
217220
'd_model': randint(32, 128),
218221
'n_layers': randint(2, 10),
219222
'lr': uniform(1e-5, 1e-3),
220-
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
223+
"numerical_preprocessing": ["ple", "standardization", "box-cox"]
221224
}
222225

223226
```
@@ -321,7 +324,7 @@ Here's how you can implement a custom model with Mambular:
321324
Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
322325

323326
```python
324-
from mambular.base_models import BaseModel
327+
from mambular.base_models.utils import BaseModel
325328
from mambular.utils.get_feature_dimensions import get_feature_dimensions
326329
import torch
327330
import torch.nn
@@ -365,7 +368,7 @@ Here's how you can implement a custom model with Mambular:
365368
You can build a regression, classification, or distributional regression model that can leverage all of Mambular's built-in methods by using the following:
366369

367370
```python
368-
from mambular.models import SklearnBaseRegressor
371+
from mambular.models.utils import SklearnBaseRegressor
369372

370373
class MyRegressor(SklearnBaseRegressor):
371374
def __init__(self, **kwargs):

mambular/__version__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
#
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
19-
__version__ = "1.2.1"
19+
20+
__version__ = "1.3.0"
21+

mambular/arch_utils/enode_utils.py

+305
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from mambular.arch_utils.layer_utils.sparsemax import sparsemax, sparsemoid
5+
from .data_aware_initialization import ModuleWithInit
6+
from .numpy_utils import check_numpy
7+
import numpy as np
8+
from warnings import warn
9+
10+
11+
class ODSTE(ModuleWithInit):
12+
13+
def __init__(
14+
self,
15+
in_features, # J (number of features)
16+
num_trees,
17+
embed_dim, # D (embedding dimension per feature)
18+
depth=6,
19+
tree_dim=1,
20+
flatten_output=True,
21+
choice_function=sparsemax,
22+
bin_function=sparsemoid,
23+
initialize_response_=nn.init.normal_,
24+
initialize_selection_logits_=nn.init.uniform_,
25+
threshold_init_beta=1.0,
26+
threshold_init_cutoff=1.0,
27+
):
28+
"""Oblivious Differentiable Sparsemax Trees (ODST) with Feature & Embedding Splitting."""
29+
super().__init__()
30+
self.depth, self.num_trees, self.tree_dim, self.flatten_output = (
31+
depth,
32+
num_trees,
33+
tree_dim,
34+
flatten_output,
35+
)
36+
self.choice_function, self.bin_function = choice_function, bin_function
37+
self.in_features, self.embed_dim = in_features, embed_dim
38+
self.threshold_init_beta, self.threshold_init_cutoff = (
39+
threshold_init_beta,
40+
threshold_init_cutoff,
41+
)
42+
43+
# Response values for each leaf
44+
self.response = nn.Parameter(
45+
torch.zeros([num_trees, tree_dim, embed_dim, 2**depth]), requires_grad=True
46+
)
47+
48+
initialize_response_(self.response)
49+
50+
# Feature selection logits (choose J)
51+
self.feature_selection_logits = nn.Parameter(
52+
torch.zeros([num_trees, depth, in_features]), requires_grad=True
53+
)
54+
initialize_selection_logits_(self.feature_selection_logits)
55+
56+
# Embedding selection logits (choose D within J)
57+
self.embedding_selection_logits = nn.Parameter(
58+
torch.randn([num_trees, depth, in_features, embed_dim])
59+
)
60+
61+
# Thresholds & temperatures (random initialization)
62+
self.feature_thresholds = nn.Parameter(torch.randn([num_trees, depth]))
63+
self.log_temperatures = nn.Parameter(torch.randn([num_trees, depth]))
64+
65+
# Binary code mappings
66+
with torch.no_grad():
67+
indices = torch.arange(2**self.depth)
68+
offsets = 2 ** torch.arange(self.depth)
69+
bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(
70+
torch.float32
71+
)
72+
bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
73+
self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)
74+
75+
def initialize(self, x, eps=1e-6):
76+
"""Data-aware initialization of thresholds and log-temperatures based on input data.
77+
78+
Parameters
79+
----------
80+
x : torch.Tensor
81+
Input tensor of shape [batch_size, in_features, embed_dim] used for threshold initialization.
82+
eps : float, optional
83+
Small value added to avoid log(0) errors in temperature initialization. Default is 1e-6.
84+
"""
85+
if len(x.shape) != 3:
86+
raise ValueError("Input tensor must have shape (batch_size, J, D)")
87+
88+
if x.shape[0] < 1000:
89+
warn(
90+
"Data-aware initialization is performed on less than 1000 data points. This may cause instability."
91+
"To avoid potential problems, run this model on a data batch with at least 1000 data samples."
92+
"You can do so manually before training. Use with torch.no_grad() for memory efficiency."
93+
)
94+
95+
with torch.no_grad():
96+
# Select features (J)
97+
feature_selectors = self.choice_function(
98+
self.feature_selection_logits, dim=-1
99+
)
100+
# feature_selectors shape: (num_trees, depth, J)
101+
102+
selected_features = torch.einsum("bjd,ntj->bntd", x, feature_selectors)
103+
# selected_features shape: (B, num_trees, depth, D)
104+
105+
# Select embeddings (D)
106+
embedding_selectors = self.choice_function(
107+
self.embedding_selection_logits, dim=-1
108+
)
109+
# embedding_selectors shape: (num_trees, depth, J, D)
110+
111+
selected_embeddings = torch.einsum(
112+
"bntd,ntjd->bntd", selected_features, embedding_selectors
113+
)
114+
# selected_embeddings shape: (B, num_trees, depth, D)
115+
116+
# Initialize thresholds using percentiles from the data
117+
percentiles_q = 100 * np.random.beta(
118+
self.threshold_init_beta,
119+
self.threshold_init_beta,
120+
size=[self.num_trees, self.depth],
121+
)
122+
123+
reshaped_embeddings = selected_embeddings.permute(1, 2, 0, 3).reshape(
124+
self.num_trees * self.depth, -1
125+
)
126+
self.feature_thresholds.data[...] = torch.as_tensor(
127+
list(
128+
map(
129+
np.percentile,
130+
check_numpy(reshaped_embeddings), # Now correctly 2D
131+
percentiles_q.flatten(),
132+
)
133+
),
134+
dtype=selected_embeddings.dtype,
135+
device=selected_embeddings.device,
136+
).view(self.num_trees, self.depth)
137+
138+
# Initialize temperatures based on the threshold differences
139+
temperatures = np.percentile(
140+
check_numpy(
141+
abs(selected_embeddings - self.feature_thresholds.unsqueeze(-1))
142+
),
143+
q=100 * min(1.0, self.threshold_init_cutoff),
144+
axis=0,
145+
)
146+
147+
# Scale temperatures based on the cutoff
148+
temperatures /= max(1.0, self.threshold_init_cutoff)
149+
150+
self.log_temperatures.data[...] = torch.log(
151+
torch.as_tensor(
152+
temperatures.mean(-1),
153+
dtype=selected_embeddings.dtype,
154+
device=selected_embeddings.device,
155+
)
156+
+ eps
157+
)
158+
159+
def forward(self, x):
160+
if len(x.shape) != 3:
161+
raise ValueError("Input tensor must have shape (batch_size, J, D)")
162+
163+
# Select feature (J) and embedding dimension (D) separately
164+
feature_selectors = self.choice_function(
165+
self.feature_selection_logits, dim=-1
166+
) # [num_trees, depth, J]
167+
168+
embedding_selectors = self.choice_function(
169+
self.embedding_selection_logits, dim=-1
170+
) # [num_trees, depth, J, D]
171+
172+
# Select features (J) first
173+
selected_features = torch.einsum("bjd,ntj->bntd", x, feature_selectors)
174+
175+
# Select embeddings (D) within selected features
176+
selected_embeddings = torch.einsum(
177+
"bntd,ntjd->bntd", selected_features, embedding_selectors
178+
)
179+
180+
# Compute threshold logits
181+
threshold_logits = (
182+
selected_embeddings - self.feature_thresholds.unsqueeze(0).unsqueeze(-1)
183+
) * torch.exp(-self.log_temperatures.unsqueeze(0).unsqueeze(-1))
184+
185+
threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1)
186+
187+
# Compute binary decisions
188+
bins = self.bin_function(threshold_logits)
189+
190+
bin_matches = torch.einsum("bntds,tcs->bntdc", bins, self.bin_codes_1hot)
191+
192+
response_weights = torch.prod(bin_matches, dim=2)
193+
194+
# Compute final response
195+
response = torch.einsum("bnds,ncds->bnd", response_weights, self.response)
196+
return response
197+
198+
def __repr__(self):
199+
return f"{self.__class__.__name__}(in_features={self.in_features}, embed_dim={self.embed_dim}, num_trees={self.num_trees}, depth={self.depth}, tree_dim={self.tree_dim}, flatten_output={self.flatten_output})"
200+
201+
202+
class DenseBlock(nn.Module):
203+
"""DenseBlock that sequentially stacks attention layers and `Module` layers (e.g., ODSTE)
204+
with feature and embedding-aware splits.
205+
206+
Parameters
207+
----------
208+
input_dim : int
209+
Number of features (J) in the input.
210+
embed_dim : int
211+
Embedding dimension per feature (D).
212+
layer_dim : int
213+
Dimensionality of each ODSTE layer.
214+
num_layers : int
215+
Number of layers to stack in the block.
216+
tree_dim : int, optional
217+
Number of output channels from each tree. Default is 1.
218+
max_features : int, optional
219+
Maximum number of features for expansion. Default is None.
220+
input_dropout : float, optional
221+
Dropout rate applied to inputs during training. Default is 0.0.
222+
flatten_output : bool, optional
223+
If True, flattens the output along the tree dimension. Default is True.
224+
Module : nn.Module, optional
225+
Module class to use for each layer in the block. Default is `ODSTE`.
226+
**kwargs : dict
227+
Additional keyword arguments for `Module` instances.
228+
"""
229+
230+
def __init__(
231+
self,
232+
input_dim,
233+
embed_dim,
234+
layer_dim,
235+
num_layers,
236+
tree_dim=1,
237+
max_features=None,
238+
input_dropout=0.0,
239+
flatten_output=True,
240+
Module=ODSTE,
241+
**kwargs,
242+
):
243+
super().__init__()
244+
self.num_layers = num_layers
245+
self.layer_dim = layer_dim
246+
self.tree_dim = tree_dim
247+
self.max_features = max_features
248+
self.input_dropout = input_dropout
249+
self.flatten_output = flatten_output
250+
251+
self.attention_layers = nn.ModuleList()
252+
self.odste_layers = nn.ModuleList()
253+
254+
for _ in range(num_layers):
255+
# self.attention_layers.append(
256+
# nn.MultiheadAttention(
257+
# embed_dim=embed_dim, num_heads=1, batch_first=True
258+
# )
259+
# )
260+
self.odste_layers.append(
261+
Module(
262+
in_features=input_dim,
263+
embed_dim=embed_dim,
264+
num_trees=layer_dim,
265+
tree_dim=tree_dim,
266+
flatten_output=True,
267+
**kwargs,
268+
)
269+
)
270+
input_dim = min(
271+
input_dim + layer_dim * tree_dim, max_features or float("inf")
272+
)
273+
274+
def forward(self, x):
275+
"""Forward pass through the DenseBlock.
276+
277+
Parameters
278+
----------
279+
x : torch.Tensor
280+
Input tensor of shape [batch_size, J, D].
281+
282+
Returns
283+
-------
284+
torch.Tensor
285+
Output tensor with expanded features.
286+
"""
287+
initial_features = x.shape[1] # J (num features)
288+
289+
for odste_layer in self.odste_layers:
290+
# x, _ = attn_layer(x, x, x) # Apply attention
291+
292+
if self.max_features is not None:
293+
tail_features = min(self.max_features, x.shape[1]) - initial_features
294+
if tail_features > 0:
295+
x = torch.cat(
296+
[x[:, :initial_features, :], x[:, -tail_features:, :]], dim=1
297+
)
298+
299+
if self.training and self.input_dropout:
300+
x = F.dropout(x, self.input_dropout)
301+
302+
h = odste_layer(x) # Apply ODSTE layer
303+
x = torch.cat([x, h], dim=1) # Concatenate new features
304+
305+
return x

mambular/arch_utils/layer_utils/embedding_layer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def forward(self, num_features, cat_features, emb_features):
156156
# Process categorical embeddings
157157
if self.cat_embeddings and cat_features is not None:
158158
cat_embeddings = [
159-
emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
159+
(
160+
emb(cat_features[i])
161+
if emb(cat_features[i]).ndim == 3
162+
else emb(cat_features[i]).unsqueeze(1)
163+
)
160164
for i, emb in enumerate(self.cat_embeddings)
161165
]
162166

0 commit comments

Comments
 (0)