-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add support for Magcache #12744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for Magcache #12744
Conversation
|
@leffff could you review as well if possible? |
|
Hi @AlanPonnachan @sayakpaul |
|
@leffff , Thank you for your review. To address this, I am implementing a Calibration Mode. My plan is to add a
Users can then simply run one calibration pass for their specific model/scheduler, copy the output ratios, and pass them into I am working on this update now and will push the changes shortly! |
Sounds great! |
|
Thanks for the thoughtful discussions here @AlanPonnachan and @leffff! I will leave my two cents below:
Ccing @DN6 to get his thoughts here, too. |
|
Thanks @sayakpaul and @leffff for the feedback! I have updated the PR to address these points. Instead of a standalone utility script, I integrated the calibration logic directly into the hook configuration for better usability:
Ready for review! |
|
Looks Great! Could you please provide a usage example:
And Provide Generations To be Sure it works, please provide generations for SD3.5 Medium, Flux, Wan T2V 2.1 1.3b I also believe, as caching is suitable for all tasks, can we also try Kandinsky 5.0 Video Pro I2V kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers |
1. Usage Example import torch
from diffusers import FluxPipeline from diffusers.hooks import MagCacheConfig, apply_mag_cache
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
# CALIBRATION STEP
config = MagCacheConfig(calibrate=True, num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)
# Logs: [1.0, 1.37, 0.97, 0.87]
# INFERENCE STEP
config = MagCacheConfig(mag_ratios=[1.0, 1.37, 0.97, 0.87], num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)2. Benchmark ResultsI validated the implementation on Flux, SD 3.5, and Wan 2.1 using a T4 Colab environment.
3. GenerationsAttached below are the outputs for the successful runs. |
|
Here is the Colab notebook used to generate the benchmarks above. It includes the full setup, memory optimizations (sequential offloading/dummy embeds), and the execution logs: |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This looks good! |
|
@AlanPonnachan thanks for your great work thus far! Some minor questions (mostly out of curiosity below):
Additionally, I could obtain outputs with Wan 1.3B and they look reasonable to me. Codeimport torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
num_inference_steps = 50
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
# apply_mag_cache(pipe.transformer, config)
config = MagCacheConfig(
mag_ratios=[1.0, 1.0337707996368408, 0.9908783435821533, 0.9898878931999207, 0.990186870098114, 0.989551305770874, 0.9898356199264526, 0.9901290535926819, 0.9913457632064819, 0.9893063902854919, 0.990695059299469, 0.9892956614494324, 0.9910416603088379, 0.9908630847930908, 0.9897039532661438, 0.9907404184341431, 0.98955237865448, 0.9905906915664673, 0.9881031513214111, 0.98977130651474, 0.9878108501434326, 0.9873648285865784, 0.98862624168396, 0.9870336055755615, 0.9855726957321167, 0.9857151508331299, 0.98496013879776, 0.9846605658531189, 0.9835416674613953, 0.984062671661377, 0.9805435538291931, 0.9828993678092957, 0.9804039001464844, 0.9776313304901123, 0.9769471883773804, 0.9752448201179504, 0.973810076713562, 0.9708614349365234, 0.9703076481819153, 0.9666262865066528, 0.9658275246620178, 0.9612534046173096, 0.9553734064102173, 0.9522399306297302, 0.9467942118644714, 0.9430344104766846, 0.9335862994194031, 0.9285727739334106, 0.9244886636734009, 0.9560992121696472],
num_inference_steps=num_inference_steps
)
apply_mag_cache(pipe.transformer, config)
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
).frames[0]
export_to_video(output, "output.mp4", fps=15)Outputs: # Calibation
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00, 1.91s/it]
# After using the `mag_ratios`
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00, 1.82it/s]Video output: output.mp4However, there seems to be a problem when using Kandinsky 5 and the error seems obvious to me. Error: https://pastebin.com/F7arxTWg Codeimport torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
num_inference_steps = 50
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
apply_mag_cache(pipe.transformer, config)
# config = MagCacheConfig(
# mag_ratios=[...],
# num_inference_steps=num_inference_steps
# )
# apply_mag_cache(pipe.transformer, config)
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=num_inference_steps,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output_kandinsky.mp4", fps=24, quality=9)For this, instead of a line like the following maybe we could pass it to the cache config? I understand this could be difficult for the users but my thought is since they have to perform calibration anyway, this is still reasonable? Just for curiosity, I changed to: diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py
index 71ebfcb25..0a7c333db 100644
--- a/src/diffusers/hooks/mag_cache.py
+++ b/src/diffusers/hooks/mag_cache.py
@@ -183,7 +183,7 @@ class MagCacheHeadHook(ModelHook):
self.state_manager.set_context("inference")
# Capture input hidden_states
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
state: MagCacheState = self.state_manager.get_state()
state.head_block_input = hidden_states
@@ -297,7 +297,7 @@ class MagCacheBlockHook(ModelHook):
state: MagCacheState = self.state_manager.get_state()
if not state.should_compute:
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
if self.is_tail:
# Still need to advance step index even if we skip
self._advance_step(state)
And ran the above code. But I am getting a pair of Unfold[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0096147060394287, 0.8601706027984619, 1.0066865682601929, 1.1018145084381104, 1.0066889524459839, 1.07235848903656, 1.006271243095398, 1.0583757162094116, 1.0066468715667725, 1.0803261995315552, 1.0059221982955933, 1.0304542779922485, 1.0061317682266235, 1.0251237154006958, 1.006355881690979, 1.0230522155761719, 1.0063568353652954, 1.0354706048965454, 1.006076455116272, 1.0154225826263428, 1.0064369440078735, 1.0257697105407715, 1.0066747665405273, 1.012341856956482, 1.0068379640579224, 1.017471432685852, 1.0070058107376099, 1.008599877357483, 1.00702702999115, 1.0158008337020874, 1.0070949792861938, 1.0113613605499268, 1.0063375234603882, 1.0122487545013428, 1.0064034461975098, 1.0091496706008911, 1.0062494277954102, 1.0109937191009521, 1.0061204433441162, 1.0084550380706787, 1.0059889554977417, 1.006821870803833, 1.0058847665786743, 1.0106556415557861, 1.005847454071045, 1.0057544708251953, 1.0058276653289795, 1.0092748403549194, 1.005746841430664]
[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0056898593902588, 1.0074970722198486, 1.005563735961914, 1.0061627626419067, 1.0054070949554443, 1.0053973197937012, 1.0052893161773682, 1.0067739486694336, 1.0051906108856201, 1.0049010515213013, 1.0050380229949951, 1.0056493282318115, 1.0049028396606445, 1.0056771039962769, 1.0048167705535889, 1.0038255453109741, 1.0047082901000977, 1.0041747093200684, 1.004562258720398, 1.002451777458191, 1.0044060945510864, 1.0022073984146118, 1.0042728185653687, 1.0011045932769775, 1.0041989088058472, 0.9996317625045776, 1.0040632486343384, 0.9980409741401672, 1.0038821697235107, 0.9960299134254456, 1.004146933555603, 0.9924721717834473, 1.0041824579238892, 0.9876144528388977, 1.0041331052780151, 0.9839898943901062, 1.003833293914795, 0.976319432258606, 1.0032036304473877, 0.9627748131752014, 1.002505898475647, 0.9450504779815674, 1.001646637916565, 0.9085856080055237, 0.9999536275863647, 0.8368133306503296, 0.9975034594535828, 0.6354470252990723, 0.9997955560684204]When applying the first one, I got: output_kandinsky.mp4When applying the second one, I got: output_kandinsky_2.mp4Thought this would help :) |
|
@sayakpaul thank you for running inferences from your side, it helped a lot. 1. Regarding
|
Makes sense, yeah!
This is awesome. Let's make sure we document it once we're at that point.
Okay then this needs to be documented as well. However, there are some small models where we run CFG in a batched manner. Would that affect Cc: @Zehong-Ma! Hey maybe you would like to review the PR as well :) |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
I have left some comments, LMK what you think of them.
Let's add documentation and button up testing :)
Thanks for your review and the contribution of @AlanPonnachan . I have briefly reviewed the pull request. Most of your discussion are correct and concise. There may be two important things that should be clearly discussed or fixed.
|
From my observation of the codebase, I found that maintaining distinct states for conditional/unconditional passes is quite difficult with the current architecture. |
updated with this config |
|
@sayakpaul , added documentation and test. Please check and let me know any changes required. |
Thanks for this (and I agree with this approach)! Should we also document what the users would need to do / how to proceed when the CFG is implemented in the batched manner (SDXL, for example)? Additionally, @Zehong-Ma WDYT? |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good to me but I am kinda debating if we should have a single unified hook class instead of maintaining two. This will reduce the complexity a bit and make it likely simpler.
Here's the signature I was thinking in my head: MagCacheHook(state_manager, config, role=role) where role would be "head" or "tail".
Or do you think it's easier to maintain two classes?
Not strongly opinionated. @DN6 WDYT?
Also, if you have a chance to test it with torch.compile and report any performance gains from that, it'd be golden. Not a priority, though.
|
@Meatfucker if you want to test the PR. |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
I ran the I used the below script: import torch
import time
import gc
from diffusers import StableDiffusion3Pipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
torch.set_float32_matmul_precision('high')
def flush():
gc.collect()
torch.cuda.empty_cache()
print(f" Benchmarking on {torch.cuda.get_device_name(0)}...")
# Load SD3.5 (No T5 to fit in 8GB VRAM)
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
text_encoder_3=None,
tokenizer_3=None,
torch_dtype=torch.bfloat16
).to("cuda")
# Setup MagCache
steps = 20
config = MagCacheConfig(
mag_ratios=[1.0] * steps,
num_inference_steps=steps,
threshold=0.05
)
apply_mag_cache(pipe.transformer, config)
# Resolution: 512x512 is safe for 8GB. 1024x1024 will likely OOM.
kwargs = {"height": 512, "width": 512, "num_inference_steps": steps}
prompt = "A photo of a fast car"
# --- RUN 1: EAGER MODE ---
print("\n>> [1/2] Benchmarking MagCache (Eager)...")
# Warmup
pipe(prompt, **kwargs)
torch.cuda.synchronize()
start = time.time()
pipe(prompt, **kwargs)
torch.cuda.synchronize()
eager_time = time.time() - start
print(f" Eager Time: {eager_time:.4f}s")
# --- RUN 2: COMPILED MODE ---
print("\n>> [2/2] Benchmarking MagCache (torch.compile)...")
print(" Compiling transformer ...")
# 'max-autotune' gives best speed but uses more memory/time.
# 'reduce-overhead' is safer for 8GB VRAM.
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False)
try:
# Warmup (Trigger compilation)
start_compile = time.time()
pipe(prompt, **kwargs)
print(f" Compilation + Warmup took: {time.time() - start_compile:.2f}s")
# Benchmark
torch.cuda.synchronize()
start = time.time()
pipe(prompt, **kwargs)
torch.cuda.synchronize()
compile_time = time.time() - start
print(f" Compile Time: {compile_time:.4f}s")
print(f"\nSpeedup: {eager_time / compile_time:.2f}x")
except Exception as e:
print(f"\nCompilation {e}")
Results (SD 3.5 Medium, 512px):
successful execution confirms that |
Regarding Single vs Two Classes: Merging them would likely require injecting conditional logic (e.g., That said, I don't hold a strong opinion here—if you prefer a unified class to keep the file smaller, I am happy to refactor it!" |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
It definitely wasn't trivial at all. I am also on the same page regarding your findings for torch.compile.
@leffff do you want to test it one final time for Kandinsky?
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |






What does this PR do?
This PR adds support for MagCache (Magnitude-aware Cache), a training-free inference acceleration method for diffusion models, specifically targeting Transformer-based architectures like Flux.
This implementation follows the
ModelHookpattern (similar toFirstBlockCache) to integrate seamlessly into Diffusers.Key features:
MagCacheConfig: Configuration class to control threshold, retention ratio, and skipping limits.calibrate=Trueflag. When enabled, the hook runs full inference and calculates/prints the magnitude ratios for the specific model and scheduler. This makes MagCache compatible with any transformer model (e.g., Hunyuan, Wan, SD3), not just Flux.mag_ratiosmust be explicitly provided in the config (or calibration enabled).FLUX_MAG_RATIOSas a constant for convenience, derived from the official implementation.Fixes Magcache Support. #12697
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul