Skip to content

Commit 286e050

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 696257b commit 286e050

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed
Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import multiprocessing as mp
22
import os
3+
from collections.abc import Iterator
34
from queue import Queue
4-
from typing import Iterator
55

66
import numpy as np
7+
from torch.utils.data import DataLoader, IterableDataset
8+
79
from lightning import Trainer
810
from lightning.pytorch.demos.boring_classes import BoringModel
9-
from torch.utils.data import DataLoader, IterableDataset
11+
1012

1113
class QueueDataset(IterableDataset):
1214
def __init__(self, queue: Queue) -> None:
@@ -18,13 +20,15 @@ def __iter__(self) -> Iterator:
1820
tensor, _ = self.queue.get(timeout=10)
1921
yield tensor
2022

23+
2124
def create_queue():
2225
q = mp.Queue()
2326
arr = np.random.random([1, 32]).astype(np.float32)
2427
for ind in range(10):
2528
q.put((arr, ind))
2629
return q
2730

31+
2832
def train_model(queue, maxEpochs, ckptPath):
2933
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
3034
trainer = Trainer(max_epochs=maxEpochs, enable_progress_bar=False, devices=1)
@@ -36,24 +40,24 @@ def train_model(queue, maxEpochs, ckptPath):
3640
trainer.save_checkpoint(ckptPath)
3741
return trainer
3842

43+
3944
def test_training():
40-
"""
41-
Test that reproduces issue in calling iter twice on a queue-based
42-
IterableDataset leads to Queue Empty errors when resuming from a checkpoint.
43-
"""
45+
"""Test that reproduces issue in calling iter twice on a queue-based IterableDataset leads to Queue Empty errors
46+
when resuming from a checkpoint."""
4447
queue = create_queue()
4548

4649
ckpt_path = "model.ckpt"
4750
trainer = train_model(queue, 1, ckpt_path)
4851
assert trainer is not None
49-
52+
5053
assert os.path.exists(ckpt_path), "Checkpoint file wasn't created"
51-
54+
5255
ckpt_size = os.path.getsize(ckpt_path)
5356
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
54-
57+
5558
trainer = train_model(queue, 1, ckpt_path)
5659
assert trainer is not None
5760

61+
5862
if __name__ == "__main__":
5963
test_training()

0 commit comments

Comments
 (0)