Skip to content

Commit b32df4d

Browse files
authored
Fixing sdmetrics import and adding install and import test workflow (#765)
1 parent 89f2df2 commit b32df4d

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

.github/workflows/install.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Install Tests
2+
on:
3+
pull_request:
4+
types: [opened, synchronize]
5+
push:
6+
branches:
7+
- main
8+
9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.ref }}
11+
cancel-in-progress: true
12+
13+
jobs:
14+
install:
15+
name: ${{ matrix.python_version }} install
16+
strategy:
17+
fail-fast: true
18+
matrix:
19+
python_version: ["3.8", "3.13"]
20+
runs-on: ubuntu-latest
21+
steps:
22+
- name: Set up python ${{ matrix.python_version }}
23+
uses: actions/setup-python@v5
24+
with:
25+
python-version: ${{ matrix.python_version }}
26+
- uses: actions/checkout@v4
27+
- name: Build package
28+
run: |
29+
make package
30+
- name: Install package
31+
run: |
32+
python -m pip install "unpacked_sdist/."
33+
- name: Test by importing packages
34+
run: |
35+
python -c "import sdmetrics"
36+
- name: Check package conflicts
37+
run: |
38+
python -m pip check

Makefile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,23 @@ release-minor: check-release bumpversion-minor release
232232

233233
.PHONY: release-major
234234
release-major: check-release bumpversion-major release
235+
236+
# Packaging Targets
237+
.PHONY: upgradepip
238+
upgradepip:
239+
python -m pip install --upgrade pip
240+
241+
.PHONY: upgradebuild
242+
upgradebuild:
243+
python -m pip install --upgrade build
244+
245+
.PHONY: upgradesetuptools
246+
upgradesetuptools:
247+
python -m pip install --upgrade setuptools
248+
249+
.PHONY: package
250+
package: upgradepip upgradebuild upgradesetuptools
251+
python -m build ; \
252+
$(eval VERSION=$(shell python -c 'import setuptools; setuptools.setup()' --version))
253+
tar -zxvf "dist/sdmetrics-${VERSION}.tar.gz"
254+
mv "sdmetrics-${VERSION}" unpacked_sdist

sdmetrics/single_table/bayesian_network.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44

55
import numpy as np
66
import pandas as pd
7-
import torch
87

98
from sdmetrics.goal import Goal
109
from sdmetrics.single_table.base import SingleTableMetric
1110

11+
try:
12+
import torch
13+
except ModuleNotFoundError:
14+
torch = None
15+
1216
LOGGER = logging.getLogger(__name__)
1317

1418

@@ -19,6 +23,9 @@ class BNLikelihoodBase(SingleTableMetric):
1923
def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
2024
try:
2125
from pomegranate.bayesian_network import BayesianNetwork
26+
27+
if torch is None:
28+
raise ImportError
2229
except ImportError:
2330
raise ImportError(
2431
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'

tests/unit/single_table/test_bayesian_network.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,20 @@ def metadata():
4646
class TestBNLikelihood:
4747
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
4848
def test_compute_error(self):
49-
"""Test that an `ImportError` is raised."""
49+
"""Test that an `ImportError` is raised when pomegranate isn't installed."""
50+
# Setup
51+
metric = BNLikelihood()
52+
53+
# Run and Assert
54+
expected_message = re.escape(
55+
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
56+
)
57+
with pytest.raises(ImportError, match=expected_message):
58+
metric.compute(Mock(), Mock())
59+
60+
@patch.dict('sys.modules', {'torch': None})
61+
def test_compute_error_torch_is_none(self):
62+
"""Test that an `ImportError` is raised when torch isn't installed."""
5063
# Setup
5164
metric = BNLikelihood()
5265

0 commit comments

Comments
 (0)