Skip to content

Commit 1df04ce

Browse files
authored
Merge pull request #78 from blei-lab/achille/prepare_release
Achille/prepare release
2 parents f3afbe5 + 488b472 commit 1df04ce

File tree

9 files changed

+82
-62
lines changed

9 files changed

+82
-62
lines changed

MANIFEST.in

Lines changed: 0 additions & 31 deletions
This file was deleted.

ci/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ six>=1.14.0
55
tox
66
twine
77
scipy
8+
poetry

pyproject.toml

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
1-
[build-system]
2-
requires = [
3-
"setuptools>=30.3.0",
1+
[tool.poetry]
2+
name = "treeffuser"
3+
version = "0.1.0"
4+
description = ""
5+
authors = [
6+
"Nicolas Beltran-Velez <[email protected]>",
7+
"Alessandro Antonio Grande <[email protected]>",
8+
"Achille Nazaret <[email protected]>",
9+
]
10+
license = "MIT"
11+
readme = "README.rst"
12+
packages = [{include = "treeffuser", from= "src"}]
13+
repository = "https://github.com/blei-lab/treeffuser"
14+
include = [
15+
"pyproject.toml",
16+
"AUTHORS.rst",
17+
"README.rst",
18+
"LICENSE"
419
]
520

21+
[tool.poetry.dependencies]
22+
python = "^3.9"
23+
numpy = "^1.24"
24+
jaxtyping = "^0.2.19"
25+
einops = "^0.8.0"
26+
scipy = "^1.13.1"
27+
tqdm = "^4.66.4"
28+
lightgbm = "^4.3.0"
29+
ml-collections = "^0.1.1"
30+
scikit-learn = "^1.5.0"
31+
32+
[tool.poetry.dev-dependencies]
33+
pytest = "^8.2.2"
34+
tox = "^3.20.1"
35+
36+
[build-system]
37+
requires = ["poetry-core"]
38+
build-backend = "poetry.core.masonry.api"
39+
640
[tool.ruff.per-file-ignores]
741
"ci/*" = ["S"]
842

@@ -56,3 +90,6 @@ target-version = ["py38"]
5690

5791
[tool.ruff.format]
5892
indent-style="space"
93+
94+
[tool.pytest.ini_options]
95+
pythonpath = "src"

requirements.txt

Lines changed: 0 additions & 14 deletions
This file was deleted.
File renamed without changes.

tests/test_score_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import numpy as np
66
from einops import repeat
7-
from sklearn.metrics import r2_score
87

98
from treeffuser._score_models import LightGBMScore
109
from treeffuser.sde.sdes import VESDE
1110

1211
from .utils import generate_bimodal_linear_regression_data
12+
from .utils import r2_score
1313

1414

1515
def test_linear_regression():

tests/test_treeffuser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
from scipy.stats import ks_2samp
3-
from sklearn.model_selection import train_test_split
43

54
from treeffuser import LightGBMTreeffuser
6-
from utils import gaussian_mixture_pdf
5+
6+
from .utils import gaussian_mixture_pdf
7+
from .utils import train_test_split
78

89

910
def test_treeffuser_bimodal_linear_regression():

tests/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,34 @@ def gaussian_mixture_pdf(
6060
density = weight1 * gaussian_pdf(x, loc1, scale1, log=False)
6161
density += (1 - weight1) * gaussian_pdf(x, loc2, scale2, log=False)
6262
return np.log(density) if log else density
63+
64+
65+
def train_test_split(X, y, test_size=0.2, random_state=None):
66+
"""
67+
Split the data into training and test sets.
68+
"""
69+
n = X.shape[0]
70+
if random_state is not None:
71+
rng = np.random.default_rng(random_state)
72+
else:
73+
rng = np.random.default_rng()
74+
idx = rng.permutation(n)
75+
n_test = int(n * test_size)
76+
idx_train = idx[n_test:]
77+
idx_test = idx[:n_test]
78+
return X[idx_train], X[idx_test], y[idx_train], y[idx_test]
79+
80+
81+
def r2_score(y_true, y_pred):
82+
"""
83+
Compute the R^2 score.
84+
"""
85+
y_true = y_true.flatten()
86+
y_pred = y_pred.flatten()
87+
88+
y_mean = np.mean(y_true)
89+
ss_tot = np.sum((y_true - y_mean) ** 2)
90+
ss_res = np.sum((y_true - y_pred) ** 2)
91+
92+
r2 = 1 - ss_res / ss_tot
93+
return r2

tox.ini

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@ commands =
88
passenv =
99
*
1010

11-
; a generative tox configuration, see: https://tox.wiki/en/latest/user_guide.html#generative-environments
12-
# TODO: This should be fixed for the final release
1311
[tox]
12+
skipsdist = true
1413
envlist =
1514
clean,
1615
check,
1716
{py310,pypy310, py39,pypy39},
1817
report
19-
#clean,
20-
#check,
21-
#docs,
22-
#{py38,py39,py310,py311,py312,pypy38,pypy39,pypy310},
23-
#report
2418
ignore_basepython_conflict = true
2519

2620
[testenv]
@@ -40,24 +34,25 @@ setenv =
4034
passenv =
4135
*
4236
usedevelop = false
37+
allowlist_externals = poetry
4338
deps =
4439
pytest
4540
pytest-cov
4641
commands =
47-
{posargs:pytest --cov --cov-report=term-missing --cov-report=xml -vv tests}
42+
poetry install
43+
poetry run pytest --cov --cov-report=term-missing --cov-report=xml -vv tests
4844

4945
[testenv:check]
5046
deps =
5147
docutils
52-
check-manifest
5348
pre-commit
5449
readme-renderer
5550
pygments
5651
isort
5752
skip_install = true
53+
allowlist_externals = poetry
5854
commands =
59-
python setup.py check --strict --metadata --restructuredtext
60-
check-manifest .
55+
poetry check
6156
pre-commit run --all-files --show-diff-on-failure
6257

6358
[testenv:docs]

0 commit comments

Comments
 (0)