Skip to content

Commit dbe70b6

Browse files
Add a VideoSlice node (Comfy-Org#12107)
* Base TrimVideo implementation * Raise error if as_trimmed call fails * Bigger max start_time, tooltips, and formatting * Count packets unless codec has subframes * Remove incorrect nested decode * Add null check for audio streams * Support non-strict duration * Added strict_duration bool to node definition * Empty commit for approval * Fix duration * Support 5.1 audio layout on save --------- Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
1 parent 00fff60 commit dbe70b6

File tree

3 files changed

+207
-60
lines changed

3 files changed

+207
-60
lines changed

comfy_api/latest/_input/video_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def save_to(
3434
"""
3535
pass
3636

37+
@abstractmethod
38+
def as_trimmed(
39+
self,
40+
start_time: float | None = None,
41+
duration: float | None = None,
42+
strict_duration: bool = False,
43+
) -> VideoInput | None:
44+
"""
45+
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
46+
47+
Returns:
48+
A new VideoInput, or None if the result would have negative duration
49+
"""
50+
pass
51+
3752
def get_stream_source(self) -> Union[str, io.BytesIO]:
3853
"""
3954
Get a streamable source for the video. This allows processing without

comfy_api/latest/_input_impl/video_types.py

Lines changed: 141 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .._input import AudioInput, VideoInput
77
import av
88
import io
9+
import itertools
910
import json
1011
import numpy as np
1112
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
2930
formats = container_format.split(",")
3031
return formats[0]
3132

32-
3333
def get_open_write_kwargs(
3434
dest: str | io.BytesIO, container_format: str, to_format: str | None
3535
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
5757
Class representing video input from a file.
5858
"""
5959

60-
def __init__(self, file: str | io.BytesIO):
60+
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
6161
"""
6262
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
6363
containing the file contents.
6464
"""
6565
self.__file = file
66+
self.__start_time = start_time
67+
self.__duration = duration
6668

6769
def get_stream_source(self) -> str | io.BytesIO:
6870
"""
@@ -96,6 +98,16 @@ def get_duration(self) -> float:
9698
Returns:
9799
Duration in seconds
98100
"""
101+
raw_duration = self._get_raw_duration()
102+
if self.__start_time < 0:
103+
duration_from_start = min(raw_duration, -self.__start_time)
104+
else:
105+
duration_from_start = raw_duration - self.__start_time
106+
if self.__duration:
107+
return min(self.__duration, duration_from_start)
108+
return duration_from_start
109+
110+
def _get_raw_duration(self) -> float:
99111
if isinstance(self.__file, io.BytesIO):
100112
self.__file.seek(0)
101113
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ def get_duration(self) -> float:
113125
if video_stream and video_stream.average_rate:
114126
frame_count = 0
115127
container.seek(0)
116-
for packet in container.demux(video_stream):
117-
for _ in packet.decode():
118-
frame_count += 1
128+
frame_iterator = (
129+
container.decode(video_stream)
130+
if video_stream.codec.capabilities & 0x100
131+
else container.demux(video_stream)
132+
)
133+
for packet in frame_iterator:
134+
frame_count += 1
119135
if frame_count > 0:
120136
return float(frame_count / video_stream.average_rate)
121137

@@ -131,36 +147,54 @@ def get_frame_count(self) -> int:
131147

132148
with av.open(self.__file, mode="r") as container:
133149
video_stream = self._get_first_video_stream(container)
134-
# 1. Prefer the frames field if available
135-
if video_stream.frames and video_stream.frames > 0:
150+
# 1. Prefer the frames field if available and usable
151+
if (
152+
video_stream.frames
153+
and video_stream.frames > 0
154+
and not self.__start_time
155+
and not self.__duration
156+
):
136157
return int(video_stream.frames)
137158

138159
# 2. Try to estimate from duration and average_rate using only metadata
139-
if container.duration is not None and video_stream.average_rate:
140-
duration_seconds = float(container.duration / av.time_base)
141-
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
142-
if estimated_frames > 0:
143-
return estimated_frames
144-
145160
if (
146161
getattr(video_stream, "duration", None) is not None
147162
and getattr(video_stream, "time_base", None) is not None
148163
and video_stream.average_rate
149164
):
150-
duration_seconds = float(video_stream.duration * video_stream.time_base)
165+
raw_duration = float(video_stream.duration * video_stream.time_base)
166+
if self.__start_time < 0:
167+
duration_from_start = min(raw_duration, -self.__start_time)
168+
else:
169+
duration_from_start = raw_duration - self.__start_time
170+
duration_seconds = min(self.__duration, duration_from_start)
151171
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
152172
if estimated_frames > 0:
153173
return estimated_frames
154174

155175
# 3. Last resort: decode frames and count them (streaming)
156-
frame_count = 0
157-
container.seek(0)
158-
for packet in container.demux(video_stream):
159-
for _ in packet.decode():
160-
frame_count += 1
161-
162-
if frame_count == 0:
163-
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
176+
if self.__start_time < 0:
177+
start_time = max(self._get_raw_duration() + self.__start_time, 0)
178+
else:
179+
start_time = self.__start_time
180+
frame_count = 1
181+
start_pts = int(start_time / video_stream.time_base)
182+
end_pts = int((start_time + self.__duration) / video_stream.time_base)
183+
container.seek(start_pts, stream=video_stream)
184+
frame_iterator = (
185+
container.decode(video_stream)
186+
if video_stream.codec.capabilities & 0x100
187+
else container.demux(video_stream)
188+
)
189+
for frame in frame_iterator:
190+
if frame.pts >= start_pts:
191+
break
192+
else:
193+
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
194+
for frame in frame_iterator:
195+
if frame.pts >= end_pts:
196+
break
197+
frame_count += 1
164198
return frame_count
165199

166200
def get_frame_rate(self) -> Fraction:
@@ -199,41 +233,66 @@ def get_container_format(self) -> str:
199233
return container.format.name
200234

201235
def get_components_internal(self, container: InputContainer) -> VideoComponents:
236+
video_stream = self._get_first_video_stream(container)
237+
if self.__start_time < 0:
238+
start_time = max(self._get_raw_duration() + self.__start_time, 0)
239+
else:
240+
start_time = self.__start_time
202241
# Get video frames
203242
frames = []
204-
for frame in container.decode(video=0):
243+
start_pts = int(start_time / video_stream.time_base)
244+
end_pts = int((start_time + self.__duration) / video_stream.time_base)
245+
container.seek(start_pts, stream=video_stream)
246+
for frame in container.decode(video_stream):
247+
if frame.pts < start_pts:
248+
continue
249+
if self.__duration and frame.pts >= end_pts:
250+
break
205251
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
206252
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
207253
frames.append(img)
208254

209255
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
210256

211257
# Get frame rate
212-
video_stream = next(s for s in container.streams if s.type == 'video')
213-
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
258+
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
214259

215260
# Get audio if available
216261
audio = None
217-
try:
218-
container.seek(0) # Reset the container to the beginning
219-
for stream in container.streams:
220-
if stream.type != 'audio':
221-
continue
222-
assert isinstance(stream, av.AudioStream)
223-
audio_frames = []
224-
for packet in container.demux(stream):
225-
for frame in packet.decode():
226-
assert isinstance(frame, av.AudioFrame)
227-
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
228-
if len(audio_frames) > 0:
229-
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
230-
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
231-
audio = AudioInput({
232-
"waveform": audio_tensor,
233-
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
234-
})
235-
except StopIteration:
236-
pass # No audio stream
262+
container.seek(start_pts, stream=video_stream)
263+
# Use last stream for consistency
264+
if len(container.streams.audio):
265+
audio_stream = container.streams.audio[-1]
266+
audio_frames = []
267+
resample = av.audio.resampler.AudioResampler(format='fltp').resample
268+
frames = itertools.chain.from_iterable(
269+
map(resample, container.decode(audio_stream))
270+
)
271+
272+
has_first_frame = False
273+
for frame in frames:
274+
offset_seconds = start_time - frame.pts * audio_stream.time_base
275+
to_skip = int(offset_seconds * audio_stream.sample_rate)
276+
if to_skip < frame.samples:
277+
has_first_frame = True
278+
break
279+
if has_first_frame:
280+
audio_frames.append(frame.to_ndarray()[..., to_skip:])
281+
282+
for frame in frames:
283+
if frame.time > start_time + self.__duration:
284+
break
285+
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
286+
if len(audio_frames) > 0:
287+
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
288+
if self.__duration:
289+
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
290+
291+
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
292+
audio = AudioInput({
293+
"waveform": audio_tensor,
294+
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
295+
})
237296

238297
metadata = container.metadata
239298
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ def save_to(
250309
path: str | io.BytesIO,
251310
format: VideoContainer = VideoContainer.AUTO,
252311
codec: VideoCodec = VideoCodec.AUTO,
253-
metadata: Optional[dict] = None
312+
metadata: Optional[dict] = None,
254313
):
255314
if isinstance(self.__file, io.BytesIO):
256315
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ def save_to(
262321
reuse_streams = False
263322
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
264323
reuse_streams = False
324+
if self.__start_time or self.__duration:
325+
reuse_streams = False
265326

266327
if not reuse_streams:
267328
components = self.get_components_internal(container)
268329
video = VideoFromComponents(components)
269330
return video.save_to(
270-
path,
271-
format=format,
272-
codec=codec,
273-
metadata=metadata
331+
path, format=format, codec=codec, metadata=metadata
274332
)
275333

276334
streams = container.streams
@@ -304,10 +362,21 @@ def save_to(
304362
output_container.mux(packet)
305363

306364
def _get_first_video_stream(self, container: InputContainer):
307-
video_stream = next((s for s in container.streams if s.type == "video"), None)
308-
if video_stream is None:
309-
raise ValueError(f"No video stream found in file '{self.__file}'")
310-
return video_stream
365+
if len(container.streams.video):
366+
return container.streams.video[0]
367+
raise ValueError(f"No video stream found in file '{self.__file}'")
368+
369+
def as_trimmed(
370+
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
371+
) -> VideoInput | None:
372+
trimmed = VideoFromFile(
373+
self.get_stream_source(),
374+
start_time=start_time + self.__start_time,
375+
duration=duration,
376+
)
377+
if trimmed.get_duration() < duration and strict_duration:
378+
return None
379+
return trimmed
311380

312381

313382
class VideoFromComponents(VideoInput):
@@ -322,15 +391,15 @@ def get_components(self) -> VideoComponents:
322391
return VideoComponents(
323392
images=self.__components.images,
324393
audio=self.__components.audio,
325-
frame_rate=self.__components.frame_rate
394+
frame_rate=self.__components.frame_rate,
326395
)
327396

328397
def save_to(
329398
self,
330399
path: str,
331400
format: VideoContainer = VideoContainer.AUTO,
332401
codec: VideoCodec = VideoCodec.AUTO,
333-
metadata: Optional[dict] = None
402+
metadata: Optional[dict] = None,
334403
):
335404
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
336405
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ def save_to(
357426
audio_stream: Optional[av.AudioStream] = None
358427
if self.__components.audio:
359428
audio_sample_rate = int(self.__components.audio['sample_rate'])
360-
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
429+
waveform = self.__components.audio['waveform']
430+
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
431+
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
432+
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
361433

362434
# Encode video
363435
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ def save_to(
372444
output.mux(packet)
373445

374446
if audio_stream and self.__components.audio:
375-
waveform = self.__components.audio['waveform']
376-
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
377-
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
447+
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
378448
frame.sample_rate = audio_sample_rate
379449
frame.pts = 0
380450
output.mux(audio_stream.encode(frame))
381451

382452
# Flush encoder
383453
output.mux(audio_stream.encode(None))
454+
455+
def as_trimmed(
456+
self,
457+
start_time: float | None = None,
458+
duration: float | None = None,
459+
strict_duration: bool = True,
460+
) -> VideoInput | None:
461+
if self.get_duration() < start_time + duration:
462+
return None
463+
#TODO Consider tracking duration and trimming at time of save?
464+
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

0 commit comments

Comments
 (0)