Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos #10660

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open

Cosmos #10660

wants to merge 45 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 27, 2025

The cosmos is within us. We are made of star-stuff. We are a way for the universe to know itself.

WIP.

Transformer

test attention
from typing import Optional
from einops import rearrange

import torch
import torch.nn as nn


class RMSNorm(torch.nn.Module):
    def __init__(
        self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
    ):
        super().__init__()
        self.eps = eps
        self.learnable_scale = elementwise_affine
        if self.learnable_scale:
            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        else:
            self.register_parameter("weight", None)

    def forward(self, x):
        r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        if self.weight is None:
            return r
        else:
            return r * self.weight.to(dtype=x.dtype, device=x.device)


def get_normalization(name: str, channels: int):
    if name == "I":
        return nn.Identity()
    elif name == "R":
    #     return te.pytorch.RMSNorm(channels, eps=1e-6)
        return RMSNorm(channels, eps=1e-6)
    else:
        raise ValueError(f"Normalization {name} not found")


class Attention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        qkv_bias: bool = False,
        out_bias: bool = False,
        qkv_norm: str = "SSI",
        qkv_norm_mode: str = "per_head",
        backend: str = "transformer_engine",
        qkv_format: str = "bshd",
    ) -> None:
        super().__init__()

        self.is_selfattn = context_dim is None  # self attention

        inner_dim = dim_head * heads
        context_dim = query_dim if context_dim is None else context_dim

        self.heads = heads
        self.dim_head = dim_head
        self.qkv_norm_mode = qkv_norm_mode
        self.qkv_format = qkv_format

        if self.qkv_norm_mode == "per_head":
            norm_dim = dim_head
        else:
            raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")

        self.backend = backend

        self.to_q = nn.Sequential(
            nn.Linear(query_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[0], norm_dim),
        )
        self.to_k = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[1], norm_dim),
        )
        self.to_v = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[2], norm_dim),
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim, bias=out_bias),
            nn.Dropout(dropout),
        )

    def cal_qkv(
        self, x, context=None, mask=None, rope_emb=None, **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q = self.to_q[0](x)
        context = x if context is None else context
        k = self.to_k[0](context)
        v = self.to_v[0](context)
        q, k, v = map(
            # lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head),
            lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
            (q, k, v),
        )

        q = self.to_q[1](q)
        k = self.to_k[1](k)
        v = self.to_v[1](v)
        if self.is_selfattn and rope_emb is not None:  # only apply to self-attention!
            print("here")
            # q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
            # k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
            # apply_rotary_pos_emb inlined
            q_shape = q.shape
            q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            q = torch.cat([rope_emb[..., 0] * q[..., 0], rope_emb[..., 1] * q[..., 1]], dim=-1)
            # q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
            q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)

            # apply_rotary_pos_emb inlined
            k_shape = k.shape
            k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            k = torch.cat([rope_emb[..., 0] * k[..., 0], rope_emb[..., 1] * k[..., 1]], dim=-1)
            # k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
            k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
        return q, k, v

    def cal_attn(self, q, k, v, mask=None):
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        out = rearrange(out, "b n s c -> s b (n c)")
        out = self.to_out(out)
        return out

    def forward(
        self,
        x,
        context=None,
        mask=None,
        rope_emb=None,
        **kwargs,
    ):
        """
        Args:
            x (Tensor): The query tensor of shape [B, Mq, K]
            context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
        """
        q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
        return self.cal_attn(q, k, v, mask)


@torch.no_grad()
def match_rms_norm():
    from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm

    theirs_rmsnorm = RMSNorm(128, elementwise_affine=True, eps=1e-6)
    ours_rmsnorm = DiffusersRMSNorm(128, eps=1e-6, elementwise_affine=True)
    ours_rmsnorm.weight.data.copy_(theirs_rmsnorm.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_rmsnorm(input)
    ours_output = ours_rmsnorm(input)

    print(sum(p.numel() for p in theirs_rmsnorm.parameters()))
    print(sum(p.numel() for p in ours_rmsnorm.parameters()))
    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_attention():
    from diffusers.models.attention import Attention as DiffusersAttention

    theirs_attention = Attention(128, 128, heads=8, dim_head=16, qkv_bias=False, out_bias=False, qkv_norm="RRI")
    ours_attention = DiffusersAttention(128, 128, heads=8, dim_head=16, qk_norm="rms_norm", out_bias=False, elementwise_affine=False)
    ours_attention.to_q.weight.data.copy_(theirs_attention.to_q[0].weight.data)
    ours_attention.to_k.weight.data.copy_(theirs_attention.to_k[0].weight.data)
    ours_attention.to_v.weight.data.copy_(theirs_attention.to_v[0].weight.data)
    ours_attention.to_out[0].weight.data.copy_(theirs_attention.to_out[0].weight.data)

    input = torch.randn(1, 42, 128)
    theirs_output = rearrange(theirs_attention(rearrange(input, "b s c -> s b c")), "s b c -> b s c")
    ours_output = ours_attention(input)

    print(sum(p.numel() for p in theirs_attention.parameters()))
    print(sum(p.numel() for p in ours_attention.parameters()))
    print(torch.allclose(theirs_output, ours_output, atol=1e-3))


match_rms_norm()
match_attention()
test ff
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class FeedForward(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1,
        activation=nn.ReLU(),
        is_gated: bool = False,
        bias: bool = False,
    ) -> None:
        super().__init__()

        self.layer1 = nn.Linear(d_model, d_ff, bias=bias)
        self.layer2 = nn.Linear(d_ff, d_model, bias=bias)

        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        self.is_gated = is_gated
        if is_gated:
            self.linear_gate = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor):
        g = self.activation(self.layer1(x))
        if self.is_gated:
            x = g * self.linear_gate(x)
        else:
            x = g
        assert self.dropout.p == 0.0, "we skip dropout"
        return self.layer2(x)


class GPT2FeedForward(FeedForward):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False):
        super().__init__(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=nn.GELU(),
            is_gated=False,
            bias=bias,
        )

    def forward(self, x: torch.Tensor):
        assert self.dropout.p == 0.0, "we skip dropout"

        x = self.layer1(x)

        def activation_layer2_forward(x):
            x = self.activation(x)
            x = self.layer2(x)
            return x

        x = checkpoint(activation_layer2_forward, x, use_reentrant=False)
        return x


@torch.no_grad()
def match_ff():
    from diffusers.models.attention import FeedForward as DiffusersFeedForward

    theirs_ff = FeedForward(128, 512, 0.0, activation=nn.GELU(), is_gated=True, bias=False)
    ours_ff = DiffusersFeedForward(128, mult=4, dropout=0.0, activation_fn="geglu", bias=False)
    ours_ff.net[0].proj.weight.data[:512, :].copy_(theirs_ff.linear_gate.weight.data)
    ours_ff.net[0].proj.weight.data[512:, :].copy_(theirs_ff.layer1.weight.data)
    ours_ff.net[2].weight.data.copy_(theirs_ff.layer2.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_ff(input)
    ours_output = ours_ff(input)

    print(sum(p.numel() for p in theirs_ff.parameters()))
    print(sum(p.numel() for p in ours_ff.parameters()))
    print(torch.allclose(theirs_output, ours_output))


match_ff()
test timesteps
import itertools
import math

import torch
import torch.nn as nn

class Timesteps(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels

    def forward(self, timesteps):
        in_dype = timesteps.dtype
        half_dim = self.num_channels // 2
        exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
        exponent = exponent / (half_dim - 0.0)

        emb = torch.exp(exponent)
        emb = timesteps[:, None].float() * emb[None, :]

        sin_emb = torch.sin(emb)
        cos_emb = torch.cos(emb)
        emb = torch.cat([cos_emb, sin_emb], dim=-1)

        # return emb.to(in_dype)
        return emb


class TimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False):
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora)
        self.activation = nn.SiLU()
        self.use_adaln_lora = use_adaln_lora
        if use_adaln_lora:
            self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
        else:
            self.linear_2 = nn.Linear(out_features, out_features, bias=True)

    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        sample = sample.to(self.linear_1.weight.dtype)
        emb = self.linear_1(sample)
        emb = self.activation(emb)
        emb = self.linear_2(emb)

        if self.use_adaln_lora:
            emb_B_D = sample
            adaln_lora_B_3D = emb
        else:
            emb_B_D = emb
            adaln_lora_B_3D = None

        return emb_B_D, adaln_lora_B_3D


class CosmosTimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=False)
        self.activation = nn.SiLU()
        self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
    
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        emb = self.linear_1(hidden_states)
        emb = self.activation(emb)
        emb = self.linear_2(emb)
        return hidden_states, emb


@torch.no_grad()
def match_timestep():
    from diffusers.models.embeddings import Timesteps as DiffusersTimesteps

    theirs_timesteps = Timesteps(256)
    ours_timesteps = DiffusersTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0.0)

    input = torch.tensor([1000.0], dtype=torch.float32)
    theirs_output = theirs_timesteps(input)
    ours_output = ours_timesteps(input)

    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_timestep_embedding():
    theirs_temb = TimestepEmbedding(256, 256, use_adaln_lora=True)
    ours_temb = CosmosTimestepEmbedding(256, 256)
    ours_temb.linear_1.weight.data.copy_(theirs_temb.linear_1.weight.data)
    ours_temb.linear_2.weight.data.copy_(theirs_temb.linear_2.weight.data)

    input = torch.randn(1, 256)
    theirs_output = theirs_temb(input)
    ours_output = ours_temb(input)

    print(sum(p.numel() for p in theirs_temb.parameters()))
    print(sum(p.numel() for p in ours_temb.parameters()))
    print(torch.allclose(theirs_output[0], ours_output[0]))
    print(torch.allclose(theirs_output[1], ours_output[1]))


@torch.no_grad()
def match_timestep_embedding_2():
    from diffusers.models.transformers.transformer_cosmos import CosmosTimestepEmbedding
    theirs_temb = TimestepEmbedding(256, 256, use_adaln_lora=True)
    ours_temb = CosmosTimestepEmbedding(256, 256)
    ours_temb.linear_1.weight.data.copy_(theirs_temb.linear_1.weight.data)
    ours_temb.linear_2.weight.data.copy_(theirs_temb.linear_2.weight.data)

    input = torch.randn(1, 256)
    theirs_output = theirs_temb(input)
    ours_output = ours_temb(input)

    print(sum(p.numel() for p in theirs_temb.parameters()))
    print(sum(p.numel() for p in ours_temb.parameters()))
    print(torch.allclose(theirs_output[1], ours_output))


@torch.no_grad()
def match_timestep_prepare_embedding():
    from diffusers.models.transformers.transformer_cosmos import CosmosEmbedding
    from diffusers.models.normalization import RMSNorm
    theirs_t_embedder = nn.Sequential(
        Timesteps(4096),
        TimestepEmbedding(4096, 4096, use_adaln_lora=True),
    )
    theirs_norm = RMSNorm(4096, 1e-6, True)
    ours_t_embedder = CosmosEmbedding(4096, 4096)
    ours_t_embedder.t_embedder.linear_1.weight.data.copy_(theirs_t_embedder[1].linear_1.weight.data)
    ours_t_embedder.t_embedder.linear_2.weight.data.copy_(theirs_t_embedder[1].linear_2.weight.data)
    ours_t_embedder.norm.weight.data.copy_(theirs_norm.weight.data)

    hidden_states = torch.randn(1, 1, 4096)
    input = torch.randint(0, 1000, (1,)).long()
    theirs_output = theirs_t_embedder(input)
    ours_output = ours_t_embedder(hidden_states, input)

    print(sum(p.numel() for p in itertools.chain(theirs_t_embedder.parameters(), theirs_norm.parameters())))
    print(sum(p.numel() for p in ours_t_embedder.parameters()))
    print(torch.allclose(theirs_output[1], ours_output[0]))
    print(torch.allclose(theirs_norm(theirs_output[0]), ours_output[1]))


match_timestep()
print()

match_timestep_embedding()
print()

match_timestep_embedding_2()
print()

match_timestep_prepare_embedding()
print()
test patch embed
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange

class PatchEmbed(nn.Module):
    def __init__(
        self,
        spatial_patch_size,
        temporal_patch_size,
        in_channels=3,
        out_channels=768,
        bias=True,
    ):
        super().__init__()
        self.spatial_patch_size = spatial_patch_size
        self.temporal_patch_size = temporal_patch_size

        self.proj = nn.Sequential(
            Rearrange(
                "b c (t r) (h m) (w n) -> b t h w (c r m n)",
                r=temporal_patch_size,
                m=spatial_patch_size,
                n=spatial_patch_size,
            ),
            nn.Linear(
                in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias
            ),
        )
        self.out = nn.Identity()

    def forward(self, x):
        assert x.dim() == 5
        _, _, T, H, W = x.shape
        assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
        assert T % self.temporal_patch_size == 0
        x = self.proj(x)
        return self.out(x)


@torch.no_grad()
def match_patch_embed():
    from diffusers.models.transformers.transformer_cosmos import CosmosPatchEmbed

    theirs_patch_embed = PatchEmbed(2, 1, 16, 4096, bias=False)
    ours_patch_embed = CosmosPatchEmbed(16, 4096, (1, 2, 2), bias=False)

    ours_patch_embed.proj.weight.data.copy_(theirs_patch_embed.proj[1].weight.data)

    input = torch.randn(1, 16, 128, 240, 240)

    theirs_output = theirs_patch_embed(input)
    ours_output = ours_patch_embed(input)

    print(torch.allclose(theirs_output, ours_output))

match_patch_embed()
test positional embed
import math
from typing import Optional, List

import numpy as np
import torch
from einops import rearrange, repeat


def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
    return x / norm.to(x.dtype)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)



class VideoPositionEmb(torch.nn.Module):
    def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor:
        """
        It delegates the embedding generation to generate_embeddings function.
        """
        B_T_H_W_C = x_B_T_H_W_C.shape
        embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps)

        return embeddings

    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]):
        raise NotImplementedError


class VideoRopePosition3DEmb(VideoPositionEmb):
    def __init__(
        self,
        *,  # enforce keyword arguments
        head_dim: int,
        len_h: int,
        len_w: int,
        len_t: int,
        base_fps: int = 24,
        h_extrapolation_ratio: float = 1.0,
        w_extrapolation_ratio: float = 1.0,
        t_extrapolation_ratio: float = 1.0,
        **kwargs,  # used for compatibility with other positional embeddings; unused in this class
    ):
        del kwargs
        super().__init__()
        self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float))
        self.base_fps = base_fps
        self.max_h = len_h
        self.max_w = len_w

        dim = head_dim
        dim_h = dim // 6 * 2
        dim_w = dim_h
        dim_t = dim - 2 * dim_h
        assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
        self.register_buffer(
            "dim_spatial_range",
            # torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h,
            torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h,
            persistent=False,
        )
        self.register_buffer(
            "dim_temporal_range",
            # torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t,
            torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t,
            persistent=False,
        )

        self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
        self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
        self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))

    def generate_embeddings(
        self,
        B_T_H_W_C: torch.Size,
        fps: Optional[torch.Tensor] = None,
        h_ntk_factor: Optional[float] = None,
        w_ntk_factor: Optional[float] = None,
        t_ntk_factor: Optional[float] = None,
    ):
        """
        Generate embeddings for the given input size.

        Args:
            B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
            fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
            h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
            w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
            t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.

        Returns:
            Not specified in the original code snippet.
        """
        h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
        w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
        t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor

        h_theta = 10000.0 * h_ntk_factor
        w_theta = 10000.0 * w_ntk_factor
        t_theta = 10000.0 * t_ntk_factor

        h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range)
        w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range)
        temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range)

        B, T, H, W, _ = B_T_H_W_C
        uniform_fps = (fps is None) or (fps.min() == fps.max())
        assert (
            uniform_fps or B == 1 or T == 1
        ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
        assert (
            H <= self.max_h and W <= self.max_w
        ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
        half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs)
        half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs)

        # apply sequence scaling in temporal dimension
        if fps is None:  # image case
            assert T == 1, "T should be 1 for image batch."
            half_emb_t = torch.outer(self.seq[:T], temporal_freqs)
        else:
            half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs)

        em_T_H_W_D = torch.cat(
            [
                repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
                repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
                repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
            ]
            * 2,
            dim=-1,
        )

        return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float()


