Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
02f8af5
Block interface
jlamypoirier Jul 29, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
b2f4476
Merge branch 'tp_mamba' into block_interface
jlamypoirier Jul 29, 2025
ce70b16
fixes
jlamypoirier Jul 29, 2025
a9f733d
fix
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
a5eb076
stuff
jlamypoirier Jul 31, 2025
ab484ac
Revert "stuff"
jlamypoirier Jul 31, 2025
b68d360
stuff
jlamypoirier Jul 31, 2025
82c9dbd
misc
jlamypoirier Jul 31, 2025
9fbb9ff
misc
jlamypoirier Jul 31, 2025
44df195
misc
jlamypoirier Jul 31, 2025
3bb03cb
misc
jlamypoirier Jul 31, 2025
98bae95
misc
jlamypoirier Jul 31, 2025
fd731ef
fixes
jlamypoirier Aug 1, 2025
f483321
fixes
jlamypoirier Aug 1, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
c16c00f
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 8, 2025
07c9211
stuff
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
0a5e458
Remove tensor space, fixes
jlamypoirier Aug 14, 2025
797bd73
stuff
jlamypoirier Aug 14, 2025
c0a3782
stuff
jlamypoirier Aug 15, 2025
e60ded4
stuff
jlamypoirier Aug 15, 2025
1483bcc
stuff
jlamypoirier Aug 15, 2025
4deb501
misc
jlamypoirier Aug 15, 2025
fc809e0
Misc, tests pass
jlamypoirier Aug 15, 2025
cdb6710
misc
jlamypoirier Aug 20, 2025
9ce72e0
Move files
jlamypoirier Aug 20, 2025
065b34f
misc
jlamypoirier Aug 20, 2025
4510b7b
misc
jlamypoirier Aug 20, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
90a3c98
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
39960ce
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
0e62f7d
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
654aeeb
Fix merge
jlamypoirier Aug 21, 2025
3f4a8ba
fix
jlamypoirier Aug 27, 2025
9741ba0
stuff
jlamypoirier Aug 27, 2025
be69677
fixes
jlamypoirier Aug 27, 2025
82a70aa
Simplify bias options
jlamypoirier Aug 27, 2025
680980a
stuff
jlamypoirier Aug 29, 2025
3ef7860
Dynamic mlp and block layer creation
jlamypoirier Aug 29, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
95e0231
Merge branch 'tp_mamba' into block_interface
jlamypoirier Sep 17, 2025
e076c7a
Merge remote-tracking branch 'origin/main' into block_interface
jlamypoirier Sep 18, 2025
2315ac4
Merge branch 'block_interface' into block_interface_weight
jlamypoirier Sep 18, 2025
79356f7
Merge remote-tracking branch 'origin/main' into block_interface_weight
jlamypoirier Sep 18, 2025
e4198a6
Merge branch 'block_interface_weight' into block_interface_mixer_mlp_…
jlamypoirier Sep 18, 2025
e9900d2
Merge remote-tracking branch 'origin/main' into block_interface_mixer…
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,27 @@ optimizer:
model:
base_model:
transformer:
mixer:
type: attention
rotary:
type: default
theta: 10000
num_attention_heads: 32
head_groups: 8
kv_channels: 128
window_size: 4096
attention_dropout: 0.0
mlp:
ffn_hidden_size: 14336
gated: true
activation_type: silu
normalization:
type: rms_norm
epsilon: 1.0e-05
rotary:
type: default
theta: 10000
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 14336
num_attention_heads: 32
head_groups: 8
add_linear_biases: false
gated: true
activation_type: silu
kv_channels: 128
window_size: 4096
init_method_std: 0.009021
attention_dropout: 0.0
hidden_dropout: 0.0
vocab_size: 32000
tie_word_embeddings: false
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def max_length(self):
# check if it is absolute positional encoding and return max_position_embeddings
if hasattr(self._config.fast_llm_config.base_model, "transformer"):
# NOTE: will need to extend if more relative encoding types will be added
if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig):
if isinstance(self._config.fast_llm_config.base_model.transformer.mixer.rotary, NoRotaryConfig):
return self._config.fast_llm_config.base_model.max_position_embeddings

# check if tokenizer holds model sequence leigh info
Expand Down
19 changes: 5 additions & 14 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,6 @@ def __init__(

self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power)

init_method_qkv = init_normal_(
std=self._config.init_method_std_qkv,
min_val=self._config.init_method_min_qkv,
max_val=self._config.init_method_max_qkv,
)
init_method_std_attn_proj = init_normal_(
std=self._config.init_method_std_attn_proj,
min_val=self._config.init_method_min_attn_proj,
max_val=self._config.init_method_max_attn_proj,
)

lr_scale = combine_lr_scales(
self._lr_scale,
self._config.attention_lr_scale,
Expand All @@ -114,7 +103,7 @@ def __init__(
self.query = self._config.query_layer.get_layer(
hidden_dim,
query_dim,
default_weight_initializer=init_method_qkv,
default_weight_initializer=init_normal_(std=self._block_config.init_method_std),
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
Expand All @@ -123,7 +112,7 @@ def __init__(
self.key_value = self._config.query_layer.get_layer(
hidden_dim,
key_value_dim,
default_weight_initializer=init_method_qkv,
default_weight_initializer=init_normal_(std=self._block_config.init_method_std),
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
Expand All @@ -137,7 +126,9 @@ def __init__(
self.dense = self._config.dense_layer.get_layer(
dense_dim,
hidden_dim,
default_weight_initializer=init_method_std_attn_proj,
default_weight_initializer=init_normal_(
std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5,
),
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
Expand Down
8 changes: 1 addition & 7 deletions fast_llm/layers/attention/block.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import functools
import logging
import typing

from fast_llm.layers.attention.attention import Attention
from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig
from fast_llm.layers.block.block import Block

Expand All @@ -13,10 +11,6 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]):
# TODO: Standardize to `mixer`
_mixer_module_name: typing.ClassVar[str] = "self_attn"

@functools.cached_property
def _mixer_class(self) -> type[Attention]:
return Attention

@property
def _mixer_config(self) -> AttentionConfig:
return self._config
return self._config.mixer
95 changes: 22 additions & 73 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import functools
import logging
import typing
import warnings

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.layers.attention.rotary.config import RotaryConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig
from fast_llm.layers.common.linear.config import AffineLinearConfig
from fast_llm.utils import Assert, div

if typing.TYPE_CHECKING:
from fast_llm.layers.attention.attention import Attention

logger = logging.getLogger(__name__)


Expand All @@ -28,8 +32,8 @@ class AttentionKwargs(BlockKwargs):
past_key_values = "past_key_values"


@config_class()
class AttentionConfig(Config):
@config_class(dynamic_type={MixerConfig: "attention"})
class AttentionConfig(MixerConfig):
# TODO: Make mixer class dynamic.
_abstract = False

Expand Down Expand Up @@ -106,72 +110,26 @@ class AttentionConfig(Config):
" Under muP (if scaling number of heads instead of kv_channels): use 0.5.",
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
# TODO: Review initialization
init_method_std_qkv: float = Field(
default=None,
desc="Scale for the query, key and value weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_qkv: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_qkv: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_attn_proj: float = Field(
default=None,
desc="Scale for the attention projection weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_attn_proj: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for attention projection. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_attn_proj: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')",
hint=FieldHint.optional,
)

def _validate(self) -> None:
with self._set_implicit_default():
# TODO: Make this work without inheritance.
if self.kv_channels is None:
self.kv_channels = div(self.hidden_size, self.num_attention_heads)
# TODO: Review initialization
if self.init_method_std_qkv is None:
self.init_method_std_qkv = self.init_method_std
if self.init_method_std_attn_proj is None:
self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5
if self.init_method_max_qkv is None:
self.init_method_max_qkv = self.init_method_max
if self.init_method_min_qkv is None:
self.init_method_min_qkv = self.init_method_min
if self.init_method_max_attn_proj is None:
self.init_method_max_attn_proj = self.init_method_max
if self.init_method_min_attn_proj is None:
self.init_method_min_attn_proj = self.init_method_min
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min, self.init_method_max)
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv)
if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None:
Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj)
def set_defaults(self, hidden_size: int):
if self.kv_channels is None:
with self._set_implicit_default():
self.kv_channels = div(hidden_size, self.num_attention_heads)

def _validate(self) -> None:
super()._validate()

if not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")

Assert.multiple(self.num_attention_heads, self.head_groups)

@property
def layer_class(self) -> "type[Attention]":
from fast_llm.layers.attention.attention import Attention

return Attention

@functools.cached_property
def projection_size(self):
assert self._validated
Expand All @@ -183,16 +141,7 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:

@config_class()
# TODO: Use composition instead
class TransformerConfig(AttentionConfig, BlockConfig):
class TransformerConfig(BlockConfig):
_abstract = False

def _validate(self) -> None:
with self._set_implicit_default():
# Kept here for initialization order.
# TODO: Review initialization
if self.init_method_std is None:
self.init_method_std = self.hidden_size**-0.5
if self.init_method_min is not None and self.init_method_max is not None:
Assert.leq(self.init_method_min, self.init_method_max)

super()._validate()
# TODO: Make this unnecessary
mixer: AttentionConfig = FieldUpdate()
4 changes: 1 addition & 3 deletions fast_llm/layers/attention/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:

class FlashAttnVarlenPreprocessor(Preprocessor):
def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig):
self._config = config
self._distributed_config = distributed_config
assert self._config.do_use_flash_attention(self._distributed_config)
assert config.do_use_flash_attention(distributed_config)

def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
"""
Expand Down
19 changes: 4 additions & 15 deletions fast_llm/layers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage
from fast_llm.tensor import TensorMeta

Expand Down Expand Up @@ -174,8 +174,7 @@ def __init__(
setattr(
self,
self._mixer_module_name,
self._mixer_class(
self._mixer_config,
self._mixer_config.get_layer(
self._config,
self._distributed_config,
self._hidden_dim,
Expand All @@ -185,12 +184,7 @@ def __init__(
),
)

# TODO: Use dynamic type.
from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP
from fast_llm.layers.block.mlp.mlp import MLP

self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)(
self._config,
self.mlp = self._config.mlp.get_layer(
self._config,
self._distributed_config,
self._hidden_dim,
Expand All @@ -199,14 +193,9 @@ def __init__(
self._lr_scale,
)

@functools.cached_property
@abc.abstractmethod
def _mixer_class(self) -> type[BlockLayer]:
pass

@property
@abc.abstractmethod
def _mixer_config(self) -> Config:
def _mixer_config(self) -> MixerConfig:
pass

def setup(self, distributed: Distributed) -> None:
Expand Down
Loading