Skip to content

Commit 648d029

Browse files
authored
Merge pull request #190 from basf/develop
SAINT, preprocessing, formatting, bug fixes etc.
2 parents af1ea08 + 4ae6afa commit 648d029

File tree

130 files changed

+5519
-5778
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+5519
-5778
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ If applicable, add screenshots to help explain your problem.
2525
- Mambular Version [e.g. 0.1.2]
2626

2727
**Additional context**
28-
Add any other context about the problem here.
28+
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
blank_issues_enabled: false
1+
blank_issues_enabled: false

.github/ISSUE_TEMPLATE/doc_request.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ assignees: ''
88
---
99

1010
**Description of the question**
11-
A clear and concise description of what should be documented.
11+
A clear and concise description of what should be documented.

.github/ISSUE_TEMPLATE/feature_request.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ A clear and concise description of what you want to happen.
1717
A clear and concise description of any alternative solutions or features you've considered.
1818

1919
**Additional context**
20-
Add any other context or screenshots about the feature request here.
20+
Add any other context or screenshots about the feature request here.

.github/ISSUE_TEMPLATE/question.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ Gives some context if needed (environment, system, hardware).
1414
A clear and concise description of what the task is.
1515

1616
**Describe the solution you'd like**
17-
A clear and concise description of what you want to happen.
17+
A clear and concise description of what you want to happen.

.github/workflows/build-publish-pypi.yml

+12-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
- release
77

88
jobs:
9-
publish:
9+
build-publish:
1010
runs-on: ubuntu-latest
1111

1212
steps:
@@ -18,15 +18,19 @@ jobs:
1818
with:
1919
python-version: "3.8"
2020

21-
- name: Install dependencies
21+
- name: Install Poetry
2222
run: |
23-
python -m pip install --upgrade pip
24-
pip install setuptools wheel twine
23+
curl -sSL https://install.python-poetry.org | python3 -
24+
export PATH="$HOME/.local/bin:$PATH"
25+
26+
- name: Install dependencies
27+
run: poetry install
2528

26-
- name: Build and publish package
29+
- name: Build package
30+
run: poetry build
31+
32+
- name: Publish to PyPI
2733
env:
2834
TWINE_USERNAME: __token__
2935
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
30-
run: |
31-
python setup.py sdist bdist_wheel
32-
twine upload dist/*
36+
run: poetry publish --username $TWINE_USERNAME --password $TWINE_PASSWORD

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ examples/lightning_logs
172172
docs/_build/doctrees/*
173173
docs/_build/html/*
174174

175+
175176
dev/*

.pre-commit-config.yaml

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
exclude: "^$"
2+
fail_fast: false
3+
default_stages: [commit, push]
4+
repos:
5+
- repo: https://github.com/pre-commit/pre-commit-hooks
6+
rev: v4.5.0
7+
hooks:
8+
- id: check-case-conflict
9+
- id: check-merge-conflict
10+
- id: end-of-file-fixer
11+
- id: mixed-line-ending
12+
- id: trailing-whitespace
13+
args: [--markdown-linebreak-ext=md]
14+
15+
- repo: https://github.com/charliermarsh/ruff-pre-commit
16+
rev: v0.1.14
17+
hooks:
18+
- id: ruff-format
19+
types_or: [python, pyi, jupyter]
20+
- id: ruff
21+
types_or: [python, pyi, jupyter]
22+
args: [ --fix, --exit-non-zero-on-fix ]
23+
24+
- repo: https://github.com/pre-commit/mirrors-prettier
25+
rev: v4.0.0-alpha.8
26+
hooks:
27+
- id: prettier
28+
types:
29+
- yaml
30+
- markdown
31+
- json

.vscode/settings.json

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"editor.formatOnSave": true,
3+
"editor.codeActionsOnSave": {
4+
"source.organizeImports": "explicit",
5+
"source.fixAll": "explicit"
6+
},
7+
"[python]": {
8+
"editor.defaultFormatter": "charliermarsh.ruff"
9+
},
10+
}

docs/codeofconduct.md CODE_OF_CONDUCT.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
12
# Code of Conduct
23

3-
- **Purpose**: The purpose of this Code of Conduct is to establish a welcoming and inclusive community around the `Mambular` project. We want to foster an environment where everyone feels respected, valued, and able to contribute to the project.
4+
- **Purpose**: The purpose of this Code of Conduct is to establish a welcoming and inclusive community around the `STREAM` project. We want to foster an environment where everyone feels respected, valued, and able to contribute to the project.
45

56
- **Openness and Respect**: We strive to create an open and respectful community where everyone can freely express their opinions and ideas. We encourage constructive discussions and debates, but we will not tolerate any form of harassment, discrimination, or disrespectful behavior.
67

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1818
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21-
SOFTWARE.
21+
SOFTWARE.

README.md

+18-17
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<h1>Mambular: Tabular Deep Learning Made Simple</h1>
2020
</div>
2121

22-
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
22+
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
2323

2424
<h3> Table of Contents </h3>
2525

@@ -66,6 +66,7 @@ Mambular is a Python package that brings the power of advanced deep learning arc
6666
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
6767
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
6868
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |
69+
| `SAINT` | Improve neural networs via Row Attention and Contrastive Pre-Training, introduced [here](https://arxiv.org/pdf/2106.01342). |
6970

7071

7172

@@ -90,7 +91,7 @@ If you want to use the original mamba and mamba2 implementations, additionally i
9091
pip install mamba-ssm
9192
```
9293

93-
Be careful to use the correct torch and cuda versions:
94+
Be careful to use the correct torch and cuda versions:
9495

9596
```sh
9697
pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html
@@ -115,7 +116,7 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t
115116
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
116117
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
117118
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.
118-
119+
119120

120121

121122

@@ -147,15 +148,15 @@ preds = model.predict_proba(X)
147148
```
148149

149150
<h3> Hyperparameter Optimization</h3>
150-
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
151+
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
151152

152153
```python
153154
from sklearn.model_selection import RandomizedSearchCV
154155

155156
param_dist = {
156-
'd_model': randint(32, 128),
157-
'n_layers': randint(2, 10),
158-
'lr': uniform(1e-5, 1e-3)
157+
'd_model': randint(32, 128),
158+
'n_layers': randint(2, 10),
159+
'lr': uniform(1e-5, 1e-3)
159160
}
160161

161162
random_search = RandomizedSearchCV(
@@ -179,10 +180,10 @@ print("Best Score:", random_search.best_score_)
179180
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
180181
```python
181182
param_dist = {
182-
'd_model': randint(32, 128),
183-
'n_layers': randint(2, 10),
183+
'd_model': randint(32, 128),
184+
'n_layers': randint(2, 10),
184185
'lr': uniform(1e-5, 1e-3),
185-
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
186+
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
186187
}
187188

188189
```
@@ -239,16 +240,16 @@ model = MambularLSS(
239240
dropout=0.2,
240241
d_model=64,
241242
n_layers=8,
242-
243+
243244
)
244245

245246
# Fit the model to your data
246247
model.fit(
247-
X,
248-
y,
249-
max_epochs=150,
250-
lr=1e-04,
251-
patience=10,
248+
X,
249+
y,
250+
max_epochs=150,
251+
lr=1e-04,
252+
patience=10,
252253
family="normal" # define your distribution
253254
)
254255

@@ -305,7 +306,7 @@ Here's how you can implement a custom model with Mambular:
305306
def forward(self, num_features, cat_features):
306307
x = num_features + cat_features
307308
x = torch.cat(x, dim=1)
308-
309+
309310
# Pass through linear layer
310311
output = self.linear(x)
311312
return output

docs/api/base_models/BaseModels.rst

+4
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ mambular.base_models
4848
.. autoclass:: mambular.base_models.NDTF
4949
:members:
5050
:no-inherited-members:
51+
52+
.. autoclass:: mambular.base_models.SAINT
53+
:members:
54+
:no-inherited-members:

docs/api/base_models/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Modules Description
2222
:class:`NDTF` Neural Decision Tree Forest (NDTF) model for tabular tasks, blending decision tree concepts with neural networks.
2323
:class:`TabulaRNN` Recurrent neural network (RNN) model, including LSTM and GRU architectures, tailored for sequential or time-series tabular data.
2424
:class:`MambAttention` Attention-based architecture for tabular tasks, combining feature importance weighting with advanced normalization techniques.
25+
:class:`SAINT` SAINT model. Transformer based model using row and column attetion.
2526
========================================= =======================================================================================================
2627

2728

docs/api/configs/Configurations.rst

+4
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ Configurations
4444
.. autoclass:: mambular.configs.DefaultTabMConfig
4545
:members:
4646
:undoc-members:
47+
48+
.. autoclass:: mambular.configs.DefaultSAINTConfig
49+
:members:
50+
:undoc-members:

docs/api/configs/index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ Dataclass Description
9595
:class:`DefaultTabMConfig` Default configuration for the TabM model (Batch-Ensembling MLP).
9696
======================================= =======================================================================================================
9797

98+
SAINT
99+
-----
100+
======================================= =======================================================================================================
101+
Dataclass Description
102+
======================================= =======================================================================================================
103+
:class:`DefaultSAINTConfig` Default configuration for the SAINT model.
104+
======================================= =======================================================================================================
105+
98106
.. toctree::
99107
:maxdepth: 1
100108

docs/api/models/Models.rst

+18-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mambular.models
55
:members:
66
:inherited-members:
77

8-
.. autoclass:: mambular.models.MambularRegressor
8+
.. autoclass:: mambular.models.MambularRegressor
99
:members:
1010
:inherited-members:
1111

@@ -29,7 +29,7 @@ mambular.models
2929
:members:
3030
:undoc-members:
3131

32-
.. autoclass:: mambular.models.MLPRegressor
32+
.. autoclass:: mambular.models.MLPRegressor
3333
:members:
3434
:undoc-members:
3535

@@ -49,7 +49,7 @@ mambular.models
4949
:members:
5050
:undoc-members:
5151

52-
.. autoclass:: mambular.models.ResNetClassifier
52+
.. autoclass:: mambular.models.ResNetClassifier
5353
:members:
5454
:undoc-members:
5555

@@ -101,7 +101,7 @@ mambular.models
101101
:members:
102102
:inherited-members:
103103

104-
.. autoclass:: mambular.models.TabMRegressor
104+
.. autoclass:: mambular.models.TabMRegressor
105105
:members:
106106
:inherited-members:
107107

@@ -113,7 +113,7 @@ mambular.models
113113
:members:
114114
:inherited-members:
115115

116-
.. autoclass:: mambular.models.NODERegressor
116+
.. autoclass:: mambular.models.NODERegressor
117117
:members:
118118
:inherited-members:
119119

@@ -125,14 +125,26 @@ mambular.models
125125
:members:
126126
:inherited-members:
127127

128-
.. autoclass:: mambular.models.NDTFRegressor
128+
.. autoclass:: mambular.models.NDTFRegressor
129129
:members:
130130
:inherited-members:
131131

132132
.. autoclass:: mambular.models.NDTFLSS
133133
:members:
134134
:undoc-members:
135135

136+
.. autoclass:: mambular.models.SAINTClassifier
137+
:members:
138+
:inherited-members:
139+
140+
.. autoclass:: mambular.models.SAINTRegressor
141+
:members:
142+
:inherited-members:
143+
144+
.. autoclass:: mambular.models.SAINTLSS
145+
:members:
146+
:undoc-members:
147+
136148
.. autoclass:: mambular.models.SklearnBaseClassifier
137149
:members:
138150
:undoc-members:

docs/api/models/index.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ Modules Description
117117
:class:`NDTFLSS` Distributional tasks using a Neural Decision Forest.
118118
======================================= =======================================================================================================
119119

120+
SAINT
121+
-----
122+
======================================= =======================================================================================================
123+
Modules Description
124+
======================================= =======================================================================================================
125+
:class:`SAINTClassifier` Multi-class and binary classification tasks using SAINT.
126+
:class:`SAINTRegressor` Regression tasks using SAINT.
127+
:class:`SAINTLSS` Distributional tasks using SAINT.
128+
======================================= =======================================================================================================
129+
120130
Base Classes
121131
------------
122132
======================================= =======================================================================================================
@@ -129,5 +139,5 @@ Modules Description
129139

130140
.. toctree::
131141
:maxdepth: 1
132-
142+
133143
Models

0 commit comments

Comments
 (0)