class LearnablePosEmbAxis(VideoPositionEmb):
    def __init__(
        self,
        *,  # enforce keyword arguments
        interpolation: str,
        model_channels: int,
        len_h: int,
        len_w: int,
        len_t: int,
        **kwargs,
    ):
        """
        Args:
            interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
        """
        del kwargs  # unused
        super().__init__()
        self.interpolation = interpolation
        assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"

        self.pos_emb_h = torch.nn.Parameter(torch.zeros(len_h, model_channels))
        self.pos_emb_w = torch.nn.Parameter(torch.zeros(len_w, model_channels))
        self.pos_emb_t = torch.nn.Parameter(torch.zeros(len_t, model_channels))

        trunc_normal_(self.pos_emb_h, std=0.02)
        trunc_normal_(self.pos_emb_w, std=0.02)
        trunc_normal_(self.pos_emb_t, std=0.02)

    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor:
        B, T, H, W, _ = B_T_H_W_C
        if self.interpolation == "crop":
            emb_h_H = self.pos_emb_h[:H]
            emb_w_W = self.pos_emb_w[:W]
            emb_t_T = self.pos_emb_t[:T]
            emb = (
                repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
                + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
                + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
            )
            assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
        else:
            raise ValueError(f"Unknown interpolation method {self.interpolation}")

        return normalize(emb, dim=-1, eps=1e-6)


@torch.no_grad()
def match_rope():
    from diffusers.models.transformers.transformer_cosmos import CosmosRotaryPosEmbed

    theirs_rope = VideoRopePosition3DEmb(head_dim=128, len_h=240 // 2, len_w=240 // 2, len_t=128 // 1, base_fps=24, h_extrapolation_ratio=1.0, w_extrapolation_ratio=1.0, t_extrapolation_ratio=2.0)
    ours_rope = CosmosRotaryPosEmbed(hidden_size=128, max_size=(128, 240, 240), patch_size=(1, 2, 2), base_fps=24, rope_scale=(2.0, 1.0, 1.0))

    hidden_states = torch.randn(2, 2, 32, 32, 16)
    fps = 30

    theirs_output = theirs_rope(hidden_states[:, :, :16, :16, :], fps=torch.tensor([fps]))  # the input slicing is to replicate patchification operation
    ours_output = ours_rope(hidden_states.permute(0, 4, 1, 2, 3), fps=fps)

    theirs_cos, theirs_sin = torch.cos(theirs_output), torch.sin(theirs_output)
    print(torch.allclose(ours_output[0][:, None, None, :], theirs_cos))
    print(torch.allclose(ours_output[1][:, None, None, :], theirs_sin))


@torch.no_grad()
def match_learnable_pe():
    from diffusers.models.transformers.transformer_cosmos import CosmosLearnablePositionalEmbed

    theirs_pe = LearnablePosEmbAxis(interpolation="crop", model_channels=4096, len_h=240 // 2, len_w=240 // 2, len_t=128 // 1)
    ours_pe = CosmosLearnablePositionalEmbed(4096, max_size=(128, 240, 240), patch_size=(1, 2, 2), eps=1e-6)

    ours_pe.pos_emb_t.data.copy_(theirs_pe.pos_emb_t.data)
    ours_pe.pos_emb_h.data.copy_(theirs_pe.pos_emb_h.data)
    ours_pe.pos_emb_w.data.copy_(theirs_pe.pos_emb_w.data)

    hidden_states = torch.randn(2, 2, 32, 32, 16)

    theirs_output = theirs_pe(hidden_states[:, :, :16, :16, :])
    ours_output = ours_pe(hidden_states.permute(0, 4, 1, 2, 3))

    theirs_output = theirs_output.flatten(1, 3)
    print(torch.allclose(ours_output, theirs_output))


# match_rope()
match_learnable_pe()
test transformer block
import sys
sys.path.append("/raid/aryan/cosmos-code/")

import torch
from cosmos1.models.diffusion.module.blocks import GeneralDITTransformerBlock


@torch.no_grad()
def match_transformer_block():
    from diffusers.models.transformers.transformer_cosmos import CosmosTransformerBlock

    theirs_transformer_block = GeneralDITTransformerBlock(
        x_dim=4096,
        context_dim=1024,
        num_heads=32,
        block_config="FA-CA-MLP",
        mlp_ratio=4.0,
        x_format="BTHWD",
        use_adaln_lora=True,
        adaln_lora_dim=256,
    )

    ours_transformer_block = CosmosTransformerBlock(
        num_attention_heads=32,
        attention_head_dim=128,
        cross_attention_dim=1024,
        mlp_ratio=4,
        adaln_lora_dim=256,
        qk_norm="rms_norm",
        out_bias=False,
    )

    ours_transformer_block.norm1.linear_1.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm1.linear_2.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[2].weight.data)
    
    ours_transformer_block.attn1.to_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[0].weight.data)
    ours_transformer_block.attn1.to_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[0].weight.data)
    ours_transformer_block.attn1.to_v.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_v[0].weight.data)
    ours_transformer_block.attn1.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_out[0].weight.data)
    ours_transformer_block.attn1.norm_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[1].weight.data)
    ours_transformer_block.attn1.norm_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[1].weight.data)

    ours_transformer_block.norm2.linear_1.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm2.linear_2.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[2].weight.data)

    ours_transformer_block.attn2.to_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[0].weight.data)
    ours_transformer_block.attn2.to_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[0].weight.data)
    ours_transformer_block.attn2.to_v.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_v[0].weight.data)
    ours_transformer_block.attn2.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_out[0].weight.data)
    ours_transformer_block.attn2.norm_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[1].weight.data)
    ours_transformer_block.attn2.norm_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[1].weight.data)

    ours_transformer_block.norm3.linear_1.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm3.linear_2.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[2].weight.data)

    ours_transformer_block.ff.net[0].proj.weight.data.copy_(theirs_transformer_block.blocks[2].block.layer1.weight.data)
    ours_transformer_block.ff.net[2].weight.data.copy_(theirs_transformer_block.blocks[2].block.layer2.weight.data)

    # ============
    batch_size = 1
    latent_num_frames = 2
    latent_height = 16
    latent_width = 16
    embedding_dim = 4096
    encoder_seq_length = 64
    encoder_dim = 1024
    

    hidden_states = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, embedding_dim)
    temb = torch.randn(batch_size, embedding_dim)
    encoder_hidden_states = torch.randn(batch_size, encoder_seq_length, encoder_dim)
    attention_mask = None
    freqs = torch.randn(1, 1, latent_num_frames * latent_height * latent_width, 128)
    embedded_timestep = torch.randn(batch_size, 3 * embedding_dim)
    extra_per_block_emb = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, embedding_dim)

    theirs_output = theirs_transformer_block(
        x=hidden_states.flatten(1, 3).permute(1, 0, 2),
        emb_B_D=temb,
        crossattn_emb=encoder_hidden_states.permute(1, 0, 2),
        crossattn_mask=attention_mask,
        rope_emb_L_1_1_D=freqs.permute(2, 0, 1, 3),
        adaln_lora_B_3D=embedded_timestep,
        extra_per_block_pos_emb=extra_per_block_emb.flatten(1, 3).permute(1, 0, 2),
    )
    ours_output = ours_transformer_block(
        hidden_states=hidden_states.flatten(1, 3),
        encoder_hidden_states=encoder_hidden_states,
        temb=temb,
        embedded_timestep=embedded_timestep,
        image_rotary_emb=(torch.cos(freqs.flatten(0, 2)), torch.sin(freqs.flatten(0, 2))),
        extra_pos_emb=extra_per_block_emb.flatten(1, 3),
        attention_mask=attention_mask,
    )

    theirs_output = theirs_output.flatten(0, 2).permute(1, 0, 2)
    print(sum(p.numel() for p in theirs_transformer_block.parameters()))
    print(sum(p.numel() for p in ours_transformer_block.parameters()))
    print(torch.allclose(theirs_output.flatten(), ours_output.flatten(), atol=1e-4))


