66from .._input import AudioInput , VideoInput
77import av
88import io
9+ import itertools
910import json
1011import numpy as np
1112import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
2930 formats = container_format .split ("," )
3031 return formats [0 ]
3132
32-
3333def get_open_write_kwargs (
3434 dest : str | io .BytesIO , container_format : str , to_format : str | None
3535) -> dict :
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
5757 Class representing video input from a file.
5858 """
5959
60- def __init__ (self , file : str | io .BytesIO ):
60+ def __init__ (self , file : str | io .BytesIO , * , start_time : float = 0 , duration : float = 0 ):
6161 """
6262 Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
6363 containing the file contents.
6464 """
6565 self .__file = file
66+ self .__start_time = start_time
67+ self .__duration = duration
6668
6769 def get_stream_source (self ) -> str | io .BytesIO :
6870 """
@@ -96,6 +98,16 @@ def get_duration(self) -> float:
9698 Returns:
9799 Duration in seconds
98100 """
101+ raw_duration = self ._get_raw_duration ()
102+ if self .__start_time < 0 :
103+ duration_from_start = min (raw_duration , - self .__start_time )
104+ else :
105+ duration_from_start = raw_duration - self .__start_time
106+ if self .__duration :
107+ return min (self .__duration , duration_from_start )
108+ return duration_from_start
109+
110+ def _get_raw_duration (self ) -> float :
99111 if isinstance (self .__file , io .BytesIO ):
100112 self .__file .seek (0 )
101113 with av .open (self .__file , mode = "r" ) as container :
@@ -113,9 +125,13 @@ def get_duration(self) -> float:
113125 if video_stream and video_stream .average_rate :
114126 frame_count = 0
115127 container .seek (0 )
116- for packet in container .demux (video_stream ):
117- for _ in packet .decode ():
118- frame_count += 1
128+ frame_iterator = (
129+ container .decode (video_stream )
130+ if video_stream .codec .capabilities & 0x100
131+ else container .demux (video_stream )
132+ )
133+ for packet in frame_iterator :
134+ frame_count += 1
119135 if frame_count > 0 :
120136 return float (frame_count / video_stream .average_rate )
121137
@@ -131,36 +147,54 @@ def get_frame_count(self) -> int:
131147
132148 with av .open (self .__file , mode = "r" ) as container :
133149 video_stream = self ._get_first_video_stream (container )
134- # 1. Prefer the frames field if available
135- if video_stream .frames and video_stream .frames > 0 :
150+ # 1. Prefer the frames field if available and usable
151+ if (
152+ video_stream .frames
153+ and video_stream .frames > 0
154+ and not self .__start_time
155+ and not self .__duration
156+ ):
136157 return int (video_stream .frames )
137158
138159 # 2. Try to estimate from duration and average_rate using only metadata
139- if container .duration is not None and video_stream .average_rate :
140- duration_seconds = float (container .duration / av .time_base )
141- estimated_frames = int (round (duration_seconds * float (video_stream .average_rate )))
142- if estimated_frames > 0 :
143- return estimated_frames
144-
145160 if (
146161 getattr (video_stream , "duration" , None ) is not None
147162 and getattr (video_stream , "time_base" , None ) is not None
148163 and video_stream .average_rate
149164 ):
150- duration_seconds = float (video_stream .duration * video_stream .time_base )
165+ raw_duration = float (video_stream .duration * video_stream .time_base )
166+ if self .__start_time < 0 :
167+ duration_from_start = min (raw_duration , - self .__start_time )
168+ else :
169+ duration_from_start = raw_duration - self .__start_time
170+ duration_seconds = min (self .__duration , duration_from_start )
151171 estimated_frames = int (round (duration_seconds * float (video_stream .average_rate )))
152172 if estimated_frames > 0 :
153173 return estimated_frames
154174
155175 # 3. Last resort: decode frames and count them (streaming)
156- frame_count = 0
157- container .seek (0 )
158- for packet in container .demux (video_stream ):
159- for _ in packet .decode ():
160- frame_count += 1
161-
162- if frame_count == 0 :
163- raise ValueError (f"Could not determine frame count for file '{ self .__file } '" )
176+ if self .__start_time < 0 :
177+ start_time = max (self ._get_raw_duration () + self .__start_time , 0 )
178+ else :
179+ start_time = self .__start_time
180+ frame_count = 1
181+ start_pts = int (start_time / video_stream .time_base )
182+ end_pts = int ((start_time + self .__duration ) / video_stream .time_base )
183+ container .seek (start_pts , stream = video_stream )
184+ frame_iterator = (
185+ container .decode (video_stream )
186+ if video_stream .codec .capabilities & 0x100
187+ else container .demux (video_stream )
188+ )
189+ for frame in frame_iterator :
190+ if frame .pts >= start_pts :
191+ break
192+ else :
193+ raise ValueError (f"Could not determine frame count for file '{ self .__file } '\n No frames exist for start_time { self .__start_time } " )
194+ for frame in frame_iterator :
195+ if frame .pts >= end_pts :
196+ break
197+ frame_count += 1
164198 return frame_count
165199
166200 def get_frame_rate (self ) -> Fraction :
@@ -199,41 +233,66 @@ def get_container_format(self) -> str:
199233 return container .format .name
200234
201235 def get_components_internal (self , container : InputContainer ) -> VideoComponents :
236+ video_stream = self ._get_first_video_stream (container )
237+ if self .__start_time < 0 :
238+ start_time = max (self ._get_raw_duration () + self .__start_time , 0 )
239+ else :
240+ start_time = self .__start_time
202241 # Get video frames
203242 frames = []
204- for frame in container .decode (video = 0 ):
243+ start_pts = int (start_time / video_stream .time_base )
244+ end_pts = int ((start_time + self .__duration ) / video_stream .time_base )
245+ container .seek (start_pts , stream = video_stream )
246+ for frame in container .decode (video_stream ):
247+ if frame .pts < start_pts :
248+ continue
249+ if self .__duration and frame .pts >= end_pts :
250+ break
205251 img = frame .to_ndarray (format = 'rgb24' ) # shape: (H, W, 3)
206252 img = torch .from_numpy (img ) / 255.0 # shape: (H, W, 3)
207253 frames .append (img )
208254
209255 images = torch .stack (frames ) if len (frames ) > 0 else torch .zeros (0 , 3 , 0 , 0 )
210256
211257 # Get frame rate
212- video_stream = next (s for s in container .streams if s .type == 'video' )
213- frame_rate = Fraction (video_stream .average_rate ) if video_stream and video_stream .average_rate else Fraction (1 )
258+ frame_rate = Fraction (video_stream .average_rate ) if video_stream .average_rate else Fraction (1 )
214259
215260 # Get audio if available
216261 audio = None
217- try :
218- container .seek (0 ) # Reset the container to the beginning
219- for stream in container .streams :
220- if stream .type != 'audio' :
221- continue
222- assert isinstance (stream , av .AudioStream )
223- audio_frames = []
224- for packet in container .demux (stream ):
225- for frame in packet .decode ():
226- assert isinstance (frame , av .AudioFrame )
227- audio_frames .append (frame .to_ndarray ()) # shape: (channels, samples)
228- if len (audio_frames ) > 0 :
229- audio_data = np .concatenate (audio_frames , axis = 1 ) # shape: (channels, total_samples)
230- audio_tensor = torch .from_numpy (audio_data ).unsqueeze (0 ) # shape: (1, channels, total_samples)
231- audio = AudioInput ({
232- "waveform" : audio_tensor ,
233- "sample_rate" : int (stream .sample_rate ) if stream .sample_rate else 1 ,
234- })
235- except StopIteration :
236- pass # No audio stream
262+ container .seek (start_pts , stream = video_stream )
263+ # Use last stream for consistency
264+ if len (container .streams .audio ):
265+ audio_stream = container .streams .audio [- 1 ]
266+ audio_frames = []
267+ resample = av .audio .resampler .AudioResampler (format = 'fltp' ).resample
268+ frames = itertools .chain .from_iterable (
269+ map (resample , container .decode (audio_stream ))
270+ )
271+
272+ has_first_frame = False
273+ for frame in frames :
274+ offset_seconds = start_time - frame .pts * audio_stream .time_base
275+ to_skip = int (offset_seconds * audio_stream .sample_rate )
276+ if to_skip < frame .samples :
277+ has_first_frame = True
278+ break
279+ if has_first_frame :
280+ audio_frames .append (frame .to_ndarray ()[..., to_skip :])
281+
282+ for frame in frames :
283+ if frame .time > start_time + self .__duration :
284+ break
285+ audio_frames .append (frame .to_ndarray ()) # shape: (channels, samples)
286+ if len (audio_frames ) > 0 :
287+ audio_data = np .concatenate (audio_frames , axis = 1 ) # shape: (channels, total_samples)
288+ if self .__duration :
289+ audio_data = audio_data [..., :int (self .__duration * audio_stream .sample_rate )]
290+
291+ audio_tensor = torch .from_numpy (audio_data ).unsqueeze (0 ) # shape: (1, channels, total_samples)
292+ audio = AudioInput ({
293+ "waveform" : audio_tensor ,
294+ "sample_rate" : int (audio_stream .sample_rate ) if audio_stream .sample_rate else 1 ,
295+ })
237296
238297 metadata = container .metadata
239298 return VideoComponents (images = images , audio = audio , frame_rate = frame_rate , metadata = metadata )
@@ -250,7 +309,7 @@ def save_to(
250309 path : str | io .BytesIO ,
251310 format : VideoContainer = VideoContainer .AUTO ,
252311 codec : VideoCodec = VideoCodec .AUTO ,
253- metadata : Optional [dict ] = None
312+ metadata : Optional [dict ] = None ,
254313 ):
255314 if isinstance (self .__file , io .BytesIO ):
256315 self .__file .seek (0 ) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ def save_to(
262321 reuse_streams = False
263322 if codec != VideoCodec .AUTO and codec != video_encoding and video_encoding is not None :
264323 reuse_streams = False
324+ if self .__start_time or self .__duration :
325+ reuse_streams = False
265326
266327 if not reuse_streams :
267328 components = self .get_components_internal (container )
268329 video = VideoFromComponents (components )
269330 return video .save_to (
270- path ,
271- format = format ,
272- codec = codec ,
273- metadata = metadata
331+ path , format = format , codec = codec , metadata = metadata
274332 )
275333
276334 streams = container .streams
@@ -304,10 +362,21 @@ def save_to(
304362 output_container .mux (packet )
305363
306364 def _get_first_video_stream (self , container : InputContainer ):
307- video_stream = next ((s for s in container .streams if s .type == "video" ), None )
308- if video_stream is None :
309- raise ValueError (f"No video stream found in file '{ self .__file } '" )
310- return video_stream
365+ if len (container .streams .video ):
366+ return container .streams .video [0 ]
367+ raise ValueError (f"No video stream found in file '{ self .__file } '" )
368+
369+ def as_trimmed (
370+ self , start_time : float = 0 , duration : float = 0 , strict_duration : bool = True
371+ ) -> VideoInput | None :
372+ trimmed = VideoFromFile (
373+ self .get_stream_source (),
374+ start_time = start_time + self .__start_time ,
375+ duration = duration ,
376+ )
377+ if trimmed .get_duration () < duration and strict_duration :
378+ return None
379+ return trimmed
311380
312381
313382class VideoFromComponents (VideoInput ):
@@ -322,15 +391,15 @@ def get_components(self) -> VideoComponents:
322391 return VideoComponents (
323392 images = self .__components .images ,
324393 audio = self .__components .audio ,
325- frame_rate = self .__components .frame_rate
394+ frame_rate = self .__components .frame_rate ,
326395 )
327396
328397 def save_to (
329398 self ,
330399 path : str ,
331400 format : VideoContainer = VideoContainer .AUTO ,
332401 codec : VideoCodec = VideoCodec .AUTO ,
333- metadata : Optional [dict ] = None
402+ metadata : Optional [dict ] = None ,
334403 ):
335404 if format != VideoContainer .AUTO and format != VideoContainer .MP4 :
336405 raise ValueError ("Only MP4 format is supported for now" )
@@ -357,7 +426,10 @@ def save_to(
357426 audio_stream : Optional [av .AudioStream ] = None
358427 if self .__components .audio :
359428 audio_sample_rate = int (self .__components .audio ['sample_rate' ])
360- audio_stream = output .add_stream ('aac' , rate = audio_sample_rate )
429+ waveform = self .__components .audio ['waveform' ]
430+ waveform = waveform [0 , :, :math .ceil ((audio_sample_rate / frame_rate ) * self .__components .images .shape [0 ])]
431+ layout = {1 : 'mono' , 2 : 'stereo' , 6 : '5.1' }.get (waveform .shape [0 ], 'stereo' )
432+ audio_stream = output .add_stream ('aac' , rate = audio_sample_rate , layout = layout )
361433
362434 # Encode video
363435 for i , frame in enumerate (self .__components .images ):
@@ -372,12 +444,21 @@ def save_to(
372444 output .mux (packet )
373445
374446 if audio_stream and self .__components .audio :
375- waveform = self .__components .audio ['waveform' ]
376- waveform = waveform [:, :, :math .ceil ((audio_sample_rate / frame_rate ) * self .__components .images .shape [0 ])]
377- frame = av .AudioFrame .from_ndarray (waveform .movedim (2 , 1 ).reshape (1 , - 1 ).float ().cpu ().numpy (), format = 'flt' , layout = 'mono' if waveform .shape [1 ] == 1 else 'stereo' )
447+ frame = av .AudioFrame .from_ndarray (waveform .float ().cpu ().numpy (), format = 'fltp' , layout = layout )
378448 frame .sample_rate = audio_sample_rate
379449 frame .pts = 0
380450 output .mux (audio_stream .encode (frame ))
381451
382452 # Flush encoder
383453 output .mux (audio_stream .encode (None ))
454+
455+ def as_trimmed (
456+ self ,
457+ start_time : float | None = None ,
458+ duration : float | None = None ,
459+ strict_duration : bool = True ,
460+ ) -> VideoInput | None :
461+ if self .get_duration () < start_time + duration :
462+ return None
463+ #TODO Consider tracking duration and trimming at time of save?
464+ return VideoFromFile (self .get_stream_source (), start_time = start_time , duration = duration )
0 commit comments