Skip to content

Set seed with prefetch for reproducibility  #74

@stockeh

Description

@stockeh

Currently, the only way to the seed is with mlx.data.core.set_state, but this only controls the seed for .shuffle(). When using .prefetch with num_threads > 1, the samples returned are not deterministic and therefore not reproducible.

Is there a way to set the seed when prefetching with more than one thread?

import mlx.data.core as dmx
from mlx.data.datasets import load_mnist

dmx.set_state(42)

train = load_mnist(root=None, train=True)
dset = (
    train.shuffle()
    .to_stream()
    .key_transform("image", lambda x: x.astype("float32") / 255)
    .batch(32)
    .prefetch(prefetch_size=4, num_threads=4) # non-deterministic with > 1 thread
)

for i, data in enumerate(dset):
    print(data["image"].sum())
    if i == 2:
        break

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions