Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
02f8af5
Block interface
jlamypoirier Jul 29, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
b2f4476
Merge branch 'tp_mamba' into block_interface
jlamypoirier Jul 29, 2025
ce70b16
fixes
jlamypoirier Jul 29, 2025
a9f733d
fix
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
a5eb076
stuff
jlamypoirier Jul 31, 2025
ab484ac
Revert "stuff"
jlamypoirier Jul 31, 2025
b68d360
stuff
jlamypoirier Jul 31, 2025
82c9dbd
misc
jlamypoirier Jul 31, 2025
9fbb9ff
misc
jlamypoirier Jul 31, 2025
44df195
misc
jlamypoirier Jul 31, 2025
3bb03cb
misc
jlamypoirier Jul 31, 2025
98bae95
misc
jlamypoirier Jul 31, 2025
fd731ef
fixes
jlamypoirier Aug 1, 2025
f483321
fixes
jlamypoirier Aug 1, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
c16c00f
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 8, 2025
07c9211
stuff
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
0a5e458
Remove tensor space, fixes
jlamypoirier Aug 14, 2025
797bd73
stuff
jlamypoirier Aug 14, 2025
c0a3782
stuff
jlamypoirier Aug 15, 2025
e60ded4
stuff
jlamypoirier Aug 15, 2025
1483bcc
stuff
jlamypoirier Aug 15, 2025
4deb501
misc
jlamypoirier Aug 15, 2025
fc809e0
Misc, tests pass
jlamypoirier Aug 15, 2025
cdb6710
misc
jlamypoirier Aug 20, 2025
9ce72e0
Move files
jlamypoirier Aug 20, 2025
065b34f
misc
jlamypoirier Aug 20, 2025
4510b7b
misc
jlamypoirier Aug 20, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
90a3c98
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
39960ce
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
0e62f7d
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
654aeeb
Fix merge
jlamypoirier Aug 21, 2025
3f4a8ba
fix
jlamypoirier Aug 27, 2025
9741ba0
stuff
jlamypoirier Aug 27, 2025
be69677
fixes
jlamypoirier Aug 27, 2025
82a70aa
Simplify bias options
jlamypoirier Aug 27, 2025
680980a
stuff
jlamypoirier Aug 29, 2025
3ef7860
Dynamic mlp and block layer creation
jlamypoirier Aug 29, 2025
ecad96b
stuff
jlamypoirier Sep 3, 2025
3fd092c
fix
jlamypoirier Sep 3, 2025
1a3497c
stuff
jlamypoirier Sep 3, 2025
b6e7fce
stuff
jlamypoirier Sep 4, 2025
4dfe2a4
stuff
jlamypoirier Sep 9, 2025
4185741
misc
jlamypoirier Sep 9, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
95e0231
Merge branch 'tp_mamba' into block_interface
jlamypoirier Sep 17, 2025
e076c7a
Merge remote-tracking branch 'origin/main' into block_interface
jlamypoirier Sep 18, 2025
2315ac4
Merge branch 'block_interface' into block_interface_weight
jlamypoirier Sep 18, 2025
79356f7
Merge remote-tracking branch 'origin/main' into block_interface_weight
jlamypoirier Sep 18, 2025
e4198a6
Merge branch 'block_interface_weight' into block_interface_mixer_mlp_…
jlamypoirier Sep 18, 2025
7abf263
Merge branch 'block_interface_mixer_mlp_config' into block_interface_…
jlamypoirier Sep 18, 2025
bfc9f84
Merge branch 'block_interface_fine_grained' into block_interface_tflops
jlamypoirier Sep 18, 2025
4db4ccd
Merge remote-tracking branch 'origin/main' into block_interface_tflops
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
20 changes: 12 additions & 8 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,31 @@ model:
rotary:
type: default
theta: 10000
num_attention_heads: 32
heads: 32
head_groups: 8
kv_channels: 128
head_size: 128
add_linear_biases: false
window_size: 4096
attention_dropout: 0.0
dropout: 0.0
mlp:
ffn_hidden_size: 14336
intermediate_size: 14336
add_linear_biases: false
gated: true
activation_type: silu
activation: silu
normalization:
type: rms_norm
epsilon: 1.0e-05
num_layers: 32
hidden_size: 4096
add_linear_biases: false
init_method_std: 0.009021
hidden_dropout: 0.0
dropout: 0.0
embeddings_layer:
vocab_size: 32000
dropout: 0.0
output_layer:
tied_weight: false
normalization:
type: rms_norm
epsilon: 1.0e-05
multi_stage:
zero_stage: 2
distributed:
Expand Down
8 changes: 6 additions & 2 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
Expand Down Expand Up @@ -43,6 +43,9 @@ def forward(
) -> torch.Tensor:
pass

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
raise NotImplementedError()


class Sequential(Layer):
def __init__(self, distributed_config: DistributedConfig):
Expand Down Expand Up @@ -94,7 +97,8 @@ def __init__(
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)

for key, value in self.named_modules():
value.module_name = key
for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
Expand Down
12 changes: 12 additions & 0 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,15 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
@abc.abstractmethod
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
pass