match_transformer_block()

# GeneralDITTransformerBlock(
#   (blocks): ModuleList(
#     (0): DITBuildingBlock(
#       (block): VideoAttn(
#         (attn): Attention(
#           (to_q): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_k): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_v): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Identity()
#           )
#           (to_out): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Dropout(p=0.0, inplace=False)
#           )
#         )
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#     (1): DITBuildingBlock(
#       (block): VideoAttn(
#         (attn): Attention(
#           (to_q): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_k): Sequential(
#             (0): Linear(in_features=1204, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_v): Sequential(
#             (0): Linear(in_features=1204, out_features=4096, bias=False)
#             (1): Identity()
#           )
#           (to_out): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Dropout(p=0.0, inplace=False)
#           )
#         )
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#     (2): DITBuildingBlock(
#       (block): GPT2FeedForward(
#         (layer1): Linear(in_features=4096, out_features=16384, bias=False)
#         (layer2): Linear(in_features=16384, out_features=4096, bias=False)
#         (dropout): Dropout(p=0.0, inplace=False)
#         (activation): GELU(approximate='none')
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#   )
# )
# CosmosTransformerBlock(
#   (norm1): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (attn1): Attention(
#     (norm_q): RMSNorm()
#     (norm_k): RMSNorm()
#     (to_q): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_k): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_v): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_out): ModuleList(
#       (0): Linear(in_features=4096, out_features=4096, bias=False)
#       (1): Dropout(p=0.0, inplace=False)
#     )
#   )
#   (norm2): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (attn2): Attention(
#     (norm_q): RMSNorm()
#     (norm_k): RMSNorm()
#     (to_q): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_k): Linear(in_features=1024, out_features=4096, bias=False)
#     (to_v): Linear(in_features=1024, out_features=4096, bias=False)
#     (to_out): ModuleList(
#       (0): Linear(in_features=4096, out_features=4096, bias=False)
#       (1): Dropout(p=0.0, inplace=False)
#     )
#   )
#   (norm3): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (ff): FeedForward(
#     (net): ModuleList(
#       (0): GELU(
#         (proj): Linear(in_features=4096, out_features=16384, bias=False)
#       )
#       (1): Dropout(p=0.0, inplace=False)
#       (2): Linear(in_features=16384, out_features=4096, bias=False)
#     )
#   )
# )
test transformer
import sys
sys.path.append("/raid/aryan/cosmos-code/")

import torch
from cosmos1.models.diffusion.networks.general_dit import GeneralDIT


@torch.no_grad()
def match_transformer():
    from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel

    theirs_transformer = GeneralDIT(
        max_img_h=240,
        max_img_w=240,
        max_frames=128,
        in_channels=16,
        out_channels=16,
        patch_spatial=2,
        patch_temporal=1,
        concat_padding_mask=True,
        block_config="FA-CA-MLP",
        model_channels=4096,
        num_blocks=2,
        num_heads=32,
        mlp_ratio=4,
        block_x_format="THWBD",
        crossattn_emb_channels=1024,
        use_cross_attn_mask=False,
        pos_emb_cls="rope3d",
        pos_emb_learnable=True,
        pos_emb_interpolation="crop",
        affline_emb_norm=True,
        use_adaln_lora=True,
        adaln_lora_dim=256,
        rope_h_extrapolation_ratio=1.0,
        rope_w_extrapolation_ratio=1.0,
        rope_t_extrapolation_ratio=2.0,
        extra_per_block_abs_pos_emb=True,
        extra_per_block_abs_pos_emb_type="learnable",
    )

    ours_transformer = CosmosTransformer3DModel(
        in_channels=16,
        out_channels=16,
        num_attention_heads=32,
        attention_head_dim=128,
        num_layers=2,
        mlp_ratio=4,
        text_embed_dim=1024,
        adaln_lora_dim=256,
        max_size=(128, 240, 240),
        patch_size=(1, 2, 2),
        rope_scale=(2.0, 1.0, 1.0),
        concat_padding_mask=True,
        extra_pos_embed_type="learnable",
    )

    # Patch embedding
    ours_transformer.patch_embed.proj.weight.data.copy_(theirs_transformer.x_embedder.proj[1].weight.data)

    # Timestep embedding
    ours_t_embedder = ours_transformer.time_embed
    theirs_t_embedder = theirs_transformer.t_embedder
    theirs_norm = theirs_transformer.affline_norm
    ours_t_embedder.t_embedder.linear_1.weight.data.copy_(theirs_t_embedder[1].linear_1.weight.data)
    ours_t_embedder.t_embedder.linear_2.weight.data.copy_(theirs_t_embedder[1].linear_2.weight.data)
    ours_t_embedder.norm.weight.data.copy_(theirs_norm.weight.data)

    # Learnable position embedding
    ours_pe = ours_transformer.learnable_pos_embed
    theirs_pe = theirs_transformer.extra_pos_embedder
    ours_pe.pos_emb_t.data.copy_(theirs_pe.pos_emb_t.data)
    ours_pe.pos_emb_h.data.copy_(theirs_pe.pos_emb_h.data)
    ours_pe.pos_emb_w.data.copy_(theirs_pe.pos_emb_w.data)

    # Transformer blocks
    for i in range(2):
        ours_transformer_block = ours_transformer.transformer_blocks[i]
        theirs_transformer_block = theirs_transformer.blocks[f"block{i}"]
        
        ours_transformer_block.norm1.linear_1.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm1.linear_2.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[2].weight.data)
            
        ours_transformer_block.attn1.to_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn1.to_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn1.to_v.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn1.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn1.norm_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn1.norm_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm2.linear_1.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm2.linear_2.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[2].weight.data)

        ours_transformer_block.attn2.to_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn2.to_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn2.to_v.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn2.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn2.norm_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn2.norm_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm3.linear_1.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm3.linear_2.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[2].weight.data)

        ours_transformer_block.ff.net[0].proj.weight.data.copy_(theirs_transformer_block.blocks[2].block.layer1.weight.data)
        ours_transformer_block.ff.net[2].weight.data.copy_(theirs_transformer_block.blocks[2].block.layer2.weight.data)
    
    # Output layers
    ours_transformer.norm_out.linear_1.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[1].weight.data)
    ours_transformer.norm_out.linear_2.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[2].weight.data)
    ours_transformer.proj_out.weight.data.copy_(theirs_transformer.final_layer.linear.weight.data)

    for name, param in theirs_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)
    for name, param in ours_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)


    # ============
    batch_size = 1
    latent_num_frames = 2
    latent_height = 16
    latent_width = 16
    encoder_seq_length = 64
    encoder_dim = 1024
    fps = 30.0
    
    hidden_states = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, 16)
    timestep = torch.randint(0, 1000, (batch_size,)).float()
    encoder_hidden_states = torch.randn(batch_size, encoder_seq_length, encoder_dim)
    attention_mask = None
    padding_mask = torch.zeros((1, 1, latent_height * 8, latent_width * 8))

    theirs_output = theirs_transformer(
        x=hidden_states.permute(0, 4, 1, 2, 3),
        timesteps=timestep,
        crossattn_emb=encoder_hidden_states,
        crossattn_mask=attention_mask,
        fps=torch.tensor([fps]),
        padding_mask=padding_mask,
    )
    print()
    ours_output = ours_transformer(
        hidden_states=hidden_states.permute(0, 4, 1, 2, 3),
        timestep=timestep.long(),
        encoder_hidden_states=encoder_hidden_states,
        attention_mask=attention_mask,
        fps=fps,
        padding_mask=padding_mask,
    )[0]

    print(torch.allclose(theirs_output, ours_output, atol=1e-4))

