-
Notifications
You must be signed in to change notification settings - Fork 178
support longcat-image block offload with 2 mgr #977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| { | ||
| "infer_steps": 50, | ||
| "aspect_ratio": "16:9", | ||
| "attn_type": "sage_attn2", | ||
| "enable_cfg": true, | ||
| "sample_guide_scale": 4.0, | ||
| "enable_cfg_renorm": true, | ||
| "cfg_renorm_min": 0.0, | ||
| "axes_dims_rope": [16, 56, 56], | ||
| "dit_quant_scheme": "Default", | ||
| "rms_norm_type": "sgl-kernel", | ||
| "cpu_offload": true, | ||
| "offload_granularity": "block" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| import torch | ||
|
|
||
| from lightx2v.common.offload.manager import WeightAsyncStreamManager | ||
| from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer | ||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||
|
|
||
| torch_device_module = getattr(torch, AI_DEVICE) | ||
|
|
||
|
|
||
| class LongCatImageOffloadTransformerInfer(LongCatImageTransformerInfer): | ||
| """Offload transformer inference for LongCat Image model. | ||
|
|
||
| Supports block-level offload with double-buffer async prefetch for both | ||
| double-stream blocks and single-stream blocks. | ||
| """ | ||
|
|
||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| if self.config.get("cpu_offload", False): | ||
| offload_granularity = self.config.get("offload_granularity", "block") | ||
| if offload_granularity == "block": | ||
| self.infer_func = self.infer_with_blocks_offload | ||
| if offload_granularity != "model": | ||
| self.offload_manager_double = WeightAsyncStreamManager(offload_granularity=offload_granularity) | ||
| self.offload_manager_single = WeightAsyncStreamManager(offload_granularity=offload_granularity) | ||
|
|
||
| def infer_with_blocks_offload(self, blocks, pre_infer_out): | ||
| """Run transformer inference with block-level offload. | ||
|
|
||
| Two-phase approach: first process all double blocks, then all single blocks, | ||
| each with their own offload manager and cuda buffers. | ||
| """ | ||
| hidden_states = pre_infer_out.hidden_states | ||
| encoder_hidden_states = pre_infer_out.encoder_hidden_states | ||
| temb = pre_infer_out.temb | ||
| image_rotary_emb = pre_infer_out.image_rotary_emb | ||
|
|
||
| # For I2I task: concatenate output latents with input image latents | ||
| output_seq_len = None | ||
| if pre_infer_out.input_image_latents is not None: | ||
| output_seq_len = pre_infer_out.output_seq_len | ||
| hidden_states = torch.cat([hidden_states, pre_infer_out.input_image_latents], dim=0) | ||
|
|
||
| # Stage 1: double blocks offload | ||
| # wait for default stream | ||
| current_stream = torch_device_module.current_stream() | ||
| self.offload_manager_double.compute_stream.wait_stream(current_stream) | ||
| for block_idx in range(len(blocks.double_blocks)): | ||
| self.block_idx = block_idx | ||
|
|
||
| if self.offload_manager_double.need_init_first_buffer: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我看了下,像是这里,如果每个step的block id=0的时候,都初始化下buffer,那么上面的wait_stream就不用加了,结果也对。感觉可能是上一个step结束的时候,某个step开始之前,swap_blocks在没有完成?所以需要wait下。不过我感觉不是很影响速度,可以先merge |
||
| self.offload_manager_double.init_first_buffer(blocks.double_blocks) | ||
|
|
||
| self.offload_manager_double.prefetch_weights((block_idx + 1) % len(blocks.double_blocks), blocks.double_blocks) | ||
|
|
||
| with torch_device_module.stream(self.offload_manager_double.compute_stream): | ||
| encoder_hidden_states, hidden_states = self.infer_double_stream_block( | ||
| self.offload_manager_double.cuda_buffers[0], | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| temb, | ||
| image_rotary_emb, | ||
| ) | ||
|
|
||
| self.offload_manager_double.swap_blocks() | ||
|
|
||
| # Stage 2: single blocks offload | ||
| # wait for double stream | ||
| self.offload_manager_single.compute_stream.wait_stream(self.offload_manager_double.compute_stream) | ||
| for block_idx in range(len(blocks.single_blocks)): | ||
| self.block_idx = block_idx | ||
|
|
||
| if self.offload_manager_single.need_init_first_buffer: | ||
| self.offload_manager_single.init_first_buffer(blocks.single_blocks) | ||
|
|
||
| self.offload_manager_single.prefetch_weights((block_idx + 1) % len(blocks.single_blocks), blocks.single_blocks) | ||
|
|
||
| with torch_device_module.stream(self.offload_manager_single.compute_stream): | ||
| encoder_hidden_states, hidden_states = self.infer_single_stream_block( | ||
| self.offload_manager_single.cuda_buffers[0], | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| temb, | ||
| image_rotary_emb, | ||
| ) | ||
|
|
||
| self.offload_manager_single.swap_blocks() | ||
|
|
||
| # For I2I task: only return output image latents | ||
| if output_seq_len is not None: | ||
| hidden_states = hidden_states[:output_seq_len] | ||
|
|
||
| return hidden_states | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| import torch.distributed as dist | ||
|
|
||
| from lightx2v.models.networks.base_model import BaseTransformerModel | ||
| from lightx2v.models.networks.longcat_image.infer.offload.transformer_infer import LongCatImageOffloadTransformerInfer | ||
| from lightx2v.models.networks.longcat_image.infer.post_infer import LongCatImagePostInfer | ||
| from lightx2v.models.networks.longcat_image.infer.pre_infer import LongCatImagePreInfer | ||
| from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer | ||
|
|
@@ -23,7 +24,7 @@ class LongCatImageTransformerModel(BaseTransformerModel): | |
| transformer_weight_class = LongCatImageTransformerWeights | ||
| post_weight_class = LongCatImagePostWeights | ||
|
|
||
| def __init__(self, config, model_path, device): | ||
| def __init__(self, model_path, config, device): | ||
| super().__init__(model_path, config, device) | ||
| # Use transformer_in_channels to avoid conflict with VAE's in_channels | ||
| self.in_channels = self.config.get("transformer_in_channels", self.config.get("in_channels", 64)) | ||
|
|
@@ -35,17 +36,25 @@ def __init__(self, config, model_path, device): | |
| self._init_infer() | ||
|
|
||
| def _init_infer_class(self): | ||
| self.transformer_infer_class = LongCatImageTransformerInfer | ||
| if self.cpu_offload and self.offload_granularity == "block": | ||
| self.transformer_infer_class = LongCatImageOffloadTransformerInfer | ||
| else: | ||
| self.transformer_infer_class = LongCatImageTransformerInfer | ||
|
Comment on lines
+39
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic here only handles |
||
| self.pre_infer_class = LongCatImagePreInfer | ||
| self.post_infer_class = LongCatImagePostInfer | ||
|
|
||
| def _init_infer(self): | ||
| self.transformer_infer = self.transformer_infer_class(self.config) | ||
| self.pre_infer = self.pre_infer_class(self.config) | ||
| self.post_infer = self.post_infer_class(self.config) | ||
| if hasattr(self.transformer_infer, "offload_manager"): | ||
| if hasattr(self.transformer_infer, "offload_manager_double") and hasattr(self.transformer_infer, "offload_manager_single"): | ||
| self._init_offload_manager() | ||
|
|
||
| def _init_offload_manager(self): | ||
| """Initialize offload managers for double and single block buffers.""" | ||
| self.transformer_infer.offload_manager_double.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_double_block_cuda_buffers) | ||
| self.transformer_infer.offload_manager_single.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_single_block_cuda_buffers) | ||
|
|
||
| @torch.no_grad() | ||
| def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True): | ||
| self.scheduler.infer_condition = infer_condition | ||
|
|
@@ -77,7 +86,11 @@ def _seq_parallel_post_process(self, x): | |
| @torch.no_grad() | ||
| def infer(self, inputs): | ||
| if self.cpu_offload: | ||
| self.to_cuda() | ||
| if self.offload_granularity == "model": | ||
| self.to_cuda() | ||
| elif self.offload_granularity == "block": | ||
| self.pre_weight.to_cuda() | ||
| self.post_weight.to_cuda() | ||
|
|
||
| latents = self.scheduler.latents | ||
|
|
||
|
|
@@ -129,3 +142,10 @@ def infer(self, inputs): | |
| # ==================== No CFG Processing ==================== | ||
| noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True) | ||
| self.scheduler.noise_pred = noise_pred | ||
|
|
||
| if self.cpu_offload: | ||
| if self.offload_granularity == "model": | ||
| self.to_cpu() | ||
| elif self.offload_granularity == "block": | ||
| self.pre_weight.to_cpu() | ||
| self.post_weight.to_cpu() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attribute
self.block_idxis set here but does not appear to be used anywhere within this class or its parentLongCatImageTransformerInfer. If it's not required for external hooks or profiling, it should be removed to avoid confusion.