Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanAnimateTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
Expand Down
79 changes: 76 additions & 3 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"wan_vace": "vace_blocks.0.after_proj.bias",
"wan_animate": "motion_encoder.dec.direction.weight",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [
"net.x_embedder.proj.1.weight",
Expand Down Expand Up @@ -205,6 +206,7 @@
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
Expand Down Expand Up @@ -727,6 +729,9 @@ def infer_diffusers_model_type(checkpoint):
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"

if CHECKPOINT_KEY_NAMES["wan-animate"] in checkpoint:
model_type = "wan-animate-14B"

elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
Expand Down Expand Up @@ -3148,14 +3153,82 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
# For the VACE model
"before_proj": "proj_in",
"after_proj": "proj_out",
# For Wan Animate
"face_adapter.fuser_blocks": "face_adapter",
".k_norm.": ".norm_k.",
".q_norm.": ".norm_q.",
# Requires tensor split
".linear1_kv.": [".to_k.", ".to_v."],
".linear1_q.": ".to_q.",
".linear2.": ".to_out.",
"conv1_local.conv": "conv1_local",
"conv2.conv": "conv2",
"conv3.conv": "conv3",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.enc.net_app.convs.7.conv1.0.weight": "motion_encoder.res_blocks.6.conv1.weight",
"motion_encoder.enc.net_app.convs.6.conv1.0.weight": "motion_encoder.res_blocks.5.conv1.weight",
"motion_encoder.enc.net_app.convs.5.conv1.0.weight": "motion_encoder.res_blocks.4.conv1.weight",
"motion_encoder.enc.net_app.convs.4.conv1.0.weight": "motion_encoder.res_blocks.3.conv1.weight",
"motion_encoder.enc.net_app.convs.3.conv1.0.weight": "motion_encoder.res_blocks.2.conv1.weight",
"motion_encoder.enc.net_app.convs.2.conv1.0.weight": "motion_encoder.res_blocks.1.conv1.weight",
"motion_encoder.enc.net_app.convs.1.conv1.0.weight": "motion_encoder.res_blocks.0.conv1.weight",
"motion_encoder.enc.net_app.convs.7.conv2.1.weight": "motion_encoder.res_blocks.6.conv2.weight",
"motion_encoder.enc.net_app.convs.6.conv2.1.weight": "motion_encoder.res_blocks.5.conv2.weight",
"motion_encoder.enc.net_app.convs.5.conv2.1.weight": "motion_encoder.res_blocks.4.conv2.weight",
"motion_encoder.enc.net_app.convs.4.conv2.1.weight": "motion_encoder.res_blocks.3.conv2.weight",
"motion_encoder.enc.net_app.convs.3.conv2.1.weight": "motion_encoder.res_blocks.2.conv2.weight",
"motion_encoder.enc.net_app.convs.2.conv2.1.weight": "motion_encoder.res_blocks.1.conv2.weight",
"motion_encoder.enc.net_app.convs.1.conv2.1.weight": "motion_encoder.res_blocks.0.conv2.weight",
"motion_encoder.enc.net_app.convs.7.conv1.1.bias": "motion_encoder.res_blocks.6.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.6.conv1.1.bias": "motion_encoder.res_blocks.5.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.5.conv1.1.bias": "motion_encoder.res_blocks.4.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.4.conv1.1.bias": "motion_encoder.res_blocks.3.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.3.conv1.1.bias": "motion_encoder.res_blocks.2.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.2.conv1.1.bias": "motion_encoder.res_blocks.1.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.1.conv1.1.bias": "motion_encoder.res_blocks.0.conv1.act_fn.bias",
"motion_encoder.enc.net_app.convs.7.conv2.2.bias": "motion_encoder.res_blocks.6.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.6.conv2.2.bias": "motion_encoder.res_blocks.5.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.5.conv2.2.bias": "motion_encoder.res_blocks.4.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.4.conv2.2.bias": "motion_encoder.res_blocks.3.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.3.conv2.2.bias": "motion_encoder.res_blocks.2.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.2.conv2.2.bias": "motion_encoder.res_blocks.1.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.1.conv2.2.bias": "motion_encoder.res_blocks.0.conv2.act_fn.bias",
"motion_encoder.enc.net_app.convs.7.skip.1.weight": "motion_encoder.res_blocks.6.conv_skip.weight",
"motion_encoder.enc.net_app.convs.6.skip.1.weight": "motion_encoder.res_blocks.5.conv_skip.weight",
"motion_encoder.enc.net_app.convs.5.skip.1.weight": "motion_encoder.res_blocks.4.conv_skip.weight",
"motion_encoder.enc.net_app.convs.4.skip.1.weight": "motion_encoder.res_blocks.3.conv_skip.weight",
"motion_encoder.enc.net_app.convs.3.skip.1.weight": "motion_encoder.res_blocks.2.conv_skip.weight",
"motion_encoder.enc.net_app.convs.2.skip.1.weight": "motion_encoder.res_blocks.1.conv_skip.weight",
"motion_encoder.enc.net_app.convs.1.skip.1.weight": "motion_encoder.res_blocks.0.conv_skip.weight",
}

for key in list(checkpoint.keys()):
new_key = key[:]
extra_key = ""
index = 0
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)

converted_state_dict[new_key] = checkpoint.pop(key)
if isinstance(rename_key, list):
if replace_key in new_key:
index = int(checkpoint[key].shape[0] / 2)
new_key = new_key.replace(replace_key, rename_key[0])
extra_key = new_key.replace(rename_key[0], rename_key[1])
else:
new_key = new_key.replace(replace_key, rename_key)
if extra_key != "":
converted_state_dict[new_key] = checkpoint[key][index:]
converted_state_dict[extra_key] = checkpoint[key][:index]
checkpoint.pop(key)
else:
if key == "motion_encoder.enc.net_app.convs.0.1.bias":
converted_state_dict[new_key] = checkpoint.pop(key)[0, :, 0, 0]
elif "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
converted_state_dict[new_key] = checkpoint.pop(key)[0, :, 0, 0]
else:
converted_state_dict[new_key] = checkpoint.pop(key)

return converted_state_dict

Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/models/transformers/transformer_wan_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
# set to 1, which should be equivalent to a 2D convolution
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
x = x.to(expanded_kernel.dtype)
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)

# Main Conv2D with scaling
x = x.to(self.weight.dtype)
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)

# Activation with fused bias, if using
Expand Down Expand Up @@ -338,8 +340,8 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tenso
weight = self.motion_synthesis_weight + 1e-8
# Upcast the QR orthogonalization operation to FP32
original_motion_dtype = motion_feat.dtype
motion_feat = motion_feat.to(torch.float32)
weight = weight.to(torch.float32)
motion_feat = motion_feat.to(weight.dtype)
# weight = weight.to(torch.float32)

Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)

Expand Down Expand Up @@ -802,8 +804,10 @@ def forward(
timestep = timestep.unflatten(0, (-1, timestep_seq_len))

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]:
timestep = timestep.to(time_embedder_dtype)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dg845 Do you know why this line exists? It seems to cause the white noise issue when time_embedder weights are in uint8, and line 811 would have an issue if timestep dtype does not match encoder_hidden_states. May be we need to remove lines 807 and 808?

if timestep.dtype != encoder_hidden_states.dtype:
timestep = timestep.to(encoder_hidden_states.dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))

Expand Down