Skip to content

Add testing workflow and use pytest #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2f75420
Add `pytest` as an extra requirement
smokestacklightnin Aug 15, 2024
386d089
Add `pytest.ini`
smokestacklightnin Aug 15, 2024
c9d54fc
Add first pass at a GitHub workflow that runs tests
smokestacklightnin Aug 15, 2024
646ec4b
Add matrix for python versions
smokestacklightnin Aug 16, 2024
230e73a
Fix typo
smokestacklightnin Aug 16, 2024
13e62cb
Remove debugging trigger
smokestacklightnin Aug 16, 2024
7c287e7
Don't use editable install. Instead use regular install
smokestacklightnin Aug 16, 2024
7643665
Update install instructions
smokestacklightnin Aug 16, 2024
f305a18
Add note about running tests
smokestacklightnin Aug 16, 2024
ab5f4f8
Remove `if __name__ == "__main__":` from test files
smokestacklightnin Aug 16, 2024
843c394
Remove unnecessary packages from testing workflow
smokestacklightnin Aug 16, 2024
fa0013c
Change `pip3` to `pip`
smokestacklightnin Aug 16, 2024
6b0f11a
Change `python3` to `python`
smokestacklightnin Aug 16, 2024
a0e6463
Remove logging options in favor of defaults
smokestacklightnin Aug 16, 2024
f63a78f
Remove verbose flag
smokestacklightnin Aug 27, 2024
ffbb959
Add xfail mark to classes with failing tests
smokestacklightnin Aug 27, 2024
fbac9af
Remove timeout
smokestacklightnin Aug 27, 2024
82466a2
Don't run xfailed tests
smokestacklightnin Aug 27, 2024
8c29b52
Merge remote-tracking branch 'upstream/master' into ci/testing/use-py…
smokestacklightnin Jul 3, 2025
46f0716
Add `xfail` mark to failing test
smokestacklightnin Jul 3, 2025
3676584
Undo ruff formatting
smokestacklightnin Jul 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Github action definitions for unit-tests with PRs.

name: tft-unit-tests
on:
pull_request:
branches: [ master ]
paths-ignore:
- '**.md'
- 'docs/**'
workflow_dispatch:

jobs:
unit-tests:
if: github.actor != 'copybara-service[bot]'
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: |
setup.py

- name: Install dependencies
run: |
pip install .[test]

- name: Run unit tests
shell: bash
run: |
pytest
29 changes: 21 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,28 @@ pip install tensorflow-transform
To build from source follow the following steps:
Create a virtual environment by running the commands

```
python3 -m venv <virtualenv_name>
```bash
python -m venv <virtualenv_name>
source <virtualenv_name>/bin/activate
pip3 install setuptools wheel
git clone https://github.com/tensorflow/transform.git
cd transform
python3 setup.py bdist_wheel
pip install .
```

This will build the TFT wheel in the dist directory. To install the wheel from
dist directory run the commands
If you are doing development on the TFT repo, replace

```bash
pip install .
```
cd dist
pip3 install tensorflow_transform-<version>-py3-none-any.whl

with

```bash
pip install -e .
```

The `-e` flag causes TFT to be installed in [development mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html).

### Nightly Packages

TFT also hosts nightly packages at https://pypi-nightly.tensorflow.org on
Expand All @@ -72,6 +77,14 @@ pip install --extra-index-url https://pypi-nightly.tensorflow.org/simple tensorf
This will install the nightly packages for the major dependencies of TFT such
as TensorFlow Metadata (TFMD), TFX Basic Shared Libraries (TFX-BSL).

### Running Tests

To run TFT tests, run the following command from the root of the repository:

```bash
pytest
```

### Notable Dependencies

TensorFlow is required.
Expand Down
5 changes: 5 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
addopts = --import-mode=importlib
testpaths = tensorflow_transform
python_files = *_test.py
norecursedirs = .* *.egg
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _make_docs_packages():
namespace_packages=[],
install_requires=_make_required_install_packages(),
extras_require= {
'test': ['pytest>=8.0'],
'docs': _make_docs_packages(),
},
python_requires='>=3.9,<4',
Expand Down
2 changes: 0 additions & 2 deletions tensorflow_transform/analyzers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,5 +624,3 @@ def testMinDiffFromAvg(self):
analyzers.calculate_recommended_min_diff_from_avg(100000000), 25)


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/annotators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,3 @@ def preprocessing_fn():
self.assertEqual(trackable_object, object_tracker.trackable_objects[0])


if __name__ == '__main__':
test_case.main()
5 changes: 3 additions & 2 deletions tensorflow_transform/beam/analysis_graph_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import sys

import pytest
import tensorflow as tf
import tensorflow_transform as tft
from tensorflow_transform import analyzer_nodes
Expand Down Expand Up @@ -412,6 +413,8 @@ class AnalysisGraphBuilderTest(tft_unit.TransformTestCase):
],
)
)
@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
def test_build(
self,
feature_spec,
Expand Down Expand Up @@ -592,5 +595,3 @@ class _Analyzer(
structured_outputs)


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/analyzer_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,5 +311,3 @@ def expand(self, pbegin):
beam_test_util.equal_to([test_cache_dict[key].cache_dict['b']]))


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/analyzer_impls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,3 @@ def testJoinBoundarieRows(self, input_boundaries, expected_boundaries,
self.assertAllEqual(num_buckets, expected_num_buckets)


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/annotators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,5 +259,3 @@ def preprocessing_fn(inputs):
)


if __name__ == '__main__':
tft_unit.main()
6 changes: 4 additions & 2 deletions tensorflow_transform/beam/bucketize_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Tests for tft.bucketize and tft.quantiles."""


import pytest
import contextlib
import random

Expand Down Expand Up @@ -339,6 +341,8 @@ def _compute_simple_per_key_bucket(val, key, weighted=False):
]


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class BucketizeIntegrationTest(tft_unit.TransformTestCase):

def setUp(self):
Expand Down Expand Up @@ -890,5 +894,3 @@ def testBucketizationSpecificDistribution(self):
inputs, expected_boundaries, tf.float32, num_buckets=5)


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/cached_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,5 +1993,3 @@ def preprocessing_fn(inputs):
)


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/combiner_packing_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,5 +760,3 @@ def _side_effect_fn(saved_model_future, cache_value_nodes,
)


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,3 @@ def testNestedContextCreateBaseTempDir(self):
tft_beam.Context.create_base_temp_dir()


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/deep_copy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,5 +330,3 @@ def testDeepCopyTags(self):
self.assertEqual(DeepCopyTest._counts['Add2'], 3 * (num_copies + 1))
self.assertEqual(DeepCopyTest._counts['Add3'], 3)

if __name__ == '__main__':
unittest.main()
6 changes: 4 additions & 2 deletions tensorflow_transform/beam/impl_output_record_batches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Same as impl_test.py, except that impl produces `pa.RecordBatch`es."""


import pytest
import collections

import numpy as np
Expand All @@ -28,6 +30,8 @@
_LARGE_BATCH_SIZE = 1 << 10


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class BeamImplOutputRecordBatchesTest(impl_test.BeamImplTest):

def _OutputRecordBatches(self):
Expand Down Expand Up @@ -199,5 +203,3 @@ def testConvertToLargeRecordBatch(
self.assertGreater(actual_num_batches, 1)


if __name__ == '__main__':
tft_unit.main()
6 changes: 4 additions & 2 deletions tensorflow_transform/beam/impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import itertools
import math
import os
Expand Down Expand Up @@ -110,6 +112,8 @@ def _mean_output_dtype(input_dtype):
return tf.float64 if input_dtype == tf.float64 else tf.float32


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class BeamImplTest(tft_unit.TransformTestCase):

def setUp(self):
Expand Down Expand Up @@ -4801,5 +4805,3 @@ def test_preprocessing_fn_returns_wrong_type(self):
expected_data=None)


if __name__ == '__main__':
tft_unit.main()
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,3 @@ def mock_write_metadata(metadata, path):
self.assertEqual(metadata, test_metadata.COMPLETE_METADATA)


if __name__ == '__main__':
tf.test.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,3 @@ def mock_copy_tree_to_unique_temp_dir(source, base_temp_dir_path):
self.assertEqual(2, len(file_io.list_directory(transform_output_dir)))


if __name__ == '__main__':
tf.test.main()
6 changes: 4 additions & 2 deletions tensorflow_transform/beam/tukey_hh_params_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Tests for tft.tukey_* calls (Tukey HH parameters)."""


import pytest
import itertools

import apache_beam as beam
Expand Down Expand Up @@ -92,6 +94,8 @@
]


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class TukeyHHParamsIntegrationTest(tft_unit.TransformTestCase):

def setUp(self):
Expand Down Expand Up @@ -627,5 +631,3 @@ def assert_and_cast_dtype(tensor):
# Runs the test deterministically on the whole batch.
beam_pipeline=beam.Pipeline())

if __name__ == '__main__':
tft_unit.main()
6 changes: 4 additions & 2 deletions tensorflow_transform/beam/vocabulary_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
"""Tests for tft.vocabulary and tft.compute_and_apply_vocabulary."""


import pytest
import os

import apache_beam as beam
Expand Down Expand Up @@ -114,6 +116,8 @@
]


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class VocabularyIntegrationTest(tft_unit.TransformTestCase):

def setUp(self):
Expand Down Expand Up @@ -2088,5 +2092,3 @@ def preprocessing_fn(inputs):
)


if __name__ == '__main__':
tft_unit.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
# limitations under the License.
"""Tests for tfrecord_gzip tft.vocabulary and tft.compute_and_apply_vocabulary."""


import pytest
from tensorflow_transform.beam import vocabulary_integration_test
from tensorflow_transform.beam import tft_unit


@pytest.mark.xfail(run=False, reason="PR 315 This class contains tests that fail and needs to be fixed. "
"If all tests pass, please remove this mark.")
class TFRecordVocabularyIntegrationTest(
vocabulary_integration_test.VocabularyIntegrationTest):

def _VocabFormat(self):
return 'tfrecord_gzip'


if __name__ == '__main__':
tft_unit.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/coders/csv_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,3 @@ def test_picklable(self):
self.assertEqual(coder.encode(instance), csv_line.encode('utf-8'))


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/coders/example_proto_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,3 @@ def test_example_proto_coder_cache(self):
self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto)


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,3 @@ def fn2():
self.assertAllEqual([], graph.get_collection("another_collection"))


if __name__ == "__main__":
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/gaussianization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,5 +252,3 @@ def test_inverse_tukey_hh(self, samples, hl, hr, expected_output):
self.assertAllClose(output, expected_output)


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/graph_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,5 +1264,3 @@ def _value_to_matcher(value, add_quotes=False):
type(value), value))


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/impl_helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,5 +974,3 @@ def iteration(counter, x_minus_counter):
cond=stop_condition, body=iteration, loop_vars=initial_values)[1]


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/info_theory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,3 @@ def test_mutual_information(self, cell_count, row_count, col_count,
self.assertNear(per_cell_mi, expected_mi, EPSILON)


if __name__ == '__main__':
unittest.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/inspect_preprocessing_fn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,3 @@ def test_column_inference(self, preprocessing_fn,
expected_transform_input_columns)


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,5 +932,3 @@ def testEstimatedProbabilityDensityMissingKey(self):
self.assertAllEqual(expected, sess.run(result))


if __name__ == '__main__':
test_case.main()
2 changes: 0 additions & 2 deletions tensorflow_transform/nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,3 @@ def testGetDotGraph(self):
msg='Result dot graph is:\n{}'.format(dot_string))


if __name__ == '__main__':
test_case.main()
Loading