Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions lightx2v/shot_runner/rs2v_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ def load_audio(audio_path, target_sr):
pipe.check_stop()

# Calculate actual target_video_length for this segment based on audio length
segment_actual_video_frames = None # Track for last-segment trimming
if is_last and pad_len > 0:
# For the last segment with padding, calculate actual video frames needed
actual_audio_samples = audio_clip.shape[1] - pad_len
actual_video_frames = int(np.ceil(actual_audio_samples / audio_per_frame))
segment_actual_video_frames = actual_video_frames
# Apply the formula to ensure VAE stride constraint
segment_target_video_length = calculate_target_video_length_from_duration(actual_video_frames / target_fps, target_fps)
clip_input_info.target_video_length = segment_target_video_length
Expand Down Expand Up @@ -206,15 +208,18 @@ def load_audio(audio_path, target_sr):
gen_clip_video, audio_clip, gen_latents = rs2v.run_clip_main()
logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}")

video_pad_len = pad_len // audio_per_frame
audio_pad_len = video_pad_len * audio_per_frame
video_seg = gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len]
# Since audio_clip is now multidimensional (N, T), slice on dim 1 and sum on dim 0 to merge tracks
audio_seg = audio_clip[:, : audio_clip.shape[1] - audio_pad_len].sum(dim=0)
if segment_actual_video_frames is not None:
# Last segment: trim to exact actual frames needed
video_seg = gen_clip_video[:, :, :segment_actual_video_frames]
audio_seg = audio_clip[:, : segment_actual_video_frames * audio_per_frame].sum(dim=0)
else:
video_seg = gen_clip_video
audio_seg = audio_clip.sum(dim=0)

clip_input_info.overlap_latent = gen_latents[:, -1:]

if clip_input_info.return_result_tensor or not clip_input_info.stream_save_video:
gen_video_list.append(video_seg.clone().cpu().float())
gen_video_list.append(video_seg.clone().cpu())
cut_audio_list.append(audio_seg.cpu())
elif self.va_controller.recorder is not None:
video_seg = torch.clamp(video_seg, -1, 1).to(torch.float).cpu()
Expand Down
Loading