-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[From Single File] support from_single_file method for WanAnimateTransformer3DModel
#12691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi @samadwar do you have a single file version of Wan Animate we can use to test this PR? |
yes, https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf or any file from here https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/tree/main/Wan22Animate |
|
Hi @samadwar, thanks for the PR! Would you be able to share an example of a code snippet which uses import os
import torch
from diffusers import GGUFQuantizationConfig, WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video
single_file_ckpt = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q4_K_M.gguf"
# single_file_ckpt = "https://huggingface.co/Kijai/WanVideo_comfy_fp8_scaled/blob/main/Wan22Animate/Wan2_2-Animate-14B_fp8_scaled_e4m3fn_KJ_v2.safetensors"
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
device = "cuda:0"
dtype = torch.bfloat16
seed = 42
transformer_kwargs = {}
_, single_file_ext = os.path.splitext(single_file_ckpt)
if single_file_ext == ".gguf":
quantization_config = GGUFQuantizationConfig(compute_dtype=dtype)
transformer_kwargs["quantization_config"] = quantization_config
transformer = WanAnimateTransformer3DModel.from_single_file(
single_file_ckpt,
config=model_id,
subfolder="transformer",
**transformer_kwargs,
)
pipe = WanAnimatePipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to(device)
image = load_image("/path/to/reference_image.png")
pose_video = load_video("/path/to/pose_video.mp4")
face_video = load_video("/path/to/face_video.mp4")
video = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt="People in the video are doing actions.",
height=720,
width=1280,
mode="animate",
guidance_scale=1.0,
num_inference_steps=20,
generator=torch.Generator(device=device).manual_seed(seed),
output_type="np",
).frames[0]
export_to_video(video, "wan_animate_single_file.mp4", fps=30)Using a checkpoint from wan_animate_single_file_gguf_20_step.mp4If I instead try a checkpoint from |
Yeah, I am experiencing same issue today, I had it working before, I will check and get back to you. For the GGUF I am using AWS ml.g6e.4xlarge that comes with 45 GB VRAM, I don't have access to more GPU VRAM to test fp8. but I guess one way to check is load the file in safetensor package and check the actual value of the weights if they match or not. |
|
@dg845 I fixed the issue, can you try now? |
|
Code I am using: import torch
import numpy as np
from diffusers import AutoencoderKLWan, GGUFQuantizationConfig
from diffusers import WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video
import os
from diffusers.utils import logging
from safetensors.torch import load_file
LoRA = True
device_cpu = torch.device("cpu")
device_gpu = torch.device("cuda")
original_model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
lora_model_id = "Kijai/WanVideo_comfy"
lora_model_path = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"
print("Loading transformer ....")
transformer = WanAnimateTransformer3DModel.from_single_file(
"https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q8_0.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config=original_model_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
offload_device="cpu",
device=device_gpu
)
print("Transformer loaded successfully ....")
print("Loading pipeline ....")
pipe = WanAnimatePipeline.from_pretrained(
original_model_id,
transformer=transformer,
torch_dtype=torch.bfloat16,
)
if LoRA:
pipe.load_lora_weights(
lora_model_id,
weight_name=lora_model_path,
adapter_name="lightning",
offload_device="cpu",
device=device_gpu
)
pipe.enable_model_cpu_offload()
print("Pipeline loaded successfully ....")
# Load the character image
image = load_image(
"Wan2.2/examples/wan_animate/animate/image.jpeg"
)
# Load pose and face videos (preprocessed from reference video)
# Note: Videos should be preprocessed to extract pose keypoints and face features
# Refer to the Wan-Animate preprocessing documentation for details
pose_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_pose.mp4")
face_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_face.mp4")
# Calculate optimal dimensions based on VAE constraints
max_area = 1280 * 720
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = "People in the video are doing actions."
# Animation mode: Animate the character with the motion from pose/face videos
print("Generating animation ....")
if LoRA:
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
# negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
guidance_scale=1.0,
prev_segment_conditioning_frames=1, # refert_num in original code
num_inference_steps=4,
mode="animate",
).frames[0]
else:
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
# negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=77,
guidance_scale=1.0,
prev_segment_conditioning_frames=1, # refert_num in original code
num_inference_steps=20,
mode="animate",
).frames[0]
print("Exporting animation ....")
export_to_video(output, "output_animation__.mp4", fps=30)output_animation__.mp4 |
| time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | ||
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | ||
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]: | ||
| timestep = timestep.to(time_embedder_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 Do you know why this line exists? It seems to cause the white noise issue when time_embedder weights are in uint8, and line 811 would have an issue if timestep dtype does not match encoder_hidden_states. May be we need to remove lines 807 and 808?
What does this PR do?
Added support to load checkpoints from a single file where some modifications were required to
convert_wan_transformer_to_diffusersmethod for it to work withWanAnimateTransformer3DModelbest regards,
Sam
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.