@dataclasses.dataclass
class ResourceUsageConfig:
# Disable to get usage for current GPU only
global_: bool = True
# Enable to get hardware compute, i.e. include redundant computations.
hardware: bool = False
# Number of backward passes. Typically 1, may be 2 with full activation recomputation.
forward: int = 1
# Number of backward passes. Typically 1 for training, 0 for inference.
backward: int = 1
16 changes: 14 additions & 2 deletions fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import yaml

from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, config_class
from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class
from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig, configure_logging
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.utils import log, set_global_variables
from fast_llm.utils import Assert, log, set_global_variables

if typing.TYPE_CHECKING:
from fast_llm.engine.distributed.distributed import Distributed
Expand Down Expand Up @@ -58,6 +58,12 @@ class RunConfig(Config):
desc="Global switch to use triton kernels for linear layers. These may be slightly slower than the defaults.",
hint=FieldHint.performance,
)
model_debug_level: int = Field(
default=0,
desc="Debugging level for the model, ex. for printing intermediate model states.",
hint=FieldHint.logging,
valid=check_field(Assert.geq, 0),
)

def _validate(self):
if self.experiment_dir is None:
Expand Down Expand Up @@ -204,15 +210,21 @@ def open_artifact(self, name: str, mode: str | None = "w", verbose=True) -> path
return path if mode is None else path.open(mode)

def __enter__(self):
from fast_llm.logging import set_model_debug_level

assert not self._is_running
global _run
_run = self
TensorLogs.reset(self._config.tensor_logs)
set_model_debug_level(self._config.model_debug_level)

def __exit__(self, exc_type, exc_val: OSError, exc_tb):
from fast_llm.logging import set_model_debug_level

assert self._is_running
global _run
self.save_logged_tensors("none")
set_model_debug_level(0)
_run = None


Expand Down
18 changes: 11 additions & 7 deletions fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
import dataclasses
import functools
import logging
import math
import time
import typing

Expand Down Expand Up @@ -203,12 +205,10 @@ def _evaluate_loss(
)
end_time = time.perf_counter()
time_per_iteration = (end_time - begin_time) / num_iters
model_tflops, hardware_tflops = self._multi_stage.get_tflops(
phase,
time_per_iteration,
self._batch_config.batch_size,
self._batch_config.sequence_length,
)

model_compute, hardware_compute = self._schedule.compute_usage
model_tflops = math.nan if model_compute is None else model_compute / time_per_iteration
hardware_tflops = math.nan if hardware_compute is None else hardware_compute / time_per_iteration
# TODO add other relevant eval metrics
metrics = {
"batch_size": self._batch_config.batch_size,
Expand All @@ -218,7 +218,7 @@ def _evaluate_loss(
"hardware_tflops": hardware_tflops,
"tokens_per_sec_per_gpu": (
(self._batch_config.sequence_length * self._batch_config.batch_size)
/ self._schedule._distributed.world_size
/ self._schedule._distributed_config.world_size
/ time_per_iteration
),
**get_and_reset_memory_usage_mib(),
Expand All @@ -240,6 +240,10 @@ def _get_data_iterator(
prefetch_factor=prefetch_factor,
)

@functools.cached_property
def compute_usage(self) -> tuple[int | None, int | None]:
return self._schedule.get_compute_usage(hardware=False), self._schedule.get_compute_usage(hardware=True)


# NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation.
class EvaluatorRunner:
Expand Down
8 changes: 1 addition & 7 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import dataclasses
import logging
import typing
Expand All @@ -13,7 +12,7 @@
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType
from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode
from fast_llm.engine.multi_stage.fsdp import FSDP
Expand Down Expand Up @@ -252,11 +251,6 @@ def setup(self, distributed: Distributed | None = None, mode: StageMode = StageM

self.train(self._mode.support_backward)

@abc.abstractmethod
def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]:
# TODO: Do in model, automate/generalize, get other stats
pass

def _allocate_buffers(
self, buffer_meta: TensorMeta, sizes: list[int], name: str
) -> tuple[tuple[torch.Tensor, ...], int]:
Expand Down
13 changes: 13 additions & 0 deletions fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from fast_llm.core.distributed import check_parallel_match
from fast_llm.engine.base_model.config import ResourceUsageConfig
from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import StageConfig, StageMode
Expand Down Expand Up @@ -81,6 +82,7 @@ def setup( # noqa

def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta:
# Store the meta inputs and outputs, for debugging only.
# TODO: Varies if there are multiple schedules.
self._meta_inputs, self._meta_outputs = [], []
# TODO: use layer.forward_meta
for layer in self._layers:
Expand All @@ -93,6 +95,17 @@ def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta:
self._meta_outputs.append(input_)
return input_

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
total = 0
for layer in self._layers:
total += layer.get_compute_usage(input_, kwargs, config)
input_ = layer(
input_,
kwargs,
losses={},
)
return total

def forward(
self,
input_: torch.Tensor,
Expand Down
57 changes: 43 additions & 14 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import dataclasses
import functools
import logging
import typing
import warnings
Expand All @@ -9,6 +10,7 @@
import torch.utils
import torch.utils.data

from fast_llm.engine.base_model.config import ResourceUsageConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.multi_stage.multi_stage import MultiStageModel
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig, StepType
Expand Down Expand Up @@ -127,12 +129,12 @@ def __init__(
self._multi_stage = multi_stage
self._batch_config = batch_config
self._schedule_config = schedule_config
self._distributed = distributed_config
self._distributed_config = distributed_config
self._num_stages = len(self._multi_stage.stages)
self._phase = phase
self._is_training = self._phase.is_training

if self._batch_config.num_inputs < self._distributed.pipeline_parallel:
if self._batch_config.num_inputs < self._distributed_config.pipeline_parallel:
warnings.warn("Not enough input to achieve true pipeline parallelism.")

# Setup the activation metas.
Expand Down Expand Up @@ -172,7 +174,7 @@ def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]:
return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank])

def __iter__(self) -> typing.Iterator[Step]:
return self.iterate(self._distributed.pipeline_rank)
return self.iterate(self._distributed_config.pipeline_rank)

def __repr__(self) -> str:
return "Schedule with steps:\n" + "\n".join(
Expand All @@ -191,7 +193,7 @@ def get_step(
return self._step_map[(type_, stage, data_index)]

def _create_index(self) -> None:
self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)]
self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed_config.pipeline_parallel)]
self._step_map = {}
for i, step in enumerate(self._steps):
Assert.in_range(step.stage, 0, self._num_stages)
Expand All @@ -204,7 +206,7 @@ def _create_index(self) -> None:
step.global_index = i
# TODO: More configurable placement?

step.pipeline_rank = step.stage % self._distributed.pipeline_parallel
step.pipeline_rank = step.stage % self._distributed_config.pipeline_parallel
step.local_index = len(self._device_steps[step.pipeline_rank])
self._device_steps[step.pipeline_rank].append(step)
Assert.not_incl(map_index := step.map_index, self._step_map)
Expand Down Expand Up @@ -272,7 +274,7 @@ def _create_index(self) -> None:

def _setup_restore_steps(self, weight_buffer_indices: dict[int, int]) -> None:
for rank, device_steps in enumerate(self._device_steps):
if rank != self._distributed.pipeline_rank:
if rank != self._distributed_config.pipeline_rank:
# TODO: Make restore schedule for all ranks (need all buffer indices)
continue
buffer_contents, buffer_last_used = {}, {}
Expand All @@ -292,7 +294,7 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
if not self._is_training:
return
for rank, device_steps in enumerate(self._device_steps):
if rank != self._distributed.pipeline_rank:
if rank != self._distributed_config.pipeline_rank:
# TODO: Make restore schedule for all ranks (need all buffer indices)
continue
buffer_last_steps = {}
Expand All @@ -314,12 +316,12 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None:
for stage, count in enumerate(reduction_count):
assert (count > 0) == (
stage >= self._first_grad_stage
and (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank)
and (stage % self._distributed_config.pipeline_parallel == self._distributed_config.pipeline_rank)
)

def _setup_timeline(self) -> None:
# TODO: Include network time
idx = [0] * self._distributed.pipeline_parallel
idx = [0] * self._distributed_config.pipeline_parallel
done = False
while not done:
done = True
Expand Down Expand Up @@ -380,11 +382,11 @@ def _setup_send_recv_steps(self) -> None:
recv_step.recv_event = torch.cuda.Event()

def _validate_send_recv_steps(self) -> None:
times = [0.0] * self._distributed.pipeline_parallel
idx = [0] * self._distributed.pipeline_parallel
recv_idx = [0] * self._distributed.pipeline_parallel
statuses = ["Ok"] * self._distributed.pipeline_parallel
recv_queues: list[list[Step | None]] = [[] for _ in range(self._distributed.pipeline_parallel)]
times = [0.0] * self._distributed_config.pipeline_parallel
idx = [0] * self._distributed_config.pipeline_parallel
recv_idx = [0] * self._distributed_config.pipeline_parallel
statuses = ["Ok"] * self._distributed_config.pipeline_parallel
recv_queues: list[list[Step | None]] = [[] for _ in range(self._distributed_config.pipeline_parallel)]
done = False
while not done:
done = True
Expand Down Expand Up @@ -519,3 +521,30 @@ def _create_steps(self) -> tuple[list[Step], int]:
)
)
return steps, first_grad_stage

def get_compute_usage(
self,
global_: bool = True,
hardware: bool = False,
) -> int | None:
total = 0
try:
for step in self._steps if global_ else self._device_steps[self._distributed_config.pipeline_rank]:
if step.type_ == StepType.forward:
total += self._multi_stage.stages[step.stage].get_compute_usage(
step.meta_input,
step.meta_kwargs,
ResourceUsageConfig(
global_=global_,
hardware=hardware,
forward=1,
backward=int(self._is_training),
),
)
return total
except NotImplementedError:
return None

@functools.cached_property
def compute_usage(self) -> tuple[int | None, int | None]:
return self.get_compute_usage(True, False), self.get_compute_usage(True, True)
Loading