Skip to content

Conversation

@samadwar
Copy link

What does this PR do?

Added support to load checkpoints from a single file where some modifications were required to convert_wan_transformer_to_diffusers method for it to work with WanAnimateTransformer3DModel

best regards,
Sam

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6
Copy link
Collaborator

DN6 commented Nov 21, 2025

Hi @samadwar do you have a single file version of Wan Animate we can use to test this PR?

@samadwar
Copy link
Author

samadwar commented Nov 21, 2025

@dg845
Copy link
Collaborator

dg845 commented Nov 22, 2025

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os

import torch

from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video

single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf"
# single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors"
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"

device = "cuda:0"
dtype = torch.bfloat16
seed = 42

transformer_kwargs = {}
_, single_file_ext = os.path.splitext(single_file_ckpt)
if single_file_ext == ".gguf":
    quantization_config = GGUFQuantizationConfig(compute_dtype=dtype)
    transformer_kwargs["quantization_config"] = quantization_config

transformer = WanAnimateTransformer3DModel.from_single_file(
    single_file_ckpt,
    config=model_id,
    subfolder="transformer",
    **transformer_kwargs,
)

pipe = WanAnimatePipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to(device)

image = load_image("/path/to/reference_image.png")
pose_video = load_video("/path/to/pose_video.mp4")
face_video = load_video("/path/to/face_video.mp4")

video = pipe(
    image=image,
    pose_video=pose_video,
    face_video=face_video,
    prompt="People in the video are doing actions.",
    height=720,
    width=1280,
    mode="animate",
    guidance_scale=1.0,
    num_inference_steps=20,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
).frames[0]

export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise:

wan_animate_single_file_gguf_20_step.mp4

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

@samadwar
Copy link
Author

Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses WanAnimateTransformer3DModel.from_single_file? I tried to test the PR using the following script:

import os

import torch

from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video

single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf"
# single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors"
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"

device = "cuda:0"
dtype = torch.bfloat16
seed = 42

transformer_kwargs = {}
_, single_file_ext = os.path.splitext(single_file_ckpt)
if single_file_ext == ".gguf":
    quantization_config = GGUFQuantizationConfig(compute_dtype=dtype)
    transformer_kwargs["quantization_config"] = quantization_config

transformer = WanAnimateTransformer3DModel.from_single_file(
    single_file_ckpt,
    config=model_id,
    subfolder="transformer",
    **transformer_kwargs,
)

pipe = WanAnimatePipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to(device)

image = load_image("/path/to/reference_image.png")
pose_video = load_video("/path/to/pose_video.mp4")
face_video = load_video("/path/to/face_video.mp4")

video = pipe(
    image=image,
    pose_video=pose_video,
    face_video=face_video,
    prompt="People in the video are doing actions.",
    height=720,
    width=1280,
    mode="animate",
    guidance_scale=1.0,
    num_inference_steps=20,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
).frames[0]

export_to_video(video, "wan_animate_single_file.mp4", fps=30)

Using a checkpoint from QuantStack/Wan2.2-Animate-14B-GGUF doesn't get any errors, but the generated samples seem to be just noise:
wan_animate_single_file_gguf_20_step.mp4

If I instead try a checkpoint from Kijai/WanVideo_comfy_fp8_scaled, I get an OOM error on a A100 (80 GB VRAM) and a lot of keys in the model don't seem to be used (they mainly end in .scale_weight, so they might be the FP8 scaling parameters?).

Yeah, I am experiencing same issue today, I had it working before, I will check and get back to you.

For the GGUF I am using AWS ml.g6e.4xlarge that comes with 45 GB VRAM, I don't have access to more GPU VRAM to test fp8. but I guess one way to check is load the file in safetensor package and check the actual value of the weights if they match or not.

@samadwar
Copy link
Author

samadwar commented Nov 22, 2025

@dg845 I fixed the issue, can you try now?

@samadwar
Copy link
Author

samadwar commented Nov 22, 2025

Code I am using:

import torch
import numpy as np
from diffusers import AutoencoderKLWan, GGUFQuantizationConfig
from diffusers import WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video
import os
from diffusers.utils import logging
from safetensors.torch import load_file

LoRA = True
device_cpu = torch.device("cpu")
device_gpu = torch.device("cuda")
original_model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
lora_model_id = "Kijai/WanVideo_comfy"
lora_model_path = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"

print("Loading transformer ....")
transformer = WanAnimateTransformer3DModel.from_single_file(
    "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q8_0.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    config=original_model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    offload_device="cpu",
    device=device_gpu
)
print("Transformer loaded successfully ....")

print("Loading pipeline ....")
pipe = WanAnimatePipeline.from_pretrained(
    original_model_id,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

if LoRA:
    pipe.load_lora_weights(
        lora_model_id,
        weight_name=lora_model_path,
        adapter_name="lightning",
        offload_device="cpu",
        device=device_gpu
    )

pipe.enable_model_cpu_offload()
print("Pipeline loaded successfully ....")

# Load the character image
image = load_image(
     "Wan2.2/examples/wan_animate/animate/image.jpeg"
 )

# Load pose and face videos (preprocessed from reference video)
# Note: Videos should be preprocessed to extract pose keypoints and face features
# Refer to the Wan-Animate preprocessing documentation for details
pose_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_pose.mp4")
face_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_face.mp4")

# Calculate optimal dimensions based on VAE constraints
max_area = 1280 * 720
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

prompt = "People in the video are doing actions."

# Animation mode: Animate the character with the motion from pose/face videos
print("Generating animation ....")
if LoRA:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=4,
        mode="animate",
    ).frames[0]
else:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=20,
        mode="animate",
    ).frames[0]
print("Exporting animation ....")
export_to_video(output, "output_animation__.mp4", fps=30)
output_animation__.mp4

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]:
timestep = timestep.to(time_embedder_dtype)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dg845 Do you know why this line exists? It seems to cause the white noise issue when time_embedder weights are in uint8, and line 811 would have an issue if timestep dtype does not match encoder_hidden_states. May be we need to remove lines 807 and 808?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants