Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# DeepSpeed Team

import torch
import contextlib
import functools
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator

try:
from torch.compiler import is_compiling as torch_is_compiling
Expand All @@ -16,6 +18,16 @@
# Torch does not have compiler support
torch_is_compiling = lambda: False

try:
if required_torch_version(min_version="2.6.0a"):
from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable
else:
from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable

_COMPILED_AUTOGRAD_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
_COMPILED_AUTOGRAD_AVAILABLE = False


def is_compile_supported():
return required_torch_version(min_version=2.1)
Expand Down Expand Up @@ -73,6 +85,22 @@ def is_compiling():
return torch_is_compiling()


@contextlib.contextmanager
def compiled_autograd(enabled: bool, kwargs: dict):
if not enabled or not _COMPILED_AUTOGRAD_AVAILABLE:
yield
return

if torch_is_compiling():
yield
return

compiler_fn = torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs)

with compiled_autograd_enable(compiler_fn):
yield


def dummy_decorator(func):
return func

Expand Down
39 changes: 27 additions & 12 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import is_compile_supported
from .compiler import is_compile_supported, compiled_autograd
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -446,6 +446,9 @@ def __init__(self,
# See also: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook
self.optimizer.register_grad_acc_post_hook(self._backward_post_hook)

self._is_compiled_autograd_enabled = False
self._compile_kwargs = {}

def _optimized_linear_offload_setup(self):
self.optimized_linear_base_weight_sharding = False
self.optimized_linear_lora_enabled = False
Expand Down Expand Up @@ -2476,17 +2479,18 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
elif self.torch_autocast_z0_gradscaler:
loss = self.torch_autocast_z0_gradscaler.scale(loss)

if self.zero_optimization() or not self.amp_enabled():
loss.backward(**backward_kwargs)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward(**backward_kwargs)
with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs):
if self.zero_optimization() or not self.amp_enabled():
loss.backward(**backward_kwargs)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward(**backward_kwargs)

# backward_epilogue is not called in a hook when self._support_torch_style_backward is False
self._backward_epilogue()
# backward_epilogue is not called in a hook when self._support_torch_style_backward is False
self._backward_epilogue()

self._running_engine_backward = False

Expand Down Expand Up @@ -4205,7 +4209,11 @@ def empty_partition_cache(self):
gc.collect()
get_accelerator().empty_cache()

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
def compile(self,
backend=get_accelerator().get_compile_backend(),
compile_kwargs={},
schedule=None,
compiled_autograd_enabled=False) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
Expand Down Expand Up @@ -4271,6 +4279,13 @@ def passes_name_to_fn(passes):
raise

self._is_compiled = True
self._compile_kwargs = compile_kwargs
if compiled_autograd_enabled:
if not self._deepcompile_active:
self._is_compiled_autograd_enabled = compiled_autograd_enabled
else:
logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.")
self._is_compiled_autograd_enabled = False

def _set_deepcompile_active(self, active: bool) -> None:
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""
Expand Down