Skip to content

Commit b1645b6

Browse files
authored
Fix IndexError when resuming after some workers are done (#567)
1 parent 779eca4 commit b1645b6

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/litdata/streaming/dataset.py

+6
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,12 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any])
368368
self.worker_chunks = workers_chunks[worker_rank]
369369
self.worker_intervals = workers_intervals[worker_rank]
370370

371+
if self.worker_next_chunk_index >= self.num_chunks:
372+
# This can happen when interrupting and resuming after some but not all workers are done.
373+
# Proceeding would result in an indexing error when attempting to access the next chunk.
374+
# To prevent this we exit early and let the worker raise a StopIteration in __next__.
375+
return
376+
371377
# replay the indexes for the current chunks
372378
interval = self.worker_intervals[self.worker_next_chunk_index]
373379
current_indexes = np.arange(interval[1], interval[2])

tests/streaming/test_dataloader.py

+22
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,25 @@ def test_resume_dataloader_with_new_dataset(tmpdir):
333333
dataloader.load_state_dict(dataloader_state)
334334
for _ in dataloader:
335335
assert dataloader.current_epoch == 2, "Current epoch should be 2"
336+
337+
338+
def test_resume_dataloader_after_some_workers_are_done(tmpdir):
339+
# see https://github.com/Lightning-AI/litData/issues/563
340+
dset_path = tmpdir.join("dataset")
341+
cache = Cache(input_dir=str(dset_path), chunk_size=1)
342+
for i in range(3):
343+
cache[i] = i
344+
cache.done()
345+
cache.merge()
346+
dset = StreamingDataset(str(dset_path), shuffle=False)
347+
dloader = StreamingDataLoader(dset, batch_size=1, num_workers=2, shuffle=False)
348+
# worker 0 is assigned with samples 0 and 1, worker 1 is assigned with sample 2
349+
# the workers alternate, so the expected sequence is [0, 2, 1] and not [0, 1, 2]
350+
expected_sequence = [0, 2, 1]
351+
for i, x in enumerate(dloader):
352+
assert x == expected_sequence[i]
353+
if i == 1:
354+
break
355+
dloader.load_state_dict(dloader.state_dict())
356+
for x in dloader:
357+
assert x == expected_sequence[2]

0 commit comments

Comments
 (0)