Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/longcat_image/longcat_image_t2i_offload.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"infer_steps": 50,
"aspect_ratio": "16:9",
"attn_type": "sage_attn2",
"enable_cfg": true,
"sample_guide_scale": 4.0,
"enable_cfg_renorm": true,
"cfg_renorm_min": 0.0,
"axes_dims_rope": [16, 56, 56],
"dit_quant_scheme": "Default",
"rms_norm_type": "sgl-kernel",
"cpu_offload": true,
"offload_granularity": "block"
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch

from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


class LongCatImageOffloadTransformerInfer(LongCatImageTransformerInfer):
"""Offload transformer inference for LongCat Image model.

Supports block-level offload with double-buffer async prefetch for both
double-stream blocks and single-stream blocks.
"""

def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
self.infer_func = self.infer_with_blocks_offload
if offload_granularity != "model":
self.offload_manager_double = WeightAsyncStreamManager(offload_granularity=offload_granularity)
self.offload_manager_single = WeightAsyncStreamManager(offload_granularity=offload_granularity)

def infer_with_blocks_offload(self, blocks, pre_infer_out):
"""Run transformer inference with block-level offload.

Two-phase approach: first process all double blocks, then all single blocks,
each with their own offload manager and cuda buffers.
"""
hidden_states = pre_infer_out.hidden_states
encoder_hidden_states = pre_infer_out.encoder_hidden_states
temb = pre_infer_out.temb
image_rotary_emb = pre_infer_out.image_rotary_emb

# For I2I task: concatenate output latents with input image latents
output_seq_len = None
if pre_infer_out.input_image_latents is not None:
output_seq_len = pre_infer_out.output_seq_len
hidden_states = torch.cat([hidden_states, pre_infer_out.input_image_latents], dim=0)

# Stage 1: double blocks offload
# wait for default stream
current_stream = torch_device_module.current_stream()
self.offload_manager_double.compute_stream.wait_stream(current_stream)
for block_idx in range(len(blocks.double_blocks)):
self.block_idx = block_idx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attribute self.block_idx is set here but does not appear to be used anywhere within this class or its parent LongCatImageTransformerInfer. If it's not required for external hooks or profiling, it should be removed to avoid confusion.


if self.offload_manager_double.need_init_first_buffer:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看了下,像是这里,如果每个step的block id=0的时候,都初始化下buffer,那么上面的wait_stream就不用加了,结果也对。感觉可能是上一个step结束的时候,某个step开始之前,swap_blocks在没有完成?所以需要wait下。不过我感觉不是很影响速度,可以先merge

self.offload_manager_double.init_first_buffer(blocks.double_blocks)

self.offload_manager_double.prefetch_weights((block_idx + 1) % len(blocks.double_blocks), blocks.double_blocks)

with torch_device_module.stream(self.offload_manager_double.compute_stream):
encoder_hidden_states, hidden_states = self.infer_double_stream_block(
self.offload_manager_double.cuda_buffers[0],
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)

self.offload_manager_double.swap_blocks()

# Stage 2: single blocks offload
# wait for double stream
self.offload_manager_single.compute_stream.wait_stream(self.offload_manager_double.compute_stream)
for block_idx in range(len(blocks.single_blocks)):
self.block_idx = block_idx

if self.offload_manager_single.need_init_first_buffer:
self.offload_manager_single.init_first_buffer(blocks.single_blocks)

self.offload_manager_single.prefetch_weights((block_idx + 1) % len(blocks.single_blocks), blocks.single_blocks)

with torch_device_module.stream(self.offload_manager_single.compute_stream):
encoder_hidden_states, hidden_states = self.infer_single_stream_block(
self.offload_manager_single.cuda_buffers[0],
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)

self.offload_manager_single.swap_blocks()

# For I2I task: only return output image latents
if output_seq_len is not None:
hidden_states = hidden_states[:output_seq_len]

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, config):
self.config = config
self.infer_conditional = True
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)

self.infer_func = self.infer_without_offload
# Sequence parallel settings
if self.config.get("seq_parallel", False):
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
Expand Down Expand Up @@ -279,7 +279,7 @@ def infer_single_stream_block(

return encoder_hidden_states, hidden_states

def infer(self, block_weights, pre_infer_out):
def infer_without_offload(self, block_weights, pre_infer_out):
"""Run transformer inference through all blocks.

Args:
Expand Down Expand Up @@ -325,3 +325,7 @@ def infer(self, block_weights, pre_infer_out):
hidden_states = hidden_states[:output_seq_len]

return hidden_states

def infer(self, block_weights, pre_infer_out):
hidden_states = self.infer_func(block_weights, pre_infer_out)
return hidden_states
28 changes: 24 additions & 4 deletions lightx2v/models/networks/longcat_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.distributed as dist

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.longcat_image.infer.offload.transformer_infer import LongCatImageOffloadTransformerInfer
from lightx2v.models.networks.longcat_image.infer.post_infer import LongCatImagePostInfer
from lightx2v.models.networks.longcat_image.infer.pre_infer import LongCatImagePreInfer
from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer
Expand All @@ -23,7 +24,7 @@ class LongCatImageTransformerModel(BaseTransformerModel):
transformer_weight_class = LongCatImageTransformerWeights
post_weight_class = LongCatImagePostWeights

def __init__(self, config, model_path, device):
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
# Use transformer_in_channels to avoid conflict with VAE's in_channels
self.in_channels = self.config.get("transformer_in_channels", self.config.get("in_channels", 64))
Expand All @@ -35,17 +36,25 @@ def __init__(self, config, model_path, device):
self._init_infer()

def _init_infer_class(self):
self.transformer_infer_class = LongCatImageTransformerInfer
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = LongCatImageOffloadTransformerInfer
else:
self.transformer_infer_class = LongCatImageTransformerInfer
Comment on lines +39 to +42
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic here only handles offload_granularity == "block". If cpu_offload is enabled but offload_granularity is set to something else (e.g., "phase"), it falls back to the base LongCatImageTransformerInfer. However, the infer method (lines 92-98) only handles "model" and "block" granularities. If a different granularity is provided, the weights will remain on CPU while computation is attempted on GPU, leading to a device mismatch error. Consider adding a check or defaulting to a supported mode.

self.pre_infer_class = LongCatImagePreInfer
self.post_infer_class = LongCatImagePostInfer

def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config)
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
if hasattr(self.transformer_infer, "offload_manager_double") and hasattr(self.transformer_infer, "offload_manager_single"):
self._init_offload_manager()

def _init_offload_manager(self):
"""Initialize offload managers for double and single block buffers."""
self.transformer_infer.offload_manager_double.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_double_block_cuda_buffers)
self.transformer_infer.offload_manager_single.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_single_block_cuda_buffers)

@torch.no_grad()
def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True):
self.scheduler.infer_condition = infer_condition
Expand Down Expand Up @@ -77,7 +86,11 @@ def _seq_parallel_post_process(self, x):
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
self.to_cuda()
if self.offload_granularity == "model":
self.to_cuda()
elif self.offload_granularity == "block":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()

latents = self.scheduler.latents

Expand Down Expand Up @@ -129,3 +142,10 @@ def infer(self, inputs):
# ==================== No CFG Processing ====================
noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
self.scheduler.noise_pred = noise_pred

if self.cpu_offload:
if self.offload_granularity == "model":
self.to_cpu()
elif self.offload_granularity == "block":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
Loading