Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Oct 12, 2025

Summary

Implements a stochastic mixer layer for supernet training, enabling random sampling from multiple mixer options (e.g., attention vs. Mamba) during training. Includes checkpoint conversion support and a hierarchical beam search tool for finding optimal mixer placement post-training.

Implementation Details

Stochastic Mixer (fast_llm/layers/decoder/stochastic_mixer.py)

  • Training mode: Randomly samples from configured mixers using distributed RNG (ensures consistency across TP/PP ranks)
  • Eval/inference mode: Uses main_mixer_index for deterministic behavior
  • Sampling strategies: Uniform or weighted sampling
  • Preprocessing: Runs preprocessing for all mixers since we don't know which will be selected

Configuration (fast_llm/layers/decoder/config.py)

  • StochasticMixerConfig: List-based mixer configuration with sampling strategy
  • main_mixer_index: Specifies which mixer to use during inference and which receives pretrained weights during checkpoint conversion
  • Validation ensures sampling weights sum to 1.0 and all indices are valid

Checkpoint Conversion (fast_llm/models/gpt/conversion/apriel.py)

  • AprielStochasticMixerConverter: Handles conversion between Fast-LLM and Apriel formats
  • Only main_mixer_index weights are exported/imported (other mixers randomly initialized during supernet training)
  • Follows existing converter patterns (minimal, no verbose comments)

Beam Search Tool (tools/supernet_beam_search.py)

  • Hierarchical algorithm: Finds optimal placement for mixers at each quality/cost level
    • Phase 1: Find best N layers for primary mixer (e.g., full attention)
    • Phase 2: Find best M layers for secondary mixer (e.g., sliding window attention)
    • Remaining layers use tertiary mixer (e.g., linear attention)
  • Efficient evaluation: Loads checkpoint once, modifies main_mixer_index in-place for each candidate
  • Fast-LLM integration: Uses Fast-LLM's evaluation system directly (no subprocess or checkpoint reconversion)
  • Features: Pre-scoring, beam growth, early stopping, configurable score direction

Tests (tests/utils/model_configs.py)

  • Added stochastic_mixer test configuration with FA/Mamba mixers
  • Enabled checkpoint conversion testing via AprielHybridSSMCheckpointFormat

Use Case

Supernet Training: Train a model where each layer can be either full attention or Mamba, with random sampling at each step. After training, use beam search to find which specific layers benefit most from full attention vs. Mamba, given a budget constraint (e.g., "I can afford 4 FA layers").

Testing

Run the stochastic mixer tests:

pytest tests/models/test_checkpoint.py::test_checkpoint_and_eval tests/models/test_checkpoint.py::test_conversion -k "stochastic_mixer" -v

Example beam search usage:

fast-llm tools/supernet_beam_search.py \
  training_config=path/to/supernet_config.yaml \
  budgets=[4,8] \
  beam_width=12 \
  score_metric="lm_eval/accuracy" \
  output_path=results.json

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

Implements a stochastic mixer layer that randomly samples from multiple
mixer options during training, enabling supernet training where different
architecture variants (e.g., attention vs. Mamba) are trained with
different data subsets.

Key components:
- StochasticMixerConfig: Configuration for stochastic sampling strategy
  (uniform or weighted) with configurable main_mixer_index for inference
- StochasticMixer: Layer implementation with distributed RNG support
- Checkpoint conversion: Apriel converter handles stochastic mixers
- Beam search tool: Hierarchical beam search for optimal mixer placement

The beam search tool finds which layers benefit most from expensive mixers
(e.g., full attention) vs. efficient mixers (e.g., linear attention) by
evaluating different configurations using Fast-LLM's evaluation system.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
tscholak and others added 2 commits October 14, 2025 05:07
- Fix Assert.gt_len AttributeError by moving validation to _validate() method
- Add AttentionConfig import to models/auto.py for proper registration
- Mark all mixer parameters with allow_no_grad=True since only one mixer is active per forward pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Fixed nested config structure bug in AprielStochasticMixerConverter.import_config
that was causing validation errors when loading Apriel checkpoints.

The converter was returning the entire block config (with mixer, mlp, and
normalization keys) instead of just the mixer config, causing these fields
to be incorrectly nested under the mixer field during import.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, some minor comments


with set_generator(generator):
# Sample from categorical distribution
idx = torch.multinomial(self._sampling_probs, num_samples=1).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?


_abstract = False

mixers: list[MixerConfig] = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a dict so we can refer to them by name, ex. in debug?

mixer_idx = self._sample_mixer_index()

if self._debug.enabled:
logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous if multiple mixers share the same type. Use named mixers instead?

we need to preprocess for all of them. This includes things like
attention masks, rotary embeddings, etc.
"""
for mixer in self.mixers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be name conflicts. Consider namespace?


return int(expected_usage)

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.

return converter_class.mixer_converter_class.export_config(inference_mixer)

@classmethod
def get_converters(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about import? I don't think it will work.

mixer_converter_class.get_converters(
mixer,
f"{fast_llm_prefix}.mixers.{mixer_index}",
hf_prefix if is_main_mixer else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_prefix. drop_on_export handles the rest.

f"{hf_prefix}.{block_index}",
drop_on_export,
)
match config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think match is warranted here, since it involves a (slow) initialization of configs.

ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave as unimportant. All this tests is the consistency of stochastic sampling, and I don't think that warrants the overhead of testing every time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants