Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion docs/source/en/optimization/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,58 @@ config = TaylorSeerCacheConfig(
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
```
```

## MagCache

[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.

MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.

### Usage

To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.

1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.

```python
import torch
from diffusers import FluxPipeline, MagCacheConfig

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to("cuda")

# 1. Calibration Step
# Run full inference to measure model behavior.
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
pipe.transformer.enable_cache(calib_config)

# Run a prompt to trigger calibration
pipe("A cat playing chess", num_inference_steps=4)
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"

# 2. Inference Step
# Apply the specific ratios obtained from calibration for optimized speed.
# Note: For Flux models, you can also import defaults:
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
mag_config = MagCacheConfig(
mag_ratios=[1.0, 1.37, 0.97, 0.87],
num_inference_steps=4
)

pipe.transformer.enable_cache(mag_config)

image = pipe("A cat playing chess", num_inference_steps=4).images[0]
```

> [!NOTE]
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.

> [!TIP]
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).

> [!TIP]
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"MagCacheConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
Expand Down Expand Up @@ -912,12 +914,14 @@
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import MagCacheConfig, apply_mag_cache
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
8 changes: 7 additions & 1 deletion src/diffusers/hooks/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)

_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"layers",
"visual_transformer_blocks",
)
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

Expand Down
22 changes: 21 additions & 1 deletion src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
hidden_states_argument_name: str = "hidden_states"

_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
Expand Down Expand Up @@ -169,7 +170,7 @@ def _register_attention_processors_metadata():


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
Expand All @@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata():
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
Expand Down Expand Up @@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata():
),
)

TransformerBlockRegistry.register(
model_class=JointTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)

# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
TransformerBlockRegistry.register(
model_class=Kandinsky5TransformerDecoderBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
hidden_states_argument_name="visual_embed",
),
)


# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
Expand Down
Loading
Loading