Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not instantiate an infinite EpochDataset from current config access #884

Closed
kothasuhas opened this issue Feb 10, 2025 · 5 comments
Closed
Labels
bug Something isn't working

Comments

@kothasuhas
Copy link
Contributor

kothasuhas commented Feb 10, 2025

I want to instantiate an infinite EpochDataset. This is done by passing max_epochs=None. When I look at LMDatasetConfig, I notice that train_set receives an epochs argument determining the number of epochs. However, if this epochs argument is None, then the EpochDataset is never instantiated, limiting me to the standard dataset.

def train_set(
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:
ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors)
# add epoch flag here.
if ds is None:
raise ValueError("No training set!")
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)

I think the interface needs to be changed. I can see two easy fixes

  • By default, infinite epoch any finite dataset
  • Have a separate flag for no epoching vs infinite epoching

cc @Helw150

@kothasuhas kothasuhas added the bug Something isn't working label Feb 10, 2025
@dlwh
Copy link
Member

dlwh commented Feb 11, 2025 via email

@Helw150
Copy link
Collaborator

Helw150 commented Feb 11, 2025

Yeah, creating a mixture dataset with data mix weights 1 would support this!

@kothasuhas
Copy link
Contributor Author

I see, my use case was epoching one dataset many times while going one pass over the other dataset. If the mixture dataset gets arbitrarily looped, I don't need this as long as I crop my repetition dataset to the correct number of sequences (which I'm manually doing by modifying levanter right now).

@Helw150
Copy link
Collaborator

Helw150 commented Feb 12, 2025

@kothasuhas If I understand correctly, you should be able to set up an experiment that supports that without modifying Levanter now!

You should be able to just call dataset.slice_dataset(num_sequences) which will return a slice of only a fixed number of seqs

https://github.com/stanford-crfm/levanter/blob/main/src%2Flevanter%2Fdata%2Fdataset.py#L381-L387

@kothasuhas
Copy link
Contributor Author

I'm currently editing a local copy of levanter to do the slicing for me in LMMixtureDatasetConfig. Regardless, it works, and I imagine there's little use case for infinite epoching if the LMMixtureDatasetConfig does infinite looping by default. Thanks for clarifications!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants