Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ model:
activation_type: silu
kv_channels: 128
window_size: 4096
init_method_std: 0.009021
# init_method_std: 0.009021
attention_dropout: 0.0
hidden_dropout: 0.0
vocab_size: 32000
Expand Down
1 change: 1 addition & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,7 @@ def __init_subclass__(cls):
compare=value.pop("compare", base_class_field.compare),
metadata=value.pop("metadata", base_class_field.metadata),
kw_only=value.pop("kw_only", base_class_field.kw_only),
init=value.pop("init", base_class_field.init),
),
)
if name in cls.__annotations__:
Expand Down
109 changes: 109 additions & 0 deletions fast_llm/engine/config_utils/initialization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,121 @@
import abc
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
import torch

from fast_llm.tensor import ParameterMeta


@config_class(registry=True)
class InitializationConfig(Config):
_abstract = True
is_default: typing.ClassVar[bool] = False

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None:
# Default subclass.
return DefaultInitializationConfig._from_dict(default, strict, flat)
return super()._from_dict(default, strict=strict, flat=flat)

def get_initializer(self) -> "Initializer":
raise NotImplementedError()


@config_class()
class DefaultInitializationConfig(InitializationConfig):
# A placeholder indicating that the class default should be used instead.
_abstract = False
is_default = True


@config_class(dynamic_type={InitializationConfig: "fill"})
class FillInitializationConfig(InitializationConfig):
"""
Normal initialization: normal(mean, std).clamp(min,max)
"""

_abstract = False

value: float = Field(
default=1,
desc="Initialization value.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

def get_initializer(self):
return init_fill_(self.value)


@config_class(dynamic_type={InitializationConfig: "normal"})
class NormalInitializationConfig(InitializationConfig):
"""
Normal initialization: normal(mean, std).clamp(min,max)
"""

_abstract = False

std: float = Field(
default=1,
desc="Standard deviation for normal initialization.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
mean: float = Field(
default=0,
desc="Mean for normal initialization.",
hint=FieldHint.optional,
)
min: float | None = Field(
default=None,
desc="Min value for initialization clamping.",
hint=FieldHint.optional,
)
max: float | None = Field(
default=None,
desc="Min value for initialization clamping.",
hint=FieldHint.optional,
)

def get_initializer(self):
return init_normal_(self.mean, self.std, self.min, self.max)


@config_class(dynamic_type={InitializationConfig: "uniform"})
class UniformInitializationConfig(InitializationConfig):
"""
Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max)
"""

_abstract = False

scale: float = Field(
default=None,
desc="Initialization scale.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
mean: float = Field(
default=None,
desc="Initialization mean.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

def get_initializer(self) -> "Initializer":
return init_uniform_centered_(self.scale, self.mean)


class Initializer(abc.ABC):
@abc.abstractmethod
def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None:
Expand Down
53 changes: 12 additions & 41 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

from fast_llm.core.distributed import set_generator
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.config import BlockConfig, BlockDimNames
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.utils import combine_lr_scales, div
from fast_llm.utils import div

try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa
Expand Down Expand Up @@ -95,62 +92,36 @@ def __init__(

self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power)

init_method_qkv = init_normal_(
std=self._config.init_method_std_qkv,
min_val=self._config.init_method_min_qkv,
max_val=self._config.init_method_max_qkv,
)
init_method_std_attn_proj = init_normal_(
std=self._config.init_method_std_attn_proj,
min_val=self._config.init_method_min_attn_proj,
max_val=self._config.init_method_max_attn_proj,
)

lr_scale = combine_lr_scales(
self._lr_scale,
self._config.attention_lr_scale,
)

# TODO: Merge the query and key-value computations? (harder with sequence parallel.)
self.query = OutputParallelLinear(
self.query = self._config.query_layer.get_layer(
hidden_dim,
query_dim,
bias=self._config.add_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_zeros_,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=self._block_config.peft,
)
self.key_value = OutputParallelLinear(
# TODO: Mix key and value configs
self.key_value = self._config.query_layer.get_layer(
hidden_dim,
key_value_dim,
bias=self._config.add_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_zeros_,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=self._block_config.peft,
)
self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

# Rotary embeddings.
self._rotary = self._config.rotary.get_layer(kv_channels_dim)

# Output.
self.dense = InputParallelLinear(
self.dense = self._config.dense_layer.get_layer(
dense_dim,
hidden_dim,
bias=self._config.add_dense_bias,
weight_init_method=init_method_std_attn_proj,
bias_init_method=init_zeros_,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=self._block_config.peft,
)

# PEFT.
self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query)
self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value)
self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense)

if self._debug.enabled:
self._query_dims = (
BlockDimNames.batch,
Expand Down Expand Up @@ -323,7 +294,7 @@ def forward(
query = query.transpose(0, 1).contiguous()
key_value = key_value.transpose(0, 1).contiguous()

key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1)
key, value = key_value.chunk(2, dim=-1)

query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels)
key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels)
Expand Down
116 changes: 33 additions & 83 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.initialization import FillInitializationConfig, NormalInitializationConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.layers.attention.rotary.config import RotaryConfig
from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.linear.config import AffineLinearConfig
from fast_llm.utils import Assert, div

logger = logging.getLogger(__name__)
Expand All @@ -33,6 +35,23 @@ class AttentionConfig(Config):
_abstract = False

# TODO: Review names
query_layer: AffineLinearConfig = Field(
desc="Configuration for the query layer.",
hint=FieldHint.architecture,
)
key_layer: AffineLinearConfig = Field(
desc="Configuration for the key layer.",
hint=FieldHint.architecture,
)
# TODO: Use
value_layer: AffineLinearConfig = Field(
desc="Configuration for the value layer.",
hint=FieldHint.architecture,
)
dense_layer: AffineLinearConfig = Field(
desc="Initialization configuration for the dense layer.",
hint=FieldHint.feature,
)
rotary: RotaryConfig = Field(
desc="Configuration for the rotary positional embeddings.",
hint=FieldHint.architecture,
Expand Down Expand Up @@ -88,65 +107,25 @@ class AttentionConfig(Config):
" Under muP (if scaling number of heads instead of kv_channels): use 0.5.",
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
# TODO: Review initialization
init_method_std_qkv: float = Field(
default=None,
desc="Scale for the query, key and value weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_qkv: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_qkv: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_attn_proj: float = Field(
default=None,
desc="Scale for the attention projection weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_attn_proj: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for attention projection. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_attn_proj: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')",
hint=FieldHint.optional,
)

def _validate(self) -> None:
with self._set_implicit_default():
# TODO: Make this work without inheritance.
if self.kv_channels is None:
self.kv_channels = div(self.hidden_size, self.num_attention_heads)
# TODO: Review initialization
if self.init_method_std_qkv is None:
self.init_method_std_qkv = self.init_method_std
if self.init_method_std_attn_proj is None:
self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5
if self.init_method_max_qkv is None:
self.init_method_max_qkv = self.init_method_max
if self.init_method_min_qkv is None:
self.init_method_min_qkv = self.init_method_min
if self.init_method_max_attn_proj is None:
self.init_method_max_attn_proj = self.init_method_max
if self.init_method_min_attn_proj is None:
self.init_method_min_attn_proj = self.init_method_min
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min, self.init_method_max)
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv)
if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None:
Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj)

# TODO: Block variables as defaults?
for layer, scale, apply_peft in zip(
(self.query_layer, self.key_layer, self.value_layer, self.dense_layer),
(1, 1, 1, 2 * max(self.num_layers, 1)),
(True, False, True, False),
):
layer.default = AffineLinearConfig(
bias=True,
weight_initialization=NormalInitializationConfig(std=(self.hidden_size * scale) ** -0.5),
bias_initialization=FillInitializationConfig(value=0),
lr_scale=None,
apply_peft=True,
)
super()._validate()

if not TritonConfig.TRITON_ENABLED:
Expand All @@ -162,37 +141,8 @@ def projection_size(self):
def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)

@property
def add_qkv_bias(self) -> bool:
# TODO: Make this work without inheritance.
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.nowhere:
return False
return True

@property
def add_dense_bias(self) -> bool:
# TODO: Make this work without inheritance.
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.everywhere:
return True
return False


@config_class()
# TODO: Use composition instead
class TransformerConfig(AttentionConfig, BlockConfig):
_abstract = False

def _validate(self) -> None:
with self._set_implicit_default():
# Kept here for initialization order.
# TODO: Review initialization
if self.init_method_std is None:
self.init_method_std = self.hidden_size**-0.5
if self.init_method_min is not None and self.init_method_max is not None:
Assert.leq(self.init_method_min, self.init_method_max)

super()._validate()
Loading