Skip to content

fix: enable parallel test execution with pytest-xdist in CI workflow #620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e6172a7
fix: enable parallel test execution with pytest-xdist in CI workflow
deependujha Jun 11, 2025
1ee614e
temporary fix to handle parallelly running tests in ci
deependujha Jun 12, 2025
cf5fb70
update
deependujha Jun 12, 2025
71eafa9
update
deependujha Jun 12, 2025
17eed41
update
deependujha Jun 12, 2025
311beae
update
deependujha Jun 12, 2025
46b5843
7 pm
deependujha Jun 12, 2025
93305ba
pytest-xdist ==3.4.0
Borda Jun 12, 2025
cbf1ca5
fix tmp path on windows
deependujha Jun 12, 2025
ecad3d8
add fixture for unique HF URL to support parallel test runs
deependujha Jun 12, 2025
5bad28f
Merge branch 'main' into feat/run-tests-parallely
deependujha Jun 12, 2025
4d730a5
update
deependujha Jun 12, 2025
c6ad3f9
Merge branch 'main' into feat/run-tests-parallely
Borda Jun 18, 2025
70f7501
update
deependujha Jun 13, 2025
f3bdcf8
increase timeout of 60s to 90s
deependujha Jun 18, 2025
c35ac31
bump pytest & pytest-xdist
deependujha Jun 19, 2025
1956b7f
rerun failing tests twice
deependujha Jun 19, 2025
f890dd6
refactor: update pytest command and adjust fixture scopes for better …
deependujha Jun 20, 2025
7b09b9f
update
deependujha Jun 20, 2025
43450e3
update
deependujha Jun 20, 2025
6f47a5b
update
deependujha Jun 20, 2025
6c47e98
update
deependujha Jun 20, 2025
e52dd98
update
deependujha Jun 20, 2025
51aea13
update
deependujha Jun 20, 2025
14c0e79
update
deependujha Jun 20, 2025
140350a
update
deependujha Jun 20, 2025
613309c
update
deependujha Jun 24, 2025
59056f0
let's just wait
deependujha Jun 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@ jobs:
- name: Install package & dependencies
run: |
pip --version
pip install -U lightning-sdk
pip install -U --force-reinstall --no-cache-dir lightning-sdk
pip install -e ".[extras]" -r requirements/test.txt -U -q --find-links $TORCH_URL
pip list

- name: Tests
working-directory: tests
run: pytest . -v --cov=litdata --durations=100
run: |
PYVER="${{ matrix.python-version }}"
if [ "$RUNNER_OS" = "Windows" ] && [ "$PYVER" != "3.12" ]; then
pytest . -v --cov=litdata --durations=100
else
pytest . -v --cov=litdata --durations=100 -n auto --dist=loadfile --reruns=2
fi

- name: Statistics
continue-on-error: true
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pytest-cov ==6.1.1
pytest-timeout ==2.3.1
pytest-rerunfailures ==14.0
pytest-random-order ==1.1.1
pytest-xdist >=3.7.0
pandas
lightning
transformers <4.53.0
Expand Down
1 change: 1 addition & 0 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _resolve_s3_connections(dir_path: str) -> Dir:
target_name = dir_path.split("/")[3]

data_connections = client.data_connection_service_list_data_connections(project_id).data_connections
print(f"{data_connections=}")

data_connection = [dc for dc in data_connections if dc.name == target_name]

Expand Down
21 changes: 18 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import shutil
import signal
import sys
import tempfile
import threading
import uuid
from collections import OrderedDict
from types import ModuleType
from unittest.mock import Mock
Expand All @@ -16,7 +19,7 @@
from litdata.utilities.dataset_utilities import get_default_cache_dir


@pytest.fixture(autouse=True)
@pytest.fixture(autouse=True, scope="session")
def teardown_process_group():
"""Ensures distributed process group gets closed before the next test runs."""
yield
Expand All @@ -25,10 +28,22 @@ def teardown_process_group():


@pytest.fixture(autouse=True)
def set_env():
def set_env(monkeypatch):
# Set environment variable before each test to configure BaseWorker's maximum wait time
os.environ["DATA_OPTIMIZER_TIMEOUT"] = "20"

uuid_str = uuid.uuid4().hex
tmp_base = tempfile.gettempdir() if sys.platform == "win32" else "/tmp" # noqa: S108
tmp_path = os.path.join(tmp_base, uuid_str)

monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", tmp_path)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", tmp_path)


@pytest.fixture(autouse=True)
def disable_signals(monkeypatch):
monkeypatch.setattr(signal, "signal", lambda *args, **kwargs: None)


@pytest.fixture
def mosaic_mds_index_data():
Expand Down Expand Up @@ -134,7 +149,7 @@ def lightning_sdk_mock(monkeypatch):
return lightning_sdk


@pytest.fixture(autouse=True)
@pytest.fixture(autouse=True, scope="session")
def _thread_police():
"""Attempts stopping left-over threads to avoid test interactions.

Expand Down
6 changes: 5 additions & 1 deletion tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from functools import partial
from queue import Empty
from time import sleep
from typing import Any, List
from unittest import mock
from unittest.mock import ANY, Mock
Expand Down Expand Up @@ -463,6 +464,8 @@ def test_data_processsor(fast_dev_run, delete_cached_files, tmpdir, monkeypatch)

chunks = fast_dev_run_enabled_chunks if fast_dev_run == 10 else fast_dev_run_disabled_chunks

sleep(5) # wait for some time to ensure all files are written

assert sorted(os.listdir(cache_dir)) == chunks

files = []
Expand Down Expand Up @@ -624,7 +627,6 @@ def test_data_process_transform(monkeypatch, tmpdir):

input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir)

imgs = []
for i in range(5):
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
Expand All @@ -648,6 +650,8 @@ def test_data_process_transform(monkeypatch, tmpdir):
)
data_processor.run(ImageResizeRecipe())

sleep(5) # Ensure all files are written

assert sorted(os.listdir(output_dir)) == ["0.JPEG", "1.JPEG", "2.JPEG", "3.JPEG", "4.JPEG"]

from PIL import Image
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir):
assert count >= 25, "There should be at least 25 batches in the third epoch"


@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
def test_resume_dataloader_with_new_dataset(tmpdir):
dataset_1_path = tmpdir.join("dataset_1")
dataset_2_path = tmpdir.join("dataset_2")
Expand Down
14 changes: 7 additions & 7 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
],
)
@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compression):
seed_everything(42)

Expand Down Expand Up @@ -363,7 +363,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compr
),
],
)
@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, compression):
seed_everything(42)

Expand Down Expand Up @@ -411,7 +411,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir, comp
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
],
)
@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
def test_streaming_dataset_distributed_full_shuffle_even_multi_nodes(drop_last, tmpdir, compression):
seed_everything(42)

Expand Down Expand Up @@ -684,7 +684,7 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir):
assert result == expected


@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
def test_dataset_for_text_tokens_with_large_block_size_multiple_workers(tmpdir):
# test to reproduce ERROR: Unexpected segmentation fault encountered in worker
seed_everything(42)
Expand Down Expand Up @@ -1021,7 +1021,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False):

@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
@pytest.mark.parametrize("shuffle", [True, False])
def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
"""Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size."""
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)


@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir, monkeypatch):
seed_everything(42)
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def fn(remote_chunkpath: str, local_chunkpath: str):
dataset._validate_state_dict()


@pytest.mark.timeout(60)
@pytest.mark.timeout(90)
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state_override(tmpdir, monkeypatch):
seed_everything(42)
Expand Down
Loading