match_transformer()
test transformer video
import sys
sys.path.append("/raid/aryan/cosmos-code/")

import torch
from cosmos1.models.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT


@torch.no_grad()
def match_transformer():
    from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel

    theirs_transformer = VideoExtendGeneralDIT(
        max_img_h=240,
        max_img_w=240,
        max_frames=128,
        in_channels=16 + 1,
        out_channels=16,
        patch_spatial=2,
        patch_temporal=1,
        concat_padding_mask=True,
        block_config="FA-CA-MLP",
        model_channels=4096,
        num_blocks=2,
        num_heads=32,
        mlp_ratio=4,
        block_x_format="THWBD",
        crossattn_emb_channels=1024,
        use_cross_attn_mask=False,
        pos_emb_cls="rope3d",
        pos_emb_learnable=True,
        pos_emb_interpolation="crop",
        affline_emb_norm=True,
        use_adaln_lora=True,
        adaln_lora_dim=256,
        rope_h_extrapolation_ratio=1.0,
        rope_w_extrapolation_ratio=1.0,
        rope_t_extrapolation_ratio=2.0,
        extra_per_block_abs_pos_emb=True,
        extra_per_block_abs_pos_emb_type="learnable",
    )

    ours_transformer = CosmosTransformer3DModel(
        in_channels=16 + 1,
        out_channels=16,
        num_attention_heads=32,
        attention_head_dim=128,
        num_layers=2,
        mlp_ratio=4,
        text_embed_dim=1024,
        adaln_lora_dim=256,
        max_size=(128, 240, 240),
        patch_size=(1, 2, 2),
        rope_scale=(2.0, 1.0, 1.0),
        concat_padding_mask=True,
        extra_pos_embed_type="learnable",
    )

    # Patch embedding
    ours_transformer.patch_embed.proj.weight.data.copy_(theirs_transformer.x_embedder.proj[1].weight.data)

    # Timestep embedding
    ours_t_embedder = ours_transformer.time_embed
    theirs_t_embedder = theirs_transformer.t_embedder
    theirs_norm = theirs_transformer.affline_norm
    ours_t_embedder.t_embedder.linear_1.weight.data.copy_(theirs_t_embedder[1].linear_1.weight.data)
    ours_t_embedder.t_embedder.linear_2.weight.data.copy_(theirs_t_embedder[1].linear_2.weight.data)
    ours_t_embedder.norm.weight.data.copy_(theirs_norm.weight.data)

    # Learnable position embedding
    ours_pe = ours_transformer.learnable_pos_embed
    theirs_pe = theirs_transformer.extra_pos_embedder
    ours_pe.pos_emb_t.data.copy_(theirs_pe.pos_emb_t.data)
    ours_pe.pos_emb_h.data.copy_(theirs_pe.pos_emb_h.data)
    ours_pe.pos_emb_w.data.copy_(theirs_pe.pos_emb_w.data)

    # Transformer blocks
    for i in range(2):
        ours_transformer_block = ours_transformer.transformer_blocks[i]
        theirs_transformer_block = theirs_transformer.blocks[f"block{i}"]
        
        ours_transformer_block.norm1.linear_1.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm1.linear_2.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[2].weight.data)
            
        ours_transformer_block.attn1.to_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn1.to_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn1.to_v.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn1.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn1.norm_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn1.norm_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm2.linear_1.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm2.linear_2.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[2].weight.data)

        ours_transformer_block.attn2.to_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn2.to_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn2.to_v.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn2.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn2.norm_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn2.norm_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm3.linear_1.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm3.linear_2.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[2].weight.data)

        ours_transformer_block.ff.net[0].proj.weight.data.copy_(theirs_transformer_block.blocks[2].block.layer1.weight.data)
        ours_transformer_block.ff.net[2].weight.data.copy_(theirs_transformer_block.blocks[2].block.layer2.weight.data)
    
    # Output layers
    ours_transformer.norm_out.linear_1.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[1].weight.data)
    ours_transformer.norm_out.linear_2.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[2].weight.data)
    ours_transformer.proj_out.weight.data.copy_(theirs_transformer.final_layer.linear.weight.data)

    for name, param in theirs_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)
    for name, param in ours_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)


    # ============
    batch_size = 1
    latent_num_frames = 2
    latent_height = 16
    latent_width = 16
    encoder_seq_length = 64
    encoder_dim = 1024
    fps = 30.0
    
    hidden_states = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, 16)
    timestep = torch.randint(0, 1000, (batch_size,)).float()
    encoder_hidden_states = torch.randn(batch_size, encoder_seq_length, encoder_dim)
    attention_mask = None
    condition_mask = torch.ones(batch_size, 1, latent_num_frames, latent_height, latent_width)
    padding_mask = torch.zeros((1, 1, latent_height * 8, latent_width * 8))

    theirs_output = theirs_transformer(
        x=hidden_states.permute(0, 4, 1, 2, 3),
        timesteps=timestep,
        crossattn_emb=encoder_hidden_states,
        crossattn_mask=attention_mask,
        fps=torch.tensor([fps]),
        condition_video_input_mask=condition_mask,
        padding_mask=padding_mask,
    )
    print()
    ours_output = ours_transformer(
        hidden_states=hidden_states.permute(0, 4, 1, 2, 3),
        timestep=timestep.long(),
        encoder_hidden_states=encoder_hidden_states,
        attention_mask=attention_mask,
        fps=fps,
        condition_mask=condition_mask,
        padding_mask=padding_mask,
    )[0]

    print(torch.allclose(theirs_output, ours_output, atol=1e-4))

match_transformer()

VAE

test vae attention
import sys
sys.path.append("/raid/aryan/cosmos-tokenizer-code/")

from cosmos_tokenizer.modules.layers3d import CausalAttnBlock, CausalTemporalAttnBlock

from typing import Union, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.autoencoders.autoencoder_kl_cosmos import CosmosCausalGroupNorm, CosmosCausalConv3d


class CosmosCausalAttention(nn.Module):
    def __init__(self, num_attention_heads: int, attention_head_dim: int, num_groups: int = 1, dropout: float = 0.0, processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None) -> None:
        super().__init__()
        self.num_attention_heads = num_attention_heads

        self.norm = CosmosCausalGroupNorm(attention_head_dim, num_groups=num_groups)
        self.to_q = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
        self.to_k = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
        self.to_v = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
        self.to_out = nn.ModuleList([])
        self.to_out.append(CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0))
        self.to_out.append(nn.Dropout(dropout))

        self.processor = processor
        if self.processor is None:
            raise ValueError("CosmosCausalAttention requires a processor.")
    
    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.processor(self, hidden_states=hidden_states, attention_mask=attention_mask)


