Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
747769b
feat(fsdp2): add _broadcast_sharded_state_dict, _get_non_persistent_b…
kevssim Mar 18, 2026
a69fb6c
feat(fsdp2): enable cpu_ram_efficient_loading for both strategies; pa…
kevssim Mar 18, 2026
c015e13
refactor(fsdp2): use _non_persistent_buffers_set for precise non-pers…
kevssim Mar 18, 2026
5fa4c4b
Merge branch 'optimize_fsdp_init' of https://github.com/kevssim/twink…
kevssim Mar 18, 2026
587f001
wip
kevssim Mar 18, 2026
983cdbc
test(fsdp2): make tests platform-agnostic (cuda/npu) via Platform API
kevssim Mar 18, 2026
d5d2832
fix(test): pass inputs as List[InputFeature] to forward_backward
kevssim Mar 18, 2026
5973173
fix(test): add position_ids to e2e test batch
kevssim Mar 18, 2026
accd03b
fix(test): simplify e2e test to only verify wrap_model, avoid process…
kevssim Mar 18, 2026
2c72aa4
fix(fsdp2): handle non-DTensor params (e.g. tied weights) in _broadca…
kevssim Mar 19, 2026
bf0e155
fix(fsdp2): move remaining CPU/meta params to device after tie_weights
kevssim Mar 19, 2026
e35a3d9
debug: add verbose logging to e2e test to diagnose CPU param issue
kevssim Mar 19, 2026
60c4a3a
debug: add verbose logging to wrap_model to trace execution path
kevssim Mar 19, 2026
93ef7e4
debug: add verbose logging to _lazy_wrap_model
kevssim Mar 19, 2026
638a996
debug: add device_mesh check to e2e test
kevssim Mar 19, 2026
a61db10
debug: print mesh before TransformersModel init
kevssim Mar 19, 2026
a06e894
fix(test): call twinkle.initialize() before TransformersModel to pres…
kevssim Mar 19, 2026
f8def97
cleanup: remove all debug print statements from native_fsdp.py and tr…
kevssim Mar 19, 2026
13c1d5f
wip
kevssim Mar 19, 2026
0438d9e
lint
kevssim Mar 19, 2026
44bf3d4
fix
kevssim Mar 19, 2026
3b82d1c
wip
kevssim Mar 19, 2026
beaa4fd
wip
kevssim Mar 19, 2026
560eb23
wip
kevssim Mar 19, 2026
d8f39b1
wip
kevssim Mar 19, 2026
cbb6191
lint
kevssim Mar 19, 2026
38e75cd
lint
kevssim Mar 19, 2026
e482625
wip
kevssim Mar 19, 2026
00fd199
clean
kevssim Mar 19, 2026
0b77d31
Merge remote-tracking branch 'origin/main' into optimize_fsdp_init
kevssim Mar 24, 2026
404096c
wip
kevssim Mar 24, 2026
eb13cda
wip
kevssim Mar 24, 2026
9158465
rename
kevssim Mar 25, 2026
9d97d84
wip
kevssim Mar 25, 2026
5fbd998
wip
kevssim Mar 25, 2026
109cf28
doc
kevssim Mar 25, 2026
62e680c
fix
kevssim Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source_en/Components/Model/TransformersModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TransformersModel:
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
**kwargs):
...

Expand All @@ -30,6 +31,7 @@ class TransformersModel:
- ddp_config: DDP configuration when strategy is `accelerate`, see: [DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155)
- fsdp_config: FSDP configuration when strategy is `accelerate`, see: [FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566)
- grad_scaler_config: PyTorch's grad_scaler initialization configuration, see: [PyTorch's GradScaler constructor](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25)
- memory_efficient_init: Whether to enable memory-efficient model initialization for FSDP. When enabled, only rank 0 loads full weights and broadcasts sharded parameters to other ranks, reducing peak memory usage during initialization. Default `False`. Note: The optimization currently only applies to transformers <= 4.57.6; for transformers >= 5.0.0, it may lead to negative performance impact.
- kwargs:
- If you don't want to pass the model config field, you can put scattered configurations here. These parameters will be passed to `from_pretrained` or `from_config` later.

