-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting varied mixtures over training #868
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! thanks for knocking this out so quickly
src/levanter/data/mixture.py
Outdated
|
||
Args: | ||
datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself | ||
weights: weights for each dataset | ||
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_index, weights). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_index, weights). | |
weights: Weights for each dataset. This can be provided in a list of stages, where each stage is a tuple of (start_step, weights). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed. also added clarification that this corresponds to the sequence index at which you want to change the distribution, not the batch index. this method doesnt get to know the batch size and i think that generality is good (eventually if you want batch size curricula like most real LMs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually because of this, im thinking its better to keep it as "start_seq_index"?
thanks for review! only final request is to name the variable as start_seq_index instead of start_step since step is often conflated with batch indices. and for maximum generality, i wanted this method to work without knowing the batch size. if this is good, ill merge |
awesome thanks! |
Description
Currently, LM mixture dataset can only handle a static mixture over the course of training. This PR enables varying this mixture over datasets over the course of training. The user can now specify a list of stages and the sequence index at which each should start.
Internally, we identify a training block to its stage, which defines its mixing weights. To efficiently translate a data point's index within a block to a respective source dataset, we precompute prefix sums that track how many data points are seen by previous stages.
Fixes Issues
https://github.com/stanford-crfm/marin/issues/81
Unit test coverage
There are new unit tests in
test_varying_mixture.py
to ensure that the varying mixture behaves as expected.Known breaking changes/behaviors
The design enables traditional usage of the MixtureDataset class. However, some of the private quantities are different (i.e. the expected counts per block now depends on the block and is not a member variable). To my knowledge, these variables are not accessed outside of tests.
Additional context
I have some changes I want to make Marin to enable usage of this new functionality, though these updates are modular and can be seperate PR's. I have spot-checked that training proceeds as expected with this test. This is my first PR so feedback is appreciated :))