Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def main():
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape")
parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip")
parser.add_argument("--aspect_ratio", type=str, default="")
parser.add_argument("--video_path", type=str, default=None, help="input video path(for sr/v2v task)")
parser.add_argument("--sr_ratio", type=float, default=2.0, help="super resolution ratio for sr task")
Expand Down
11 changes: 9 additions & 2 deletions lightx2v/models/runners/wan/wan_audio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.input_info import UNSET
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, fixed_shape_resize, get_optimal_patched_size_with_sp, isotropic_crop_resize, load_weights, wan_vae_to_comfy
Expand Down Expand Up @@ -315,8 +316,14 @@ def read_audio_input(self, audio_path):
if expected_frames < int(self.video_duration * target_fps):
logger.warning(f"Input video duration is greater than actual audio duration, using audio duration instead: audio_duration={audio_len / target_fps}, video_duration={self.video_duration}")

# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
# Segment audio (CLI / input_info wins over config_json; target_video_length is not merged into config)
target_video_length = self.config.get("target_video_length", 81)
ii = getattr(self, "input_info", None)
if ii is not None and hasattr(ii, "target_video_length"):
tvl = ii.target_video_length
if tvl is not None and tvl is not UNSET and tvl > 0:
target_video_length = tvl
Comment on lines +320 to +325
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了提高代码的简洁性和可读性,可以对这部分逻辑进行重构。目前的嵌套 if 语句有些冗长,可以将其简化为更紧凑、更符合 Python 风格的形式。

Suggested change
target_video_length = self.config.get("target_video_length", 81)
ii = getattr(self, "input_info", None)
if ii is not None and hasattr(ii, "target_video_length"):
tvl = ii.target_video_length
if tvl is not None and tvl is not UNSET and tvl > 0:
target_video_length = tvl
target_video_length = self.config.get("target_video_length", 81)
ii = getattr(self, "input_info", None)
if ii is not None:
tvl = getattr(ii, "target_video_length", None)
if tvl not in (None, UNSET) and tvl > 0:
target_video_length = tvl

audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, target_video_length, self.prev_frame_length)

# Mask latent for multi-person s2v
if mask_files is not None:
Expand Down
2 changes: 2 additions & 0 deletions lightx2v/utils/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class S2VInputInfo:
resized_shape: list = field(default_factory=list)
latent_shape: list = field(default_factory=list)
target_shape: list = field(default_factory=list)
target_video_length: int = field(default_factory=int)

# prev info
overlap_frame: torch.Tensor = field(default_factory=lambda: None)
Expand Down Expand Up @@ -148,6 +149,7 @@ class RS2VInputInfo:
resized_shape: list = field(default_factory=list)
latent_shape: list = field(default_factory=list)
target_shape: list = field(default_factory=list)
target_video_length: int = field(default_factory=int)

# prev info
overlap_frame: torch.Tensor = field(default_factory=lambda: None)
Expand Down
Loading