-
Couldn't load subscription status.
- Fork 37
Add stochastic mixer for supernet training #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements a stochastic mixer layer that randomly samples from multiple mixer options during training, enabling supernet training where different architecture variants (e.g., attention vs. Mamba) are trained with different data subsets. Key components: - StochasticMixerConfig: Configuration for stochastic sampling strategy (uniform or weighted) with configurable main_mixer_index for inference - StochasticMixer: Layer implementation with distributed RNG support - Checkpoint conversion: Apriel converter handles stochastic mixers - Beam search tool: Hierarchical beam search for optimal mixer placement The beam search tool finds which layers benefit most from expensive mixers (e.g., full attention) vs. efficient mixers (e.g., linear attention) by evaluating different configurations using Fast-LLM's evaluation system. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Fix Assert.gt_len AttributeError by moving validation to _validate() method - Add AttentionConfig import to models/auto.py for proper registration - Mark all mixer parameters with allow_no_grad=True since only one mixer is active per forward pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Fixed nested config structure bug in AprielStochasticMixerConverter.import_config that was causing validation errors when loading Apriel checkpoints. The converter was returning the entire block config (with mixer, mlp, and normalization keys) instead of just the mixer config, causing these fields to be incorrectly nested under the mixer field during import. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some minor comments
|
|
||
| with set_generator(generator): | ||
| # Sample from categorical distribution | ||
| idx = torch.multinomial(self._sampling_probs, num_samples=1).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?
|
|
||
| _abstract = False | ||
|
|
||
| mixers: list[MixerConfig] = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a dict so we can refer to them by name, ex. in debug?
| mixer_idx = self._sample_mixer_index() | ||
|
|
||
| if self._debug.enabled: | ||
| logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous if multiple mixers share the same type. Use named mixers instead?
| we need to preprocess for all of them. This includes things like | ||
| attention masks, rotary embeddings, etc. | ||
| """ | ||
| for mixer in self.mixers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be name conflicts. Consider namespace?
|
|
||
| return int(expected_usage) | ||
|
|
||
| def get_loss_definitions(self, count: int = 1) -> list[LossDef]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.
| return converter_class.mixer_converter_class.export_config(inference_mixer) | ||
|
|
||
| @classmethod | ||
| def get_converters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about import? I don't think it will work.
| mixer_converter_class.get_converters( | ||
| mixer, | ||
| f"{fast_llm_prefix}.mixers.{mixer_index}", | ||
| hf_prefix if is_main_mixer else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf_prefix. drop_on_export handles the rest.
| f"{hf_prefix}.{block_index}", | ||
| drop_on_export, | ||
| ) | ||
| match config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think match is warranted here, since it involves a (slow) initialization of configs.
| ModelTestingGroup.convert: ModelTestingGroupAction.normal, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.normal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave as unimportant. All this tests is the consistency of stochastic sampling, and I don't think that warrants the overhead of testing every time.
Summary
Implements a stochastic mixer layer for supernet training, enabling random sampling from multiple mixer options (e.g., attention vs. Mamba) during training. Includes checkpoint conversion support and a hierarchical beam search tool for finding optimal mixer placement post-training.
Implementation Details
Stochastic Mixer (
fast_llm/layers/decoder/stochastic_mixer.py)main_mixer_indexfor deterministic behaviorConfiguration (
fast_llm/layers/decoder/config.py)StochasticMixerConfig: List-based mixer configuration with sampling strategymain_mixer_index: Specifies which mixer to use during inference and which receives pretrained weights during checkpoint conversionCheckpoint Conversion (
fast_llm/models/gpt/conversion/apriel.py)AprielStochasticMixerConverter: Handles conversion between Fast-LLM and Apriel formatsmain_mixer_indexweights are exported/imported (other mixers randomly initialized during supernet training)Beam Search Tool (
tools/supernet_beam_search.py)main_mixer_indexin-place for each candidateTests (
tests/utils/model_configs.py)stochastic_mixertest configuration with FA/Mamba mixersAprielHybridSSMCheckpointFormatUse Case
Supernet Training: Train a model where each layer can be either full attention or Mamba, with random sampling at each step. After training, use beam search to find which specific layers benefit most from full attention vs. Mamba, given a budget constraint (e.g., "I can afford 4 FA layers").
Testing
Run the stochastic mixer tests:
pytest tests/models/test_checkpoint.py::test_checkpoint_and_eval tests/models/test_checkpoint.py::test_conversion -k "stochastic_mixer" -vExample beam search usage:
fast-llm tools/supernet_beam_search.py \ training_config=path/to/supernet_config.yaml \ budgets=[4,8] \ beam_width=12 \ score_metric="lm_eval/accuracy" \ output_path=results.json🤖 Generated with Claude Code
Co-Authored-By: Claude [email protected]