Skip to content

Commit 67c5077

Browse files
author
FelixAbrahamsson
committed
feature: seeded random sampler
1 parent 7f19810 commit 67c5077

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

datastream/datastream.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,29 @@ def test_last_batch():
584584
SequentialSampler(3),
585585
)
586586
assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]
587+
588+
589+
def test_seeded_random_sampler():
590+
dataset = Dataset.from_subscriptable(np.arange(100))
591+
datastream = Datastream(dataset, sampler=StandardSampler(len(dataset), seed=1))
592+
593+
loader = datastream.data_loader(batch_size=1, collate_fn=tuple)
594+
batches1 = [batch for batch in loader]
595+
batches2 = [batch for batch in loader]
596+
assert all(
597+
batch1[0] == batch2[0]
598+
for batch1, batch2 in zip(batches1, batches2)
599+
)
600+
601+
602+
def test_unseeded_random_sampler():
603+
dataset = Dataset.from_subscriptable(np.arange(100))
604+
datastream = Datastream(dataset, sampler=StandardSampler(len(dataset)))
605+
606+
loader = datastream.data_loader(batch_size=1, collate_fn=tuple)
607+
batches1 = [batch for batch in loader]
608+
batches2 = [batch for batch in loader]
609+
assert any(
610+
batch1[0] != batch2[0]
611+
for batch1, batch2 in zip(batches1, batches2)
612+
)

datastream/samplers/standard_sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
from __future__ import annotations
22
from pydantic import BaseModel
3+
from typing import Optional
34
import torch
45

56

67
class StandardSampler(BaseModel, torch.utils.data.Sampler):
78
proportion: float
89
replacement: bool
910
sampler: torch.utils.data.WeightedRandomSampler
11+
seed: Optional[int]
12+
generator: Optional[torch.Generator]
1013

1114
class Config:
1215
arbitrary_types_allowed = True
1316
allow_mutation = False
1417

15-
def __init__(self, length, proportion=1.0, replacement=False):
18+
def __init__(self, length, proportion=1.0, replacement=False, seed=None):
19+
if seed is not None:
20+
generator = torch.Generator()
21+
generator.manual_seed(seed)
22+
else:
23+
generator = None
1624
BaseModel.__init__(
1725
self,
1826
proportion=proportion,
@@ -21,13 +29,18 @@ def __init__(self, length, proportion=1.0, replacement=False):
2129
torch.ones(length).double(),
2230
num_samples=int(max(1, min(length, length * proportion))),
2331
replacement=replacement,
24-
)
32+
generator=generator,
33+
),
34+
seed=seed,
35+
generator=generator,
2536
)
2637

2738
def __len__(self):
2839
return len(self.sampler)
2940

3041
def __iter__(self):
42+
if self.generator is not None:
43+
self.generator.manual_seed(self.seed)
3144
return iter(self.sampler)
3245

3346
@property
@@ -51,6 +64,7 @@ def sample_proportion(self, proportion):
5164
len(self),
5265
proportion,
5366
self.replacement,
67+
self.seed,
5468
)
5569
sampler.sampler.weights = self.sampler.weights
5670
return sampler

0 commit comments

Comments
 (0)