1
1
import multiprocessing as mp
2
2
import os
3
+ from collections .abc import Iterator
3
4
from queue import Queue
4
- from typing import Iterator
5
5
6
6
import numpy as np
7
+ from torch .utils .data import DataLoader , IterableDataset
8
+
7
9
from lightning import Trainer
8
10
from lightning .pytorch .demos .boring_classes import BoringModel
9
- from torch . utils . data import DataLoader , IterableDataset
11
+
10
12
11
13
class QueueDataset (IterableDataset ):
12
14
def __init__ (self , queue : Queue ) -> None :
@@ -18,13 +20,15 @@ def __iter__(self) -> Iterator:
18
20
tensor , _ = self .queue .get (timeout = 10 )
19
21
yield tensor
20
22
23
+
21
24
def create_queue ():
22
25
q = mp .Queue ()
23
26
arr = np .random .random ([1 , 32 ]).astype (np .float32 )
24
27
for ind in range (10 ):
25
28
q .put ((arr , ind ))
26
29
return q
27
30
31
+
28
32
def train_model (queue , maxEpochs , ckptPath ):
29
33
dataloader = DataLoader (QueueDataset (queue ), num_workers = 1 , batch_size = None , persistent_workers = True )
30
34
trainer = Trainer (max_epochs = maxEpochs , enable_progress_bar = False , devices = 1 )
@@ -36,24 +40,24 @@ def train_model(queue, maxEpochs, ckptPath):
36
40
trainer .save_checkpoint (ckptPath )
37
41
return trainer
38
42
43
+
39
44
def test_training ():
40
- """
41
- Test that reproduces issue in calling iter twice on a queue-based
42
- IterableDataset leads to Queue Empty errors when resuming from a checkpoint.
43
- """
45
+ """Test that reproduces issue in calling iter twice on a queue-based IterableDataset leads to Queue Empty errors
46
+ when resuming from a checkpoint."""
44
47
queue = create_queue ()
45
48
46
49
ckpt_path = "model.ckpt"
47
50
trainer = train_model (queue , 1 , ckpt_path )
48
51
assert trainer is not None
49
-
52
+
50
53
assert os .path .exists (ckpt_path ), "Checkpoint file wasn't created"
51
-
54
+
52
55
ckpt_size = os .path .getsize (ckpt_path )
53
56
assert ckpt_size > 0 , f"Checkpoint file is empty (size: { ckpt_size } bytes)"
54
-
57
+
55
58
trainer = train_model (queue , 1 , ckpt_path )
56
59
assert trainer is not None
57
60
61
+
58
62
if __name__ == "__main__" :
59
63
test_training ()
0 commit comments