class CosmosSpatialAttentionProcessor2_0:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch.")
    
    def __call__(self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        residual = hidden_states
        
        hidden_states = attn.norm(hidden_states)
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # [B, C, T, H, W] -> [B * T, H * W, C]
        query = query.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
        key = key.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
        value = value.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)

        # [B * T, H * W, C] -> [B * T, N, H * W, C // N]
        query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
        hidden_states = hidden_states.unflatten(1, (height, width)).unflatten(0, (batch_size, num_frames))
        hidden_states = hidden_states.permute(0, 4, 1, 2, 3)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        
        return hidden_states + residual


class CosmosTemporalAttentionProcessor2_0:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch.")
    
    def __call__(self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        residual = hidden_states
        
        hidden_states = attn.norm(hidden_states)
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # [B, C, T, H, W] -> [B * T, H * W, C]
        query = query.permute(0, 3, 4, 2, 1).flatten(0, 2)
        key = key.permute(0, 3, 4, 2, 1).flatten(0, 2)
        value = value.permute(0, 3, 4, 2, 1).flatten(0, 2)

        # [B * T, H * W, C] -> [B * T, N, H * W, C // N]
        query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
        hidden_states = hidden_states.unflatten(0, (batch_size, height, width))
        hidden_states = hidden_states.permute(0, 4, 3, 1, 2)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        
        return hidden_states + residual


@torch.no_grad()
def test_causal_attn_block_spatial():
    in_channels = 128

    torch.manual_seed(0)
    theirs_attn = CausalAttnBlock(in_channels, num_groups=1)
    ours_attn = CosmosCausalAttention(num_attention_heads=1, attention_head_dim=in_channels, num_groups=1, dropout=0.0, processor=CosmosSpatialAttentionProcessor2_0())

    ours_attn.to_q.conv.weight.data.copy_(theirs_attn.q.conv3d.weight.data)
    ours_attn.to_k.conv.weight.data.copy_(theirs_attn.k.conv3d.weight.data)
    ours_attn.to_v.conv.weight.data.copy_(theirs_attn.v.conv3d.weight.data)
    ours_attn.to_out[0].conv.weight.data.copy_(theirs_attn.proj_out.conv3d.weight.data)

    ours_attn.to_q.conv.bias.data.copy_(theirs_attn.q.conv3d.bias.data)
    ours_attn.to_k.conv.bias.data.copy_(theirs_attn.k.conv3d.bias.data)
    ours_attn.to_v.conv.bias.data.copy_(theirs_attn.v.conv3d.bias.data)
    ours_attn.to_out[0].conv.bias.data.copy_(theirs_attn.proj_out.conv3d.bias.data)

    ours_attn.norm.norm.weight.data.copy_(theirs_attn.norm.norm.weight.data)
    ours_attn.norm.norm.bias.data.copy_(theirs_attn.norm.norm.bias.data)

    batch_size = 2
    num_frames = 16
    height = 8
    width = 8

    hidden_states = torch.randn(batch_size, in_channels, num_frames, height, width)
    theirs_output = theirs_attn(hidden_states)
    ours_output = ours_attn(hidden_states)

    diff = theirs_output - ours_output
    print(f"absmax diff: {diff.abs().max()}")
    print(f"absmean diff: {diff.abs().mean()}")


@torch.no_grad()
def test_causal_attn_block_temporal():
    in_channels = 128

    torch.manual_seed(0)
    theirs_attn = CausalTemporalAttnBlock(in_channels, num_groups=1)
    ours_attn = CosmosCausalAttention(num_attention_heads=1, attention_head_dim=in_channels, num_groups=1, dropout=0.0, processor=CosmosTemporalAttentionProcessor2_0())

    ours_attn.to_q.conv.weight.data.copy_(theirs_attn.q.conv3d.weight.data)
    ours_attn.to_k.conv.weight.data.copy_(theirs_attn.k.conv3d.weight.data)
    ours_attn.to_v.conv.weight.data.copy_(theirs_attn.v.conv3d.weight.data)
    ours_attn.to_out[0].conv.weight.data.copy_(theirs_attn.proj_out.conv3d.weight.data)

    ours_attn.to_q.conv.bias.data.copy_(theirs_attn.q.conv3d.bias.data)
    ours_attn.to_k.conv.bias.data.copy_(theirs_attn.k.conv3d.bias.data)
    ours_attn.to_v.conv.bias.data.copy_(theirs_attn.v.conv3d.bias.data)
    ours_attn.to_out[0].conv.bias.data.copy_(theirs_attn.proj_out.conv3d.bias.data)

    ours_attn.norm.norm.weight.data.copy_(theirs_attn.norm.norm.weight.data)
    ours_attn.norm.norm.bias.data.copy_(theirs_attn.norm.norm.bias.data)

    batch_size = 2
    num_frames = 16
    height = 8
    width = 8

    hidden_states = torch.randn(batch_size, in_channels, num_frames, height, width)
    theirs_output = theirs_attn(hidden_states)
    
    attn_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
    ours_output = ours_attn(hidden_states, attn_mask)

    diff = theirs_output - ours_output
    print(f"absmax diff: {diff.abs().max()}")
    print(f"absmean diff: {diff.abs().mean()}")


test_causal_attn_block_temporal()
test_causal_attn_block_spatial()
test vae
import sys
sys.path.append("/raid/aryan/cosmos-tokenizer-code/")

from cosmos_tokenizer.modules import (
    ContinuousFormulation,
    Encoder3DType,
    Decoder3DType,
)
from cosmos_tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer

import torch
from accelerate import init_empty_weights
from typing import Dict, Any

def remove_keys_(key: str, state_dict: Dict[str, Any]):
    state_dict.pop(key)


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


VAE_KEYS_RENAME_DICT = {
    "down.0": "down_blocks.0",
    "down.1": "down_blocks.1",
    "down.2": "down_blocks.2",
    "up.0": "up_blocks.2",
    "up.1": "up_blocks.1",
    "up.2": "up_blocks.0",
    ".block.": ".resnets.",
    "downsample": "downsamplers.0",
    "upsample": "upsamplers.0",
    "mid.block_1": "mid_block.resnets.0",
    "mid.attn_1.0": "mid_block.attentions.0",
    "mid.attn_1.1": "mid_block.temp_attentions.0",
    "mid.block_2": "mid_block.resnets.1",
    ".q.conv3d": ".to_q",
    ".k.conv3d": ".to_k",
    ".v.conv3d": ".to_v",
    ".proj_out.conv3d": ".to_out.0",
    ".0.conv3d": ".conv_s",
    ".1.conv3d": ".conv_t",
    "conv1.conv3d": "conv1",
    "conv2.conv3d": "conv2",
    "conv3.conv3d": "conv3",
    "nin_shortcut.conv3d": "conv_shortcut",
    "quant_conv.conv3d": "quant_conv",
    "post_quant_conv.conv3d": "post_quant_conv",
}

VAE_SPECIAL_KEYS_REMAP = {}


@torch.no_grad()
def test_vae():
    from diffusers import AutoencoderKLCosmos
    
    torch.manual_seed(0)
    
    theirs_config = dict(
        attn_resolutions=[32],
        channels=128,
        channels_mult=[2, 4, 4],
        dropout=0.0,
        in_channels=3,
        num_res_blocks=2,
        out_channels=3,
        resolution=1024,
        patch_size=4,
        patch_method="haar",
        latent_channels=16,
        z_channels=16,
        z_factor=1,
        num_groups=1,
        legacy_mode=False,
        spatial_compression=8,
        temporal_compression=8,
        formulation=ContinuousFormulation.AE.name,
        encoder=Encoder3DType.FACTORIZED.name,
        decoder=Decoder3DType.FACTORIZED.name,
        name="CV",
    )
    theirs_model = CausalContinuousVideoTokenizer(**theirs_config)

    ours_model = AutoencoderKLCosmos()

    # print(theirs_model.decoder)
    # print()
    # print()
    # print()
    # print()
    # print(ours_model.decoder)

    theirs_num_params = sum(p.numel() for p in theirs_model.parameters())
    ours_num_params = sum(p.numel() for p in ours_model.parameters())

    print(f"theirs_num_params: {theirs_num_params}")
    print(f"ours_num_params: {ours_num_params}")

    PREFIX_KEY = ""
    original_state_dict = theirs_model.state_dict()
    
    for key in list(original_state_dict.keys()):
        new_key = key[:]
        if new_key.startswith(PREFIX_KEY):
            new_key = new_key.removeprefix(PREFIX_KEY)
        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)
    
    ours_model.load_state_dict(original_state_dict, strict=True, assign=True)

    batch_size = 2
    num_channels = 3
    num_frames = 49
    height = 256
    width = 256

    hidden_states = torch.randn(batch_size, num_channels, num_frames, height, width)
    theirs_output = theirs_model(hidden_states)["reconstructions"]
    ours_output = ours_model(hidden_states)[0]

    # torch.Size([2, 3, 49, 256, 256]) torch.Size([2, 3, 97, 512, 512])

    print(theirs_output.shape, ours_output.shape)
    diff = theirs_output - ours_output
    print(f"absmax diff: {diff.abs().max()}")
    print(f"absmean diff: {diff.abs().mean()}")


