Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,12 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
)


def _optimizer_has_dtensor_params(optimizer: Optimizer) -> bool:
from torch.distributed.tensor import DTensor

return any(isinstance(param, DTensor) for group in optimizer.param_groups for param in group["params"])


def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
20 changes: 20 additions & 0 deletions src/lightning/fabric/utilities/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,23 @@ def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurs
if isinstance(obj, Module):
return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)))
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")


def _has_all_dtensor_params_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool:
"""Check whether all parameters and buffers of a given :class:`torch.nn.Module` or :class:`torch.optim.Optimizer`
are instances of :class:`torch.distributed.tensor.DTensor`."""
from torch.distributed.tensor import DTensor

if isinstance(obj, Optimizer):
return all(
isinstance(t, DTensor)
for param_group in obj.param_groups
for t in param_group["params"]
if isinstance(t, Parameter)
)
if isinstance(obj, Module):
return all(
isinstance(t, DTensor)
for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))
)
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
4 changes: 1 addition & 3 deletions src/lightning/fabric/utilities/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def _load_external_callbacks(group: str) -> list[Any]:
A list of all callbacks collected from external factories.

"""
factories = (
entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
)
factories = entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]

external_callbacks: list[Any] = []
for factory in factories:
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning.pytorch.plugins.precision.double import DoublePrecision
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision
from lightning.pytorch.plugins.precision.half import HalfPrecision
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
Expand All @@ -28,6 +29,7 @@
"Precision",
"TransformerEnginePrecision",
"FSDPPrecision",
"FSDP2Precision",
"XLAPrecision",
"LayerSync",
"TorchSyncBatchNorm",
Expand Down
110 changes: 110 additions & 0 deletions src/lightning/pytorch/plugins/precision/fsdp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager
from typing import Any

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import get_args, override

from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException


class FSDP2Precision(Precision):
"""Precision plugin for training with FSDP2 (Fully Sharded Data Parallel v2).

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Args:
precision: Full precision (32-true), half precision (16-true, bf16-true) or
mixed precision (16-mixed, bf16-mixed).
scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use.

Raises:
ValueError:
If unsupported ``precision`` is provided.

"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Any = None) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)

if scaler is not None:
raise ValueError(
f"`scaler` is not supported in `{self.__class__.__name__}`, found {scaler}."
"Use `mixed-precision policy` instead to configure the scaler."
)

if "mixed" in precision:
raise ValueError(
f"`precision={precision!r}` is not supported in `{self.__class__.__name__}`."
"Only `true` precision is supported."
"Use `mixed-precision policy (mp_policy)` instead to configure mixed precision."
)

self.precision = precision

precision_to_type = {
"bf16-true": torch.bfloat16,
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

@override
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)

@override
def module_init_context(self) -> AbstractContextManager:
# Use float32 for module parameter initialization to ensure numerical stability
return _DtypeContextManager(self._desired_input_dtype)

@override
def forward_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
2 changes: 2 additions & 0 deletions src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.strategies.fsdp2 import FSDP2Strategy
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
Expand All @@ -32,6 +33,7 @@
"DDPStrategy",
"DeepSpeedStrategy",
"FSDPStrategy",
"FSDP2Strategy",
"ModelParallelStrategy",
"ParallelStrategy",
"SingleDeviceStrategy",
Expand Down
Loading
Loading