Skip to content

Commit 0a829f6

Browse files
authored
Merge pull request #171 from basf/develop
Version 1.0.0
2 parents d631275 + 7e22659 commit 0a829f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+9046
-1948
lines changed

README.md

+156-74
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616
</div>
1717

1818
<div style="text-align: center;">
19-
<h1>Mambular: Tabular Deep Learning (with Mamba)</h1>
19+
<h1>Mambular: Tabular Deep Learning Made Simple</h1>
2020
</div>
2121

22-
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291).
22+
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
2323

2424
<h3> Table of Contents </h3>
2525

2626
- [🏃 Quickstart](#-quickstart)
2727
- [📖 Introduction](#-introduction)
2828
- [🤖 Models](#-models)
29-
- [🏆 Results](#-results)
3029
- [📚 Documentation](#-documentation)
3130
- [🛠️ Installation](#️-installation)
3231
- [🚀 Usage](#-usage)
3332
- [💻 Implement Your Own Model](#-implement-your-own-model)
33+
- [Custom Training](#custom-training)
3434
- [🏷️ Citation](#️-citation)
3535
- [License](#license)
3636

@@ -55,75 +55,24 @@ Mambular is a Python package that brings the power of advanced deep learning arc
5555

5656
| Model | Description |
5757
| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
58-
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
58+
| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). |
59+
| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
60+
| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
5961
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
6062
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
6163
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
6264
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
6365
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
64-
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
66+
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
67+
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
68+
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
69+
6570

6671

6772
All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
6873
Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
6974

7075

71-
# 🏆 Results
72-
Detailed results for the available methods can be found [here](https://arxiv.org/abs/2408.06291).
73-
Note, that these are achieved results with default hyperparameter and for our splits. Performing hyperparameter optimization could improve the performance of all models.
74-
75-
The average rank table over all models and all datasets is given here:
76-
77-
<div align="center">
78-
79-
<table>
80-
<tr>
81-
<th style="text-align:center;">Model</th>
82-
<th style="text-align:center;">Avg. Rank</th>
83-
</tr>
84-
<tr>
85-
<td style="text-align:center;"><strong>Mambular</strong></td>
86-
<td style="text-align:center;"><strong>2.083</strong> <sub>±1.037</sub></td>
87-
</tr>
88-
<tr>
89-
<td style="text-align:center;">FT-Transformer</td>
90-
<td style="text-align:center;">2.417 <sub>±1.256</sub></td>
91-
</tr>
92-
<tr>
93-
<td style="text-align:center;">XGBoost</td>
94-
<td style="text-align:center;">3.167 <sub>±2.577</sub></td>
95-
</tr>
96-
<tr>
97-
<td style="text-align:center;">MambaTab*</td>
98-
<td style="text-align:center;">4.333 <sub>±1.374</sub></td>
99-
</tr>
100-
<tr>
101-
<td style="text-align:center;">ResNet</td>
102-
<td style="text-align:center;">4.750 <sub>±1.639</sub></td>
103-
</tr>
104-
<tr>
105-
<td style="text-align:center;">TabTransformer</td>
106-
<td style="text-align:center;">6.222 <sub>±1.618</sub></td>
107-
</tr>
108-
<tr>
109-
<td style="text-align:center;">MLP</td>
110-
<td style="text-align:center;">6.500 <sub>±1.500</sub></td>
111-
</tr>
112-
<tr>
113-
<td style="text-align:center;">MambaTab</td>
114-
<td style="text-align:center;">6.583 <sub>±1.801</sub></td>
115-
</tr>
116-
<tr>
117-
<td style="text-align:center;">MambaTab<sup>T</sup></td>
118-
<td style="text-align:center;">7.917 <sub>±1.187</sub></td>
119-
</tr>
120-
</table>
121-
122-
</div>
123-
124-
125-
126-
12776
# 📚 Documentation
12877

12978
You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/).
@@ -135,6 +84,19 @@ Install Mambular using pip:
13584
pip install mambular
13685
```
13786

87+
If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:
88+
89+
```sh
90+
pip install mamba-ssm
91+
```
92+
93+
Be careful to use the correct torch and cuda versions:
94+
95+
```sh
96+
pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html
97+
pip install mamba-ssm
98+
```
99+
138100
# 🚀 Usage
139101

140102
<h2> Preprocessing </h2>
@@ -143,12 +105,18 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t
143105

144106
<h3> Data Type Detection and Transformation </h3>
145107

146-
- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats.
147-
- **Binning**: Discretizes numerical features; can use decision trees for optimal binning.
148-
- **Normalization & Standardization**: Scales numerical data appropriately.
149-
- **Periodic Linear Encoding (PLE)**: Encodes periodicity in numerical data.
150-
- **Quantile & Spline Transformations**: Applies advanced transformations to handle nonlinearity and distributional shifts.
151-
- **Polynomial Features**: Generates polynomial and interaction terms to capture complex relationships.
108+
- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats using continuous ordinal encoding or one-hot encoding. Includes options for transforming outputs to `float` for compatibility with downstream models.
109+
- **Binning**: Discretizes numerical features into bins, with support for both fixed binning strategies and optimal binning derived from decision tree models.
110+
- **MinMax**: Scales numerical data to a specific range, such as [-1, 1], using Min-Max scaling or similar techniques.
111+
- **Standardization**: Centers and scales numerical features to have a mean of zero and unit variance for better compatibility with certain models.
112+
- **Quantile Transformations**: Normalizes numerical data to follow a uniform or normal distribution, handling distributional shifts effectively.
113+
- **Spline Transformations**: Captures nonlinearity in numerical features using spline-based transformations, ideal for complex relationships.
114+
- **Piecewise Linear Encodings (PLE)**: Captures complex numerical patterns by applying piecewise linear encoding, suitable for data with periodic or nonlinear structures.
115+
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
116+
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
117+
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.
118+
119+
152120

153121

154122
<h2> Fit a Model </h2>
@@ -159,9 +127,10 @@ from mambular.models import MambularClassifier
159127
# Initialize and fit your model
160128
model = MambularClassifier(
161129
d_model=64,
162-
n_layers=8,
130+
n_layers=4,
163131
numerical_preprocessing="ple",
164-
n_bins=50
132+
n_bins=50,
133+
d_conv=8
165134
)
166135

167136
# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
@@ -177,6 +146,59 @@ preds = model.predict(X)
177146
preds = model.predict_proba(X)
178147
```
179148

149+
<h3> Hyperparameter Optimization</h3>
150+
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
151+
152+
```python
153+
from sklearn.model_selection import RandomizedSearchCV
154+
155+
param_dist = {
156+
'd_model': randint(32, 128),
157+
'n_layers': randint(2, 10),
158+
'lr': uniform(1e-5, 1e-3)
159+
}
160+
161+
random_search = RandomizedSearchCV(
162+
estimator=model,
163+
param_distributions=param_dist,
164+
n_iter=50, # Number of parameter settings sampled
165+
cv=5, # 5-fold cross-validation
166+
scoring='accuracy', # Metric to optimize
167+
random_state=42
168+
)
169+
170+
fit_params = {"max_epochs":5, "rebuild":False}
171+
172+
# Fit the model
173+
random_search.fit(X, y, **fit_params)
174+
175+
# Best parameters and score
176+
print("Best Parameters:", random_search.best_params_)
177+
print("Best Score:", random_search.best_score_)
178+
```
179+
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
180+
```python
181+
param_dist = {
182+
'd_model': randint(32, 128),
183+
'n_layers': randint(2, 10),
184+
'lr': uniform(1e-5, 1e-3),
185+
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
186+
}
187+
188+
```
189+
190+
191+
Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.
192+
193+
194+
Or use the built-in bayesian hpo simply by running:
195+
196+
```python
197+
best_params = model.optimize_hparams(X, y)
198+
```
199+
200+
This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
201+
180202

181203
<h2> ⚖️ Distributional Regression with MambularLSS </h2>
182204

@@ -260,6 +282,7 @@ Here's how you can implement a custom model with Mambular:
260282

261283
```python
262284
from mambular.base_models import BaseModel
285+
from mambular.utils.get_feature_dimensions import get_feature_dimensions
263286
import torch
264287
import torch.nn
265288

@@ -275,11 +298,7 @@ Here's how you can implement a custom model with Mambular:
275298
super().__init__(**kwargs)
276299
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
277300

278-
input_dim = 0
279-
for feature_name, input_shape in num_feature_info.items():
280-
input_dim += input_shape
281-
for feature_name, input_shape in cat_feature_info.items():
282-
input_dim += 1
301+
input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
283302

284303
self.linear = nn.Linear(input_dim, num_classes)
285304

@@ -311,6 +330,59 @@ Here's how you can implement a custom model with Mambular:
311330
regressor.fit(X_train, y_train, max_epochs=50)
312331
```
313332

333+
# Custom Training
334+
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
335+
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.
336+
337+
```python
338+
import torch
339+
import torch.nn as nn
340+
import torch.optim as optim
341+
from mambular.base_models import Mambular
342+
from mambular.configs import DefaultMambularConfig
343+
344+
# Dummy data and configuration
345+
cat_feature_info = {
346+
"cat1": {
347+
"preprocessing": "imputer -> continuous_ordinal",
348+
"dimension": 1,
349+
"categories": 4,
350+
}
351+
} # Example categorical feature information
352+
num_feature_info = {
353+
"num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
354+
} # Example numerical feature information
355+
num_classes = 1
356+
config = DefaultMambularConfig() # Use the desired configuration
357+
358+
# Initialize model, loss function, and optimizer
359+
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
360+
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
361+
optimizer = optim.Adam(model.parameters(), lr=0.001)
362+
363+
# Example training loop
364+
for epoch in range(10): # Number of epochs
365+
model.train()
366+
optimizer.zero_grad()
367+
368+
# Dummy Data
369+
num_features = [torch.randn(32, 1) for _ in num_feature_info]
370+
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
371+
labels = torch.randn(32, num_classes)
372+
373+
# Forward pass
374+
outputs = model(num_features, cat_features)
375+
loss = criterion(outputs, labels)
376+
377+
# Backward pass and optimization
378+
loss.backward()
379+
optimizer.step()
380+
381+
# Print loss for monitoring
382+
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
383+
384+
```
385+
314386
# 🏷️ Citation
315387

316388
If you find this project useful in your research, please consider cite:
@@ -323,6 +395,16 @@ If you find this project useful in your research, please consider cite:
323395
}
324396
```
325397

398+
If you use TabulaRNN please consider to cite:
399+
```BibTeX
400+
@article{thielmann2024efficiency,
401+
title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
402+
author={Thielmann, Anton Frederik and Samiee, Soheila},
403+
journal={arXiv preprint arXiv:2411.17207},
404+
year={2024}
405+
}
406+
```
407+
326408
# License
327409

328410
The entire codebase is under MIT license.

0 commit comments

Comments
 (0)