test_vae()

Text-to-World:

import torch
from diffusers import CosmosPipeline
from diffusers.utils import export_to_video

model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
pipe = CosmosPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."

output = pipe(prompt=prompt).frames[0]
export_to_video(output, "output.mp4", fps=30)

Video-to-World (image-conditioning):

import torch
from diffusers import CosmosVideoToWorldPipeline
from diffusers.utils import export_to_video, load_image

model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
)

video = pipe(image=image, prompt=prompt).frames[0]
export_to_video(video, "output.mp4", fps=30)

Video-to-World (video-conditioning):

import torch
from diffusers import CosmosVideoToWorldPipeline
from diffusers.utils import export_to_video, load_video

model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.transformer = torch.compile(pipe.transformer)
pipe.to("cuda")

prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
video = load_video(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
)[:21]  # This example uses only the first 21 frames

video = pipe(video=video, prompt=prompt).frames[0]
export_to_video(video, "output.mp4", fps=30)

Note that the model repos are not yet compatible with Diffusers-loading. I'll open PRs for weights once nvidia team gives the thumbs up.

Inference code (old)
import os
from typing import Any, Dict

import torch
from diffusers import CosmosTransformer3DModel, CosmosPipeline, EDMEulerScheduler, EDMDPMSolverMultistepScheduler
from diffusers.utils import export_to_video
from transformers import T5EncoderModel, T5TokenizerFast


def remove_keys_(key: str, state_dict: Dict[str, Any]):
    state_dict.pop(key)


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
    block_index = int(key.split(".")[1].removeprefix("block"))
    new_key = key

    old_prefix = f"blocks.block{block_index}"
    new_prefix = f"transformer_blocks.{block_index}"
    new_key = new_prefix + new_key.removeprefix(old_prefix)
    
    state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
    "t_embedder.1": "time_embed.t_embedder",
    "affline_norm": "time_embed.norm",
    ".blocks.0.block.attn": ".attn1",
    ".blocks.1.block.attn": ".attn2",
    ".blocks.2.block": ".ff",
    ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
    ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
    ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
    ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
    ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
    ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
    "to_q.0": "to_q",
    "to_q.1": "norm_q",
    "to_k.0": "to_k",
    "to_k.1": "norm_k",
    "to_v.0": "to_v",
    "layer1": "net.0.proj",
    "layer2": "net.2",
    "proj.1": "proj",
    "x_embedder": "patch_embed",
    "extra_pos_embedder": "learnable_pos_embed",
    "final_layer.adaLN_modulation.1": "norm_out.linear_1",
    "final_layer.adaLN_modulation.2": "norm_out.linear_2",
    "final_layer.linear": "proj_out",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
    "blocks.block": rename_transformer_blocks_,
    "logvar.0.freqs": remove_keys_,
    "logvar.0.phases": remove_keys_,
    "logvar.1.weight": remove_keys_,
    "pos_embedder.seq": remove_keys_,
}


def convert_transformer(state_dict):
    PREFIX_KEY = "net."
    for key in list(state_dict.keys()):
        new_key = key[:]
        if new_key.startswith(PREFIX_KEY):
            new_key = key[len(PREFIX_KEY) :]
        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_inplace(state_dict, key, new_key)

    for key in list(state_dict.keys()):
        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, state_dict)
    
    return state_dict


torch.manual_seed(0)
device = "cuda"
dtype = torch.bfloat16

with torch.no_grad():
    with torch.device("meta"):
        transformer = CosmosTransformer3DModel()
    num_parameters = sum(p.numel() for p in transformer.parameters())
    print(f"{num_parameters=}")

    checkpoint_file = "/raid/aryan/cosmos-code/checkpoints/Cosmos-1.0-Diffusion-7B-Text2World/model.pt"
    checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
    checkpoint = convert_transformer(checkpoint)
    transformer.load_state_dict(checkpoint, strict=True, assign=True)

    text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b", torch_dtype=dtype, cache_dir="/raid/aryan/cosmos-code/checkpoints")
    tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b", cache_dir="/raid/aryan/cosmos-code/checkpoints")

    vae_dir = "/raid/aryan/cosmos-code/checkpoints/Cosmos-1.0-Tokenizer-CV8x8x8"
    decoder = torch.jit.load(os.path.join(vae_dir, "decoder.jit")).to(device=device, dtype=dtype)
    latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=True)

    scheduler = EDMEulerScheduler(final_sigmas_type="sigma_min")
    
    pipe = CosmosPipeline(text_encoder, tokenizer, transformer, vae=None, scheduler=scheduler)
    pipe.to(device, dtype=dtype)

    prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
    negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."

    latents = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=704,
        width=960,
        # width=1280,
        num_frames=121,
        num_inference_steps=36,
        output_type="latent",
    ).frames
    torch.save(latents, "latents.pt")

    latent_mean = latent_mean.to(device=device).reshape(1, 16, 16, 1, 1).float()[:, :, :latents.size(2)]
    latent_std = latent_std.to(device=device).reshape(1, 16, 16, 1, 1).float()[:, :, :latents.size(2)]

    sigma_data = 0.5
    latents = latents / sigma_data
    latents = (latents.float() * latent_std + latent_mean).type_as(latents)
    output = decoder(latents.to(device=device, dtype=dtype))

    video = pipe.video_processor.postprocess_video(output, output_type="pil")[0]

    export_to_video(video, "output.mp4", fps=30)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Feb 4, 2025
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Feb 18, 2025

To match our sigmas to original exactly, without any rounding errors, I had to use torch.float64. This change is maybe not required since the values are nearly the same, but just something to keep in mind.

