Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth_engine/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .base import PipelineConfig
from .qwen_image import QwenImagePipelineConfig
from .wan import WanPipelineConfig

__all__ = [
"PipelineConfig",
"QwenImagePipelineConfig",
"WanPipelineConfig",
]
8 changes: 8 additions & 0 deletions diffsynth_engine/configs/wan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

from diffsynth_engine.configs.base import PipelineConfig


@dataclass
class WanPipelineConfig(PipelineConfig):
pass
32 changes: 30 additions & 2 deletions diffsynth_engine/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
class DiffusionModel(nn.Module, ConfigMixin):
config_name = CONFIG_NAME

# This is identical to diffusers' ModelMixin._keep_in_fp32_modules.
_keep_in_fp32_modules: list[str] | None = None

# ModelMixin._keys_to_ignore_on_load_unexpected.
_keys_to_ignore_on_load_unexpected: list[str] | None = None

@classmethod
def from_pretrained(
cls,
Expand All @@ -30,8 +36,30 @@ def from_pretrained(
with init_empty_weights():
model = cls.from_config(config_dict)

# load model weights
state_dict = load_model_weights(model_path, subfolder, device, dtype)
# avoids precision loss
if dtype is not None and dtype != torch.float32 and cls._keep_in_fp32_modules:
state_dict = load_model_weights(model_path, subfolder, device, dtype=None)
for key in state_dict:
if any(m in key.split(".") for m in cls._keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(device=device, dtype=torch.float32)
else:
state_dict[key] = state_dict[key].to(device=device, dtype=dtype)
else:
state_dict = load_model_weights(model_path, subfolder, device, dtype)

# Filter out unexpected keys that the model explicitly ignores
if cls._keys_to_ignore_on_load_unexpected:
keys_to_remove = [
key for key in state_dict if any(pattern in key for pattern in cls._keys_to_ignore_on_load_unexpected)
]
for key in keys_to_remove:
del state_dict[key]
if keys_to_remove:
logger.info(
f"Dropped {len(keys_to_remove)} unexpected key(s) matching "
f"{cls._keys_to_ignore_on_load_unexpected} from state_dict."
)

model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=device)
return model
Expand Down
6 changes: 6 additions & 0 deletions diffsynth_engine/models/wan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .autoencoder_kl_wan import AutoencoderKLWan
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel

__all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WanAnimateTransformer3DModel", "WanVACETransformer3DModel"]
Loading