Expand Down
2 changes: 2 additions & 0 deletions docs/source_zh/组件/模型/TransformersModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TransformersModel:
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
**kwargs):
...

Expand All @@ -30,6 +31,7 @@ class TransformersModel:
- ddp_config: strategy为`accelerate`时的DDP配置,参见:[DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155)
- fsdp_config: strategy为`accelerate`时的FSDP配置,参见:[FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566)
- grad_scaler_config: PyTorch的grad_scaler初始化配置,参见:[PyTorch的GradScaler构造](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25)
- memory_efficient_init: 是否启用FSDP内存高效初始化。启用后仅rank 0加载完整权重,其余rank通过广播获取分片参数,降低初始化阶段的内存和显存峰值。默认`False`。注意:该优化目前仅适用于 transformers <= 4.57.6;对于 transformers >= 5.0.0,可能会导致负面性能影响。
- kwargs:
- 如果你不希望传递模型config字段,可以把零星的配置从这里放置进去。后续这些参数会传递到`from_pretrained`或者`from_config`中。

Expand Down
16 changes: 10 additions & 6 deletions src/twinkle/model/transformers/strategy/accelerate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from typing import Any, Dict, Literal, Optional

from twinkle import DeviceMesh
from .load_context import fsdp_pretrained_load_context


class AccelerateStrategy:
Expand All @@ -21,13 +21,15 @@ def __init__(
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
):
from accelerate import Accelerator

self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._memory_efficient_init = memory_efficient_init
parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config)
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init)

kwargs_handlers = []
if ddp_config is not None:
Expand All @@ -42,6 +44,9 @@ def __init__(
kwargs_handlers=kwargs_handlers,
)

def pretrained_load_context(self):
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)

@staticmethod
def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
# TODO should test with transformers v5.0
Expand Down Expand Up @@ -69,7 +74,8 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):

return parallelism_config

def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]):
def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any],
memory_efficient: bool):
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy
Expand Down Expand Up @@ -107,11 +113,9 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
activation_checkpointing=fsdp_config.pop('activation_checkpointing', False),
auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa
reshard_after_forward=fsdp_config.pop('reshard_after_forward', True),
cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient),
**fsdp_config,
)
# Enable memory efficient model loading in transformers(see `is_fsdp_enabled` in transformers)
# os.environ['ACCELERATE_USE_FSDP'] = '1'
# os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = '1'
return fsdp_plugin

def wrap_model(self, model, *args):
Expand Down
27 changes: 27 additions & 0 deletions src/twinkle/model/transformers/strategy/load_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import contextlib
import os

_FSDP_EFFICIENT_LOADING_ENV = {
'ACCELERATE_USE_FSDP': 'true',
'FSDP_CPU_RAM_EFFICIENT_LOADING': 'true',
}


@contextlib.contextmanager
def fsdp_pretrained_load_context(enabled: bool):
"""Enable the env flags required for transformers FSDP-aware loading when needed."""
if not enabled:
yield
return

saved_env = {key: os.environ.get(key) for key in _FSDP_EFFICIENT_LOADING_ENV}
os.environ.update(_FSDP_EFFICIENT_LOADING_ENV)
try:
yield
finally:
for key, old_val in saved_env.items():
if old_val is None:
os.environ.pop(key, None)
else:
os.environ[key] = old_val
119 changes: 109 additions & 10 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set

from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context

if TYPE_CHECKING:
from torch.distributed.fsdp import MixedPrecisionPolicy
Expand All @@ -18,14 +19,19 @@ def __init__(self,
device_mesh: Optional[DeviceMesh] = None,
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
fsdp_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
enable_ep: bool = True,
ep_size: Optional[int] = None):
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self.fsdp_config = fsdp_config or {}
self._memory_efficient_init = memory_efficient_init
self.enable_ep = enable_ep
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None