# theirs:             [80.0, 68.32506, 58.14207, 49.28863, 41.61683, 34.99219, 29.29279, 24.40834, 20.23932, 16.69618, 13.69857, 11.17463, 9.06026, 7.29851, 5.83895, 4.63707, 3.65381, 2.85496, 2.21074, 1.69537, 1.28661, 0.96542, 0.71556, 0.52331, 0.37715, 0.26748, 0.18636, 0.12731, 0.08509, 0.05548, 0.03519, 0.02162, 0.01281, 0.00727, 0.00393, 0.002]
# ours_original:      [79.99998, 68.00508, 57.58597, 48.56622, 40.78557, 34.09878, 28.37458, 23.49461, 19.35245, 15.8527, 12.91008, 10.44864, 8.40094, 6.70731, 5.31519, 4.17847, 3.25682, 2.51522, 1.92334, 1.4551, 1.08817, 0.80359, 0.58535, 0.42002, 0.29644, 0.20544, 0.13952, 0.09262, 0.05995, 0.03769, 0.02293, 0.01343, 0.00753, 0.004, 0.002, 0.0]
# ours_modified:      [80.0, 68.32506, 58.14206, 49.28863, 41.61682, 34.99219, 29.29279, 24.40834, 20.23932, 16.69618, 13.69858, 11.17463, 9.06026, 7.29851, 5.83895, 4.63708, 3.65381, 2.85496, 2.21074, 1.69537, 1.28661, 0.96542, 0.71556, 0.52331, 0.37715, 0.26748, 0.18636, 0.12731, 0.08509, 0.05548, 0.03519, 0.02162, 0.01281, 0.00727, 0.00393, 0.002, 0.0]

Also, we only match the sigmas if we set our our_num_inference_steps=their_num_inference_steps + 1. This is because they do an extra inference step without scheduler step (effective same as setting our final_sigmas_type="sigma_min"

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review February 25, 2025 11:38
@a-r-r-o-w
Copy link
Member Author

The latest push makes it so that the Video2World models can run end-to-end with diffusers. The T2W pipeline produces good outputs but V2W pipeline still generates garbage -- I'm debugging it. I've matched the transformers for both T2W and V2W though (everything is matching and I've updated the description with test code), so the bug is most likely in the pipeline implementation.

Comment on lines 626 to 694
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(transformer_dtype)

current_sigma = self.scheduler.sigmas[i]
is_augment_sigma_greater = augment_sigma >= current_sigma

current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
cond_latent = cond_latent.to(transformer_dtype)

noise_pred = self.transformer(
hidden_states=cond_latent,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]

if self.do_classifier_free_guidance:
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
uncond_latent = uncond_latent.to(transformer_dtype)

noise_pred_uncond = self.transformer(
hidden_states=uncond_latent,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
condition_mask=uncond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = torch.cat([noise_pred_uncond, noise_pred])

# pred_original_sample (x0)
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[1]
self.scheduler._step_index -= 1

if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred_uncond = (
current_uncond_indicator * conditioning_latents
+ (1 - current_uncond_indicator) * noise_pred_uncond
)
noise_pred_cond = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
)
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
)

# pred_sample (eps)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
)[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too happy about this hack around our scheduler, but this seems like the only way to make it work (the outputs are atleast no longer random garbage; following the conditioning is still not fixed).

The original code seems to be applying CFG on the x0-prediction, followed by obtaining the eps-prediction. I've made the same change related to CFG on x0-pred in the Text-to-World pipeline as well. Would be great if a second set of eyes wanted to give it a look but I do think the implementation regarding is right. The hack around scheduler._step_index is necessary to make sure we can compute the eps-pred using the augmented x0-pred.

Here's the relevant code:

I'm still trying to figure out bug related to conditioning:

output2.mp4

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After even more hacking around our scheduler design, I seem to get something decent finally. It's still not following the conditioning fully but I don't really see any glaring differences any more 😅

Edit: this video looks buggy because I'm using a completely different prompt compared to input image. But the latest version works with the conditioning correctly if provided related image/prompts

output2.mp4

@a-r-r-o-w
Copy link
Member Author

Okay, it's working!

Image condition:

output2.mp4

Video condition:

cosmos-video2world-input-vid.mp4
output3.mp4

@a-r-r-o-w
Copy link
Member Author

@yiyixuxu @hlky Made some more changes to the schedulers. Please take a look when you can. The pipeline forward implementations are also a bit hacky, but this is the only way I was able to make it work after extensively playing around with trying to fit to our scheduler design

@pjannaty
Copy link

Well done! This will go a long way in supporting the wider adoption of Cosmos.

Copy link
Member

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment @a-r-r-o-w, changes are good either way though!

Comment on lines +707 to +708
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
self.scheduler._step_index -= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do

noise_pred = self.scheduler.precondition_outputs(sample, noise_pred, current_sigma)

It's safe to use sigma as sigma_hat. s_tmin and s_tmax are rarely used (never seen it used myself) and not supported in some other schedulers for that reason, in turn gamma is 0 and sigma_hat is the same as sigma.

@asfiyab-nvidia
Copy link
Contributor

This PR doesn't seem to include the guardrail model: https://huggingface.co/nvidia/Cosmos-1.0-Guardrail
Will this be included in a follow up PR?

@a-r-r-o-w
Copy link
Member Author

@asfiyab-nvidia I didn't think to add the guardrail models because they essentially work as preprocessors/postprocessors outside the core diffusion-related aspects. Can definitely do a follow-up adding support for it.

Additionally, the prompt upsampler isn't added for the similar reasons. The upsampling can be run via any language model (independent of diffusers), but I'll update the docs to point to Pixtral-12B as used in original codebase as an example.

This PR contains only the parts relevant for running the diffusion sampling and generating videos.

@asfiyab-nvidia
Copy link
Contributor

@a-r-r-o-w Not including the guardrail model violates the License in https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Text2World. cc @pjannaty for comment on this

@a-r-r-o-w
Copy link
Member Author

@asfiyab-nvidia Thanks for the notice! I didn't check the license until now. In that case, I'll implement the guardrails tomorrow.

@a-r-r-o-w
Copy link
Member Author

@asfiyab-nvidia @pjannaty The CosmosGuardrail has been integrated as well. The relevant class to review is CosmosSafetyChecker. The 4 guardrails have been taken directly from the Cosmos codebase with minimal modification to make it run directly without users first needing to manually download the checkpoints.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Mar 21, 2025

Thanks for patiently reviewing this! If everything looks good to merge, please let us know.

We plan to do a diffusers release over the weekend or on Monday. It would be great to ship the Cosmos integration as well for this release cycle. In order to proceed with that, we'll have to host diffusers-format weights for the following repositories:

To host the weights, none of the existing files will be modified apart from README.md (which we can use to showcase how to run inference with diffusers). The diffusers-format folder structure would look something like:

I've opened an example PR for the 7B Text-to-World weights here: https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Text2World/discussions/9

Once I have the go from your end that these changes are good, I can open up PRs to all the other repositories

@asfiyab-nvidia
Copy link
Contributor

@a-r-r-o-w I'm running into the below issue during pipeline load FYI. Is this expected?

    pipe = CosmosPipeline.from_pretrained(model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/pipeline_utils.py", line 1023, in from_pretrained
    model = pipeline_class(**init_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/cosmos/pipeline_cosmos.py", line 162, in __init__
    safety_checker = CosmosSafetyChecker()
                     ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/cosmos/cosmos_guardrail.py", line 716, in __init__
    Blocklist(checkpoint_id),
    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/cosmos/cosmos_guardrail.py", line 287, in __init__
    self.profanity = profanity
                     ^^^^^^^^^
NameError: name 'profanity' is not defined

@asfiyab-nvidia
Copy link
Contributor

Another note re the attention definition here. Enabling GQA breaks ONNX export due to https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset14.py#L152. Can this be addressed?

@a-r-r-o-w
Copy link
Member Author

@asfiyab-nvidia I'm testing a non-enable_gqa version and will try to update asap.

I'm not sure why you get the error about profanity not being defined. The following code seems to work for me without errors:

import torch
from diffusers import CosmosPipeline
from diffusers.utils import export_to_video

model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
pipe = CosmosPipeline.from_pretrained(model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."

output = pipe(prompt=prompt).frames[0]
export_to_video(output, "output.mp4", fps=30)
output.mp4

I'll try to dig in more soon to see if it errors out with a different environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants