Skip to content

Commit e0a6bb5

Browse files
Enabled compiled autograd for backward pass (#7667)
Compiled Autograd is an extension to torch.compile which enhances the autograd engine by capturing a larger backward computation graph at runtime. This allows a more comprehensive optimization of the backward pass during training. Overall, 5-20% speedup is expected in backward-heavy workloads with stable graphs. Disabled by default, the feature can be enabled from a user script by setting `compiled_autograd_enabled=True` when invoking the engine's `compile` method. Note, that bfloat16 + eager backend requires PyTorch >=2.5 (where partial fixes landed) or disabling compiled autograd for bfloat16 models (due to a known PyTorch bug in torch.compile PyTorch #152162/#161153) --------- Signed-off-by: Max Kovalenko <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 39a682d commit e0a6bb5

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

deepspeed/runtime/compiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# DeepSpeed Team
55

66
import torch
7+
import contextlib
78
import functools
89
from deepspeed.utils.torch import required_torch_version
10+
from deepspeed.accelerator import get_accelerator
911

1012
try:
1113
from torch.compiler import is_compiling as torch_is_compiling
@@ -16,6 +18,16 @@
1618
# Torch does not have compiler support
1719
torch_is_compiling = lambda: False
1820

21+
try:
22+
if required_torch_version(min_version="2.6.0a"):
23+
from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable
24+
else:
25+
from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable
26+
27+
_COMPILED_AUTOGRAD_AVAILABLE = True
28+
except (ImportError, ModuleNotFoundError):
29+
_COMPILED_AUTOGRAD_AVAILABLE = False
30+
1931

2032
def is_compile_supported():
2133
return required_torch_version(min_version=2.1)
@@ -73,6 +85,22 @@ def is_compiling():
7385
return torch_is_compiling()
7486

7587

88+
@contextlib.contextmanager
89+
def compiled_autograd(enabled: bool, kwargs: dict):
90+
if not enabled or not _COMPILED_AUTOGRAD_AVAILABLE:
91+
yield
92+
return
93+
94+
if torch_is_compiling():
95+
yield
96+
return
97+
98+
compiler_fn = torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs)
99+
100+
with compiled_autograd_enable(compiler_fn):
101+
yield
102+
103+
76104
def dummy_decorator(func):
77105
return func
78106

deepspeed/runtime/engine.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106

107107
from .pipe.module import PipelineModule
108108
from .utils import get_ma_status
109-
from .compiler import is_compile_supported
109+
from .compiler import is_compile_supported, compiled_autograd
110110
from ..ops.adam import FusedAdam
111111
from ..moe.sharded_moe import TopKGate, MOELayer
112112
from ..moe.layer import MoE
@@ -446,6 +446,9 @@ def __init__(self,
446446
# See also: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook
447447
self.optimizer.register_grad_acc_post_hook(self._backward_post_hook)
448448

449+
self._is_compiled_autograd_enabled = False
450+
self._compile_kwargs = {}
451+
449452
def _optimized_linear_offload_setup(self):
450453
self.optimized_linear_base_weight_sharding = False
451454
self.optimized_linear_lora_enabled = False
@@ -2476,17 +2479,18 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
24762479
elif self.torch_autocast_z0_gradscaler:
24772480
loss = self.torch_autocast_z0_gradscaler.scale(loss)
24782481

2479-
if self.zero_optimization() or not self.amp_enabled():
2480-
loss.backward(**backward_kwargs)
2481-
elif self.amp_enabled():
2482-
# AMP requires delaying unscale when inside gradient accumulation boundaries
2483-
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
2484-
delay_unscale = not self.is_gradient_accumulation_boundary()
2485-
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
2486-
scaled_loss.backward(**backward_kwargs)
2482+
with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs):
2483+
if self.zero_optimization() or not self.amp_enabled():
2484+
loss.backward(**backward_kwargs)
2485+
elif self.amp_enabled():
2486+
# AMP requires delaying unscale when inside gradient accumulation boundaries
2487+
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
2488+
delay_unscale = not self.is_gradient_accumulation_boundary()
2489+
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
2490+
scaled_loss.backward(**backward_kwargs)
24872491

2488-
# backward_epilogue is not called in a hook when self._support_torch_style_backward is False
2489-
self._backward_epilogue()
2492+
# backward_epilogue is not called in a hook when self._support_torch_style_backward is False
2493+
self._backward_epilogue()
24902494

24912495
self._running_engine_backward = False
24922496

@@ -4205,7 +4209,11 @@ def empty_partition_cache(self):
42054209
gc.collect()
42064210
get_accelerator().empty_cache()
42074211

4208-
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
4212+
def compile(self,
4213+
backend=get_accelerator().get_compile_backend(),
4214+
compile_kwargs={},
4215+
schedule=None,
4216+
compiled_autograd_enabled=False) -> None:
42094217
"""Compile the module using the specified backend and kwargs.
42104218
If a compiler_fn is set, it will be used instead of torch.compile().
42114219
"""
@@ -4271,6 +4279,13 @@ def passes_name_to_fn(passes):
42714279
raise
42724280

42734281
self._is_compiled = True
4282+
self._compile_kwargs = compile_kwargs
4283+
if compiled_autograd_enabled:
4284+
if not self._deepcompile_active:
4285+
self._is_compiled_autograd_enabled = compiled_autograd_enabled
4286+
else:
4287+
logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.")
4288+
self._is_compiled_autograd_enabled = False
42744289

42754290
def _set_deepcompile_active(self, active: bool) -> None:
42764291
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""

0 commit comments

Comments
 (0)