def pretrained_load_context(self):
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)

def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
return None
Expand All @@ -48,6 +54,23 @@ def wrap_model(self, model, optimizer=None):
fsdp_mesh = _build_fsdp_mesh(self.device_mesh)
if fsdp_mesh is not None:
ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None)

# Drop optimizer references to pre-shard params before fully_shard to reduce peak memory.
if optimizer is not None:
_unbind_optimizer_params(optimizer)

# EP path requires experts on a real device, incompatible with meta-device flow.
use_meta = self._memory_efficient_init and not ep_enabled

original_sd = None
saved_buffers = None
if use_meta:
original_sd = model.state_dict()
saved_buffers = _get_non_persistent_buffers(model)
model = model.to(torch.device('meta'))
if hasattr(model, 'tie_weights'):
model.tie_weights()

if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
_place_ep_experts_on_local_device(model, self.ep_fsdp_device_mesh)
Expand All @@ -57,19 +80,16 @@ def wrap_model(self, model, optimizer=None):
if ep_enabled:
_ensure_ep_fsdp_supported(model)

# Collect experts map and expert params
experts_map = _collect_ep_experts_map(model) if ep_enabled else {}
expert_params = _collect_expert_params(model) if self.enable_ep else None

# Build layer_pairs: [(layer_mod, experts_mod_or_None)]
layers = _get_decoder_layers(model)
layer_pairs = []
if layers is not None:
for layer_mod in layers:
experts_mod = _find_experts_in_layer(layer_mod, experts_map)
layer_pairs.append((layer_mod, experts_mod))

# FSDP2 wrapping per layer
world_size = self.device_mesh.world_size
ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None

Expand All @@ -79,9 +99,6 @@ def wrap_model(self, model, optimizer=None):
if experts_mod is not None and ep_fsdp_mesh_1d is not None:
from torch.distributed.tensor import Shard

# PreMulSum (used by set_gradient_divide_factor) only supports
# float16/float32/float64; override reduce_dtype to float32
# when the base policy uses bfloat16.
ep_mp_policy = _build_ep_mp_policy(mp_policy)
fully_shard(
experts_mod,
Expand All @@ -90,7 +107,6 @@ def wrap_model(self, model, optimizer=None):
mp_policy=ep_mp_policy,
shard_placement_fn=lambda param: Shard(1),
)
# gradient_divide_factor = world_size
experts_mod.set_gradient_divide_factor(world_size)
layer_mod._fsdp_modules.append(experts_mod)

Expand All @@ -103,7 +119,6 @@ def wrap_model(self, model, optimizer=None):
)
layer_mod._fsdp_modules.append(layer_mod)

# Root model
fully_shard(
model,
mesh=fsdp_mesh,
Expand All @@ -112,11 +127,22 @@ def wrap_model(self, model, optimizer=None):
ignored_params=expert_params,
)

# Manual prefetch
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
is_rank0 = (dist.get_rank() == 0)
_broadcast_sharded_state_dict(
model,
original_sd if is_rank0 else {},
device_type=device_type,
)
target_device = torch.device(device_type)
_restore_non_persistent_buffers(model, saved_buffers, device=target_device)
if hasattr(model, 'tie_weights'):
model.tie_weights()

if ep_enabled and layer_pairs:
_setup_manual_prefetch([lp[0] for lp in layer_pairs])

# Rebuild groups after wrapping so grad clip sees the live Parameter objects.
if ep_enabled:
_rebuild_ep_param_groups(model)

Expand Down Expand Up @@ -398,3 +424,76 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor
return optimizer
optimizer.param_groups[0]['params'] = list(model.parameters())
return optimizer


