Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"diffusers": {
"global_rename_dict": {
"patch_embedding": "patch_embedding",
"condition_embedder.text_embedder.linear_1": "text_embedding.0",
"condition_embedder.text_embedder.linear_2": "text_embedding.2",
"condition_embedder.time_embedder.linear_1": "time_embedding.0",
"condition_embedder.time_embedder.linear_2": "time_embedding.2",
"condition_embedder.time_proj": "time_projection.1",
"condition_embedder.image_embedder.norm1": "img_emb.proj.0",
"condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
"condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
"condition_embedder.image_embedder.norm2": "img_emb.proj.4",
"condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
"proj_out": "head.head",
"scale_shift_table": "head.modulation"
},
"rename_dict": {
"attn1.to_q": "self_attn.q",
"attn1.to_k": "self_attn.k",
"attn1.to_v": "self_attn.v",
"attn1.to_out.0": "self_attn.o",
"attn1.norm_q": "self_attn.norm_q",
"attn1.norm_k": "self_attn.norm_k",
"to_gate_compress": "self_attn.gate_compress",
"attn2.to_q": "cross_attn.q",
"attn2.to_k": "cross_attn.k",
"attn2.to_v": "cross_attn.v",
"attn2.to_out.0": "cross_attn.o",
"attn2.norm_q": "cross_attn.norm_q",
"attn2.norm_k": "cross_attn.norm_k",
"attn2.add_k_proj": "cross_attn.k_img",
"attn2.add_v_proj": "cross_attn.v_img",
"attn2.norm_added_k": "cross_attn.norm_k_img",
"norm2": "norm3",
"ffn.net.0.proj": "ffn.0",
"ffn.net.2": "ffn.2",
"scale_shift_table": "modulation"
}
}
}
38 changes: 33 additions & 5 deletions diffsynth_engine/configs/pipeline.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里AttentionConfig SPARGE/VSA已经分叉了,可以定义两个子类,把prepare_attn_kwargs定义在子类的to_attn_kwargs方法里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了attn_params字段来适配不同的参数

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Dict, Tuple, Optional

from diffsynth_engine.configs.controlnet import ControlType
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs


@dataclass
Expand All @@ -30,16 +31,43 @@ class AttnImpl(Enum):
SDPA = "sdpa" # Scaled Dot Product Attention
SAGE = "sage" # Sage Attention
SPARGE = "sparge" # Sparge Attention
VSA = "vsa" # Video Sparse Attention


@dataclass
class SpargeAttentionParams:
smooth_k: bool = True
cdfthreshd: float = 0.6
simthreshd1: float = 0.98
pvthreshd: float = 50.0


@dataclass
class VideoSparseAttentionParams:
sparsity: float = 0.9


@dataclass
class AttentionConfig:
dit_attn_impl: AttnImpl = AttnImpl.AUTO
# Sparge Attention
sparge_smooth_k: bool = True
sparge_cdfthreshd: float = 0.6
sparge_simthreshd1: float = 0.98
sparge_pvthreshd: float = 50.0
attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None

def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
if isinstance(self.attn_params, SpargeAttentionParams):
assert self.dit_attn_impl == AttnImpl.SPARGE
attn_kwargs.update(
{
"smooth_k": self.attn_params.smooth_k,
"simthreshd1": self.attn_params.simthreshd1,
"cdfthreshd": self.attn_params.cdfthreshd,
"pvthreshd": self.attn_params.pvthreshd,
}
)
elif isinstance(self.attn_params, VideoSparseAttentionParams):
assert self.dit_attn_impl == AttnImpl.VSA
attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
return attn_kwargs


@dataclass
Expand Down
79 changes: 59 additions & 20 deletions diffsynth_engine/models/basic/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SDPA_AVAILABLE,
SAGE_ATTN_AVAILABLE,
SPARGE_ATTN_AVAILABLE,
VIDEO_SPARSE_ATTN_AVAILABLE,
)
from diffsynth_engine.utils.platform import DTYPE_FP8

Expand All @@ -20,19 +21,18 @@
logger = logging.get_logger(__name__)


def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
padding_size = (alignment - x.shape[dim] % alignment) % alignment
padded_x = F.pad(x, (0, padding_size), "constant", 0)
return padded_x[..., : x.shape[dim]]


if FLASH_ATTN_3_AVAILABLE:
from flash_attn_interface import flash_attn_func as flash_attn3
if FLASH_ATTN_2_AVAILABLE:
from flash_attn import flash_attn_func as flash_attn2
if XFORMERS_AVAILABLE:
from xformers.ops import memory_efficient_attention

def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
padding_size = (alignment - x.shape[dim] % alignment) % alignment
padded_x = F.pad(x, (0, padding_size), "constant", 0)
return padded_x[..., : x.shape[dim]]

def xformers_attn(q, k, v, attn_mask=None, scale=None):
if attn_mask is not None:
if attn_mask.ndim == 2:
Expand Down Expand Up @@ -94,6 +94,13 @@ def sparge_attn(
return out.transpose(1, 2)


if VIDEO_SPARSE_ATTN_AVAILABLE:
from diffsynth_engine.models.basic.video_sparse_attention import (
video_sparse_attn,
distributed_video_sparse_attn,
)


def eager_attn(q, k, v, attn_mask=None, scale=None):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
Expand All @@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):


def attention(
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: Optional[torch.Tensor] = None,
attn_impl: Optional[str] = "auto",
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
Expand All @@ -133,6 +141,7 @@ def attention(
"sdpa",
"sage",
"sparge",
"vsa",
]
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
if attn_impl is None or attn_impl == "auto":
Expand Down Expand Up @@ -189,10 +198,24 @@ def attention(
v,
attn_mask=attn_mask,
scale=scale,
smooth_k=kwargs.get("sparge_smooth_k", True),
simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
smooth_k=kwargs.get("smooth_k", True),
simthreshd1=kwargs.get("simthreshd1", 0.6),
cdfthreshd=kwargs.get("cdfthreshd", 0.98),
pvthreshd=kwargs.get("pvthreshd", 50),
)
if attn_impl == "vsa":
return video_sparse_attn(
q,
k,
v,
g,
sparsity=kwargs.get("sparsity"),
num_tiles=kwargs.get("num_tiles"),
total_seq_length=kwargs.get("total_seq_length"),
tile_partition_indices=kwargs.get("tile_partition_indices"),
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
variable_block_sizes=kwargs.get("variable_block_sizes"),
non_pad_index=kwargs.get("non_pad_index"),
)
raise ValueError(f"Invalid attention implementation: {attn_impl}")

Expand Down Expand Up @@ -242,9 +265,10 @@ def forward(


def long_context_attention(
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: Optional[torch.Tensor] = None,
attn_impl: Optional[str] = None,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
Expand All @@ -267,6 +291,7 @@ def long_context_attention(
"sdpa",
"sage",
"sparge",
"vsa",
]
assert attn_mask is None, "long context attention does not support attention mask"
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
Expand Down Expand Up @@ -307,11 +332,25 @@ def long_context_attention(
if attn_impl == "sparge":
attn_processor = SparseAttentionMeansim()
# default args from spas_sage2_attn_meansim_cuda
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
q, k, v, softmax_scale=scale
)
if attn_impl == "vsa":
return distributed_video_sparse_attn(
q,
k,
v,
g,
sparsity=kwargs.get("sparsity"),
num_tiles=kwargs.get("num_tiles"),
total_seq_length=kwargs.get("total_seq_length"),
tile_partition_indices=kwargs.get("tile_partition_indices"),
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
variable_block_sizes=kwargs.get("variable_block_sizes"),
non_pad_index=kwargs.get("non_pad_index"),
)
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
Loading