@@ -53,7 +53,9 @@ def fate(name, path="."):
53
53
class TestVideoApi :
54
54
@pytest .mark .skipif (av is None , reason = "PyAV unavailable" )
55
55
@pytest .mark .parametrize ("test_video" , test_videos .keys ())
56
- def test_frame_reading (self , test_video ):
56
+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
57
+ def test_frame_reading (self , test_video , backend ):
58
+ torchvision .set_video_backend (backend )
57
59
full_path = os .path .join (VIDEO_DIR , test_video )
58
60
with av .open (full_path ) as av_reader :
59
61
if av_reader .streams .video :
@@ -117,58 +119,70 @@ def test_frame_reading(self, test_video):
117
119
118
120
@pytest .mark .parametrize ("stream" , ["video" , "audio" ])
119
121
@pytest .mark .parametrize ("test_video" , test_videos .keys ())
120
- def test_frame_reading_mem_vs_file (self , test_video , stream ):
122
+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
123
+ def test_frame_reading_mem_vs_file (self , test_video , stream , backend ):
124
+ torchvision .set_video_backend (backend )
121
125
full_path = os .path .join (VIDEO_DIR , test_video )
122
126
123
- # Test video reading from file vs from memory
124
- vr_frames , vr_frames_mem = [], []
125
- vr_pts , vr_pts_mem = [], []
126
- # get vr frames
127
- video_reader = VideoReader (full_path , stream )
128
- for vr_frame in video_reader :
129
- vr_frames .append (vr_frame ["data" ])
130
- vr_pts .append (vr_frame ["pts" ])
131
-
132
- # get vr frames = read from memory
133
- f = open (full_path , "rb" )
134
- fbytes = f .read ()
135
- f .close ()
136
- video_reader_from_mem = VideoReader (fbytes , stream )
137
-
138
- for vr_frame_from_mem in video_reader_from_mem :
139
- vr_frames_mem .append (vr_frame_from_mem ["data" ])
140
- vr_pts_mem .append (vr_frame_from_mem ["pts" ])
141
-
142
- # same number of frames
143
- assert len (vr_frames ) == len (vr_frames_mem )
144
- assert len (vr_pts ) == len (vr_pts_mem )
145
-
146
- # compare the frames and ptss
147
- for i in range (len (vr_frames )):
148
- assert vr_pts [i ] == vr_pts_mem [i ]
149
- mean_delta = torch .mean (torch .abs (vr_frames [i ].float () - vr_frames_mem [i ].float ()))
150
- # on average the difference is very small and caused
151
- # by decoding (around 1%)
152
- # TODO: asses empirically how to set this? atm it's 1%
153
- # averaged over all frames
154
- assert mean_delta .item () < 2.55
155
-
156
- del vr_frames , vr_pts , vr_frames_mem , vr_pts_mem
127
+ reader = VideoReader (full_path )
128
+ reader_md = reader .get_metadata ()
129
+
130
+ if stream in reader_md :
131
+ # Test video reading from file vs from memory
132
+ vr_frames , vr_frames_mem = [], []
133
+ vr_pts , vr_pts_mem = [], []
134
+ # get vr frames
135
+ video_reader = VideoReader (full_path , stream )
136
+ for vr_frame in video_reader :
137
+ vr_frames .append (vr_frame ["data" ])
138
+ vr_pts .append (vr_frame ["pts" ])
139
+
140
+ # get vr frames = read from memory
141
+ f = open (full_path , "rb" )
142
+ fbytes = f .read ()
143
+ f .close ()
144
+ video_reader_from_mem = VideoReader (fbytes , stream )
145
+
146
+ for vr_frame_from_mem in video_reader_from_mem :
147
+ vr_frames_mem .append (vr_frame_from_mem ["data" ])
148
+ vr_pts_mem .append (vr_frame_from_mem ["pts" ])
149
+
150
+ # same number of frames
151
+ assert len (vr_frames ) == len (vr_frames_mem )
152
+ assert len (vr_pts ) == len (vr_pts_mem )
153
+
154
+ # compare the frames and ptss
155
+ for i in range (len (vr_frames )):
156
+ assert vr_pts [i ] == vr_pts_mem [i ]
157
+ mean_delta = torch .mean (torch .abs (vr_frames [i ].float () - vr_frames_mem [i ].float ()))
158
+ # on average the difference is very small and caused
159
+ # by decoding (around 1%)
160
+ # TODO: asses empirically how to set this? atm it's 1%
161
+ # averaged over all frames
162
+ assert mean_delta .item () < 2.55
163
+
164
+ del vr_frames , vr_pts , vr_frames_mem , vr_pts_mem
165
+ else :
166
+ del reader , reader_md
157
167
158
168
@pytest .mark .parametrize ("test_video,config" , test_videos .items ())
159
- def test_metadata (self , test_video , config ):
169
+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
170
+ def test_metadata (self , test_video , config , backend ):
160
171
"""
161
172
Test that the metadata returned via pyav corresponds to the one returned
162
173
by the new video decoder API
163
174
"""
175
+ torchvision .set_video_backend (backend )
164
176
full_path = os .path .join (VIDEO_DIR , test_video )
165
177
reader = VideoReader (full_path , "video" )
166
178
reader_md = reader .get_metadata ()
167
179
assert config .video_fps == approx (reader_md ["video" ]["fps" ][0 ], abs = 0.0001 )
168
180
assert config .duration == approx (reader_md ["video" ]["duration" ][0 ], abs = 0.5 )
169
181
170
182
@pytest .mark .parametrize ("test_video" , test_videos .keys ())
171
- def test_seek_start (self , test_video ):
183
+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
184
+ def test_seek_start (self , test_video , backend ):
185
+ torchvision .set_video_backend (backend )
172
186
full_path = os .path .join (VIDEO_DIR , test_video )
173
187
video_reader = VideoReader (full_path , "video" )
174
188
num_frames = 0
@@ -194,7 +208,9 @@ def test_seek_start(self, test_video):
194
208
assert start_num_frames == num_frames
195
209
196
210
@pytest .mark .parametrize ("test_video" , test_videos .keys ())
197
- def test_accurateseek_middle (self , test_video ):
211
+ @pytest .mark .parametrize ("backend" , ["video_reader" ])
212
+ def test_accurateseek_middle (self , test_video , backend ):
213
+ torchvision .set_video_backend (backend )
198
214
full_path = os .path .join (VIDEO_DIR , test_video )
199
215
stream = "video"
200
216
video_reader = VideoReader (full_path , stream )
@@ -233,7 +249,9 @@ def test_fate_suite(self):
233
249
234
250
@pytest .mark .skipif (av is None , reason = "PyAV unavailable" )
235
251
@pytest .mark .parametrize ("test_video,config" , test_videos .items ())
236
- def test_keyframe_reading (self , test_video , config ):
252
+ @pytest .mark .parametrize ("backend" , ["pyav" , "video_reader" ])
253
+ def test_keyframe_reading (self , test_video , config , backend ):
254
+ torchvision .set_video_backend (backend )
237
255
full_path = os .path .join (VIDEO_DIR , test_video )
238
256
239
257
av_reader = av .open (full_path )
0 commit comments