def _broadcast_sharded_state_dict(
model: nn.Module,
full_sd: dict,
device_type: str = 'cuda',
) -> None:
"""Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
from torch.distributed.tensor import DTensor, distribute_tensor

meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)

for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
dtype = sharded_param.dtype

if is_rank0:
full_param = full_sd[param_name]
full_tensor = full_param.detach().to(device_type)
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
else:
full_tensor = torch.empty(shape, device=device_type, dtype=dtype)

dist.broadcast(full_tensor, src=0)
torch_util.synchronize()

device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements)
del full_tensor

sharded_sd[param_name] = sharded_tensor

model.load_state_dict(sharded_sd, assign=True)


def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Return {fqn: tensor} for non-persistent buffers (lost on to('meta'))."""
non_persistent_fqns: Set[str] = set()
for fqn, module in model.named_modules():
for buf_name in getattr(module, '_non_persistent_buffers_set', set()):
full_fqn = f'{fqn}.{buf_name}' if fqn else buf_name
non_persistent_fqns.add(full_fqn)

return {k: v.clone() for k, v in model.named_buffers() if k in non_persistent_fqns}


def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
"""Drop optimizer refs to pre-shard params before fully_shard to lower peak memory."""
for group in optimizer.param_groups:
for i in range(len(group['params'])):
param = group['params'][i]
group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device)


def _restore_non_persistent_buffers(
model: nn.Module,
saved_buffers: Dict[str, torch.Tensor],
device: torch.device,
) -> None:
"""Re-register non-persistent buffers saved before to('meta')."""
for fqn, buf_tensor in saved_buffers.items():
buf_tensor = buf_tensor.to(device)
if '.' in fqn:
parent_fqn, local_name = fqn.rsplit('.', 1)
parent = model.get_submodule(parent_fqn)
else:
local_name = fqn
parent = model
parent.register_buffer(local_name, buf_tensor, persistent=False)
18 changes: 14 additions & 4 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
**kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self._try_init_process_group()
Expand All @@ -201,6 +202,7 @@ def __init__(
self.mixed_precision = mixed_precision
self._fsdp_config = dict(fsdp_config or {})
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
self.grad_scaler_config = grad_scaler_config
if isinstance(model_cls, str):
Expand All @@ -209,8 +211,9 @@ def __init__(
self.model = model_cls.from_config(config, **kwargs)
else:
model_id = HubOperation.download_model(model_id)
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
# Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects.
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
with self.strategy.pretrained_load_context():
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
self.model.gradient_checkpointing_enable()
self.sp_strategy = None
self._model_wrapped = False
Expand All @@ -235,6 +238,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
mixed_precision=self.mixed_precision,
fsdp_config=self._fsdp_config,
device_mesh=self.device_mesh,
memory_efficient_init=self._memory_efficient_init,
enable_ep=self._enable_expert_parallel,
ep_size=ep_size,
)
Expand All @@ -243,7 +247,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
mixed_precision=self.mixed_precision,
ddp_config=self._ddp_config,
fsdp_config=self._fsdp_config,
device_mesh=self.device_mesh)
device_mesh=self.device_mesh,
memory_efficient_init=self._memory_efficient_init)

# Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size.
# We construct `sp_strategy` after the underlying HF model is initialized (see __init__).
Expand Down Expand Up @@ -290,6 +295,7 @@ def _lazy_wrap_model(self):
self._ensure_sp_strategy()
if self.sp_strategy is not None:
self.sp_strategy.initialize()

if len(optimizer_groups) == 1:
optimizer_group = optimizer_groups[0]
optimizer = optimizer_group.optimizer
Expand All @@ -299,7 +305,11 @@ def _lazy_wrap_model(self):
self.register_mm_forward_hook(optimizer_group)
else:
# maybe forward_only, no optimizer_group available
self.model = self.strategy.wrap_model(self.model)
result = self.strategy.wrap_model(self.model)
if isinstance(result, tuple):
self.model = result[0]
else:
self.model = result
self._model_wrapped = True

def register_mm_forward_hook(self, optimizer_group: OptimizerGroup):
Expand Down
Loading