Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix seed random sampler to use initial seed from generator #3130

Closed
wants to merge 2 commits into from

Conversation

MekkCyber
Copy link

@MekkCyber MekkCyber commented Sep 26, 2024

What does this PR do?

This PR fixes the behavior of SeedableRandomSampler, when self.generator is not None, we use the initial_seed of the generator instead of initializing a random seed. We need it for transformers PR huggingface/transformers#33731 to enable data_seed.
See the issue : huggingface/transformers#31818

For example in this code :

from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, Trainer, set_seed
import torch

set_seed(125)

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

dataset = load_dataset("glue", "mrpc", split="train")

def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

dataset = dataset.map(tokenize_function, batched=True)
SEED = 126
DATA_SEED = 762
training_args = TrainingArguments(f"test-trainer_dataseed_{DATA_SEED}_seed_{SEED}", max_steps=400, logging_steps=40, seed=SEED, data_seed=DATA_SEED)
trainer = Trainer(
    model,
    training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

trainer.train()

When we set the seed using set_seed, if the data sampler's seed is chosen using torch.random.initial_seed(), it will remain deterministic even when the data_seed value changes. To address this, the PR checks whether a generator using the data_seed is already available. If it exists, we use the initial data_seed provided by the user. If not, we randomly select one. This ensures that even when set_seed is used, the randomness of the data sampler remains unaffected when different data_seed values are used.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks for this PR. Could you please describe a situation where this would make a difference?

@MekkCyber
Copy link
Author

Hello @BenjaminBossan, I just edited the description to make it clearer

@BenjaminBossan
Copy link
Member

Ah thanks for the detailed explanation. From my POV, this looks good. I think the failing tests are unrelated. Let's wait for Zach's return before proceeding with this PR.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as we discussed internally !

@SunMarc SunMarc requested a review from muellerzr October 1, 2024 15:15
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly these breaking tests are actually quite important. As training on CPU does not in fact give us the same model training in the end, so this PR currently doesn't work properly (aka training doesn't give reproducible results)

@muellerzr muellerzr mentioned this pull request Oct 9, 2024
5 tasks
@SunMarc SunMarc closed this Oct 9, 2024
@SunMarc
Copy link
Member

SunMarc commented Oct 9, 2024

Superseded by #3150

@SunMarc
Copy link
Member

SunMarc commented Oct 9, 2024

This PR could potentially pass also. The fix the failing test, we just need to reinitialize the generator before preparing the dataloader. What was happening was because we had multiple training loop using the same generator, the seed of the generator changed after training (42->44). However, since we were initializing the initial seed of the sampler for the next training based on the generator, we got 44.

Test that was failing :

    # first training
    set_seed(42)
    generator.manual_seed(42)
    accelerator = Accelerator()
    train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
    model = RegressionModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    # initial_seed = 42 since we have generator.manual_seed(42)
    train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
    set_seed(42)
    generator.manual_seed(42)
    for _ in range(3):
        for batch in train_dl:
            model.zero_grad()
            output = model(batch["x"])
            loss = torch.nn.functional.mse_loss(output, batch["y"])
            accelerator.backward(loss)
            optimizer.step()
    # generator seed value changed to 44

    model = accelerator.unwrap_model(model).cpu()
    assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training."
    assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training."

    accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
  
    dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
    accelerator = Accelerator(dataloader_config=dataloader_config)
    train_dl = generate_baseline_dataloader(
        train_set, generator, batch_size * state.num_processes, use_seedable_sampler
    )
    model = RegressionModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    # initial_seed = 44 from the generator value -> to fix it, we need to reset the generator before. 
    train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
    set_seed(42)
    generator.manual_seed(42)
    for _ in range(3):
        for batch in train_dl:
            model.zero_grad()
            output = model(batch["x"])
            loss = torch.nn.functional.mse_loss(output, batch["y"])
            accelerator.backward(loss)
            optimizer.step()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants