Skip to content

Commit d710f3d

Browse files
bjuncekBruno Korbarjdsgomes
authored
Bkorbar/pyavapi (#6943)
* Test: add backend parameter * VideoReader object now works on backend * Frame reading now passes * Keyframe seek now passes * Pyav backend now supports metadata * changes in test to reflect GPU decoder change * Linter? * Test GPU output * Addressing Joao's comments * lint * lint * Revert "Test GPU output" This reverts commit f62e955. * lint? * lint * lint * Address issues in build? * hopefully doc fix * Arrgh * arrgh * fix typos * fix input options * remove read from memory option in pyav * skip read from mem test for gpu and pyab be * fix test * remove unused import * Hack to get reading from memory work with pyav * patch audio test * gallery change in a hope that docs won't break * check video decoder inside io * adding missing lib loading code * remove unused input Co-authored-by: Bruno Korbar <[email protected]> Co-authored-by: Joao Gomes <[email protected]>
1 parent b1054cb commit d710f3d

File tree

8 files changed

+210
-105
lines changed

8 files changed

+210
-105
lines changed

gallery/plot_video_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
import torchvision
3434
from torchvision.datasets.utils import download_url
35+
torchvision.set_video_backend("video_reader")
3536

3637
# Download the sample video
3738
download_url(

test/test_video_gpu_decoder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
import torchvision
67
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
78

89
try:
@@ -29,8 +30,9 @@ class TestVideoGPUDecoder:
2930
],
3031
)
3132
def test_frame_reading(self, video_file):
33+
torchvision.set_video_backend("cuda")
3234
full_path = os.path.join(VIDEO_DIR, video_file)
33-
decoder = VideoReader(full_path, device="cuda")
35+
decoder = VideoReader(full_path)
3436
with av.open(full_path) as container:
3537
for av_frame in container.decode(container.streams.video[0]):
3638
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
@@ -54,7 +56,8 @@ def test_frame_reading(self, video_file):
5456
],
5557
)
5658
def test_seek_reading(self, keyframes, full_path, duration):
57-
decoder = VideoReader(full_path, device="cuda")
59+
torchvision.set_video_backend("cuda")
60+
decoder = VideoReader(full_path)
5861
time = duration / 2
5962
decoder.seek(time, keyframes_only=keyframes)
6063
with av.open(full_path) as container:
@@ -79,8 +82,9 @@ def test_seek_reading(self, keyframes, full_path, duration):
7982
],
8083
)
8184
def test_metadata(self, video_file):
85+
torchvision.set_video_backend("cuda")
8286
full_path = os.path.join(VIDEO_DIR, video_file)
83-
decoder = VideoReader(full_path, device="cuda")
87+
decoder = VideoReader(full_path)
8488
video_metadata = decoder.get_metadata()["video"]
8589
with av.open(full_path) as container:
8690
video = container.streams.video[0]

test/test_videoapi.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def fate(name, path="."):
5353
class TestVideoApi:
5454
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
5555
@pytest.mark.parametrize("test_video", test_videos.keys())
56-
def test_frame_reading(self, test_video):
56+
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
57+
def test_frame_reading(self, test_video, backend):
58+
torchvision.set_video_backend(backend)
5759
full_path = os.path.join(VIDEO_DIR, test_video)
5860
with av.open(full_path) as av_reader:
5961
if av_reader.streams.video:
@@ -117,58 +119,70 @@ def test_frame_reading(self, test_video):
117119

118120
@pytest.mark.parametrize("stream", ["video", "audio"])
119121
@pytest.mark.parametrize("test_video", test_videos.keys())
120-
def test_frame_reading_mem_vs_file(self, test_video, stream):
122+
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
123+
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
124+
torchvision.set_video_backend(backend)
121125
full_path = os.path.join(VIDEO_DIR, test_video)
122126

123-
# Test video reading from file vs from memory
124-
vr_frames, vr_frames_mem = [], []
125-
vr_pts, vr_pts_mem = [], []
126-
# get vr frames
127-
video_reader = VideoReader(full_path, stream)
128-
for vr_frame in video_reader:
129-
vr_frames.append(vr_frame["data"])
130-
vr_pts.append(vr_frame["pts"])
131-
132-
# get vr frames = read from memory
133-
f = open(full_path, "rb")
134-
fbytes = f.read()
135-
f.close()
136-
video_reader_from_mem = VideoReader(fbytes, stream)
137-
138-
for vr_frame_from_mem in video_reader_from_mem:
139-
vr_frames_mem.append(vr_frame_from_mem["data"])
140-
vr_pts_mem.append(vr_frame_from_mem["pts"])
141-
142-
# same number of frames
143-
assert len(vr_frames) == len(vr_frames_mem)
144-
assert len(vr_pts) == len(vr_pts_mem)
145-
146-
# compare the frames and ptss
147-
for i in range(len(vr_frames)):
148-
assert vr_pts[i] == vr_pts_mem[i]
149-
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
150-
# on average the difference is very small and caused
151-
# by decoding (around 1%)
152-
# TODO: asses empirically how to set this? atm it's 1%
153-
# averaged over all frames
154-
assert mean_delta.item() < 2.55
155-
156-
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
127+
reader = VideoReader(full_path)
128+
reader_md = reader.get_metadata()
129+
130+
if stream in reader_md:
131+
# Test video reading from file vs from memory
132+
vr_frames, vr_frames_mem = [], []
133+
vr_pts, vr_pts_mem = [], []
134+
# get vr frames
135+
video_reader = VideoReader(full_path, stream)
136+
for vr_frame in video_reader:
137+
vr_frames.append(vr_frame["data"])
138+
vr_pts.append(vr_frame["pts"])
139+
140+
# get vr frames = read from memory
141+
f = open(full_path, "rb")
142+
fbytes = f.read()
143+
f.close()
144+
video_reader_from_mem = VideoReader(fbytes, stream)
145+
146+
for vr_frame_from_mem in video_reader_from_mem:
147+
vr_frames_mem.append(vr_frame_from_mem["data"])
148+
vr_pts_mem.append(vr_frame_from_mem["pts"])
149+
150+
# same number of frames
151+
assert len(vr_frames) == len(vr_frames_mem)
152+
assert len(vr_pts) == len(vr_pts_mem)
153+
154+
# compare the frames and ptss
155+
for i in range(len(vr_frames)):
156+
assert vr_pts[i] == vr_pts_mem[i]
157+
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
158+
# on average the difference is very small and caused
159+
# by decoding (around 1%)
160+
# TODO: asses empirically how to set this? atm it's 1%
161+
# averaged over all frames
162+
assert mean_delta.item() < 2.55
163+
164+
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
165+
else:
166+
del reader, reader_md
157167

158168
@pytest.mark.parametrize("test_video,config", test_videos.items())
159-
def test_metadata(self, test_video, config):
169+
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
170+
def test_metadata(self, test_video, config, backend):
160171
"""
161172
Test that the metadata returned via pyav corresponds to the one returned
162173
by the new video decoder API
163174
"""
175+
torchvision.set_video_backend(backend)
164176
full_path = os.path.join(VIDEO_DIR, test_video)
165177
reader = VideoReader(full_path, "video")
166178
reader_md = reader.get_metadata()
167179
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
168180
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
169181

170182
@pytest.mark.parametrize("test_video", test_videos.keys())
171-
def test_seek_start(self, test_video):
183+
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
184+
def test_seek_start(self, test_video, backend):
185+
torchvision.set_video_backend(backend)
172186
full_path = os.path.join(VIDEO_DIR, test_video)
173187
video_reader = VideoReader(full_path, "video")
174188
num_frames = 0
@@ -194,7 +208,9 @@ def test_seek_start(self, test_video):
194208
assert start_num_frames == num_frames
195209

196210
@pytest.mark.parametrize("test_video", test_videos.keys())
197-
def test_accurateseek_middle(self, test_video):
211+
@pytest.mark.parametrize("backend", ["video_reader"])
212+
def test_accurateseek_middle(self, test_video, backend):
213+
torchvision.set_video_backend(backend)
198214
full_path = os.path.join(VIDEO_DIR, test_video)
199215
stream = "video"
200216
video_reader = VideoReader(full_path, stream)
@@ -233,7 +249,9 @@ def test_fate_suite(self):
233249

234250
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
235251
@pytest.mark.parametrize("test_video,config", test_videos.items())
236-
def test_keyframe_reading(self, test_video, config):
252+
@pytest.mark.parametrize("backend", ["pyav", "video_reader"])
253+
def test_keyframe_reading(self, test_video, config, backend):
254+
torchvision.set_video_backend(backend)
237255
full_path = os.path.join(VIDEO_DIR, test_video)
238256

239257
av_reader = av.open(full_path)

torchvision/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import warnings
3+
from modulefinder import Module
34

45
import torch
56
from torchvision import datasets, io, models, ops, transforms, utils
@@ -11,6 +12,7 @@
1112
except ImportError:
1213
pass
1314

15+
1416
# Check if torchvision is being imported within the root folder
1517
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
1618
os.path.realpath(os.getcwd()), "torchvision"
@@ -66,11 +68,16 @@ def set_video_backend(backend):
6668
backend, please compile torchvision from source.
6769
"""
6870
global _video_backend
69-
if backend not in ["pyav", "video_reader"]:
70-
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
71+
if backend not in ["pyav", "video_reader", "cuda"]:
72+
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
7173
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
74+
# TODO: better messages
7275
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
73-
warnings.warn(message)
76+
raise RuntimeError(message)
77+
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
78+
# TODO: better messages
79+
message = "cuda video backend is not available."
80+
raise RuntimeError(message)
7481
else:
7582
_video_backend = backend
7683

torchvision/io/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
from ..utils import _log_api_usage_once
66

7-
try:
8-
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
9-
except ModuleNotFoundError:
10-
_HAS_GPU_VIDEO_DECODER = False
117
from ._video_opt import (
128
_HAS_VIDEO_OPT,
139
_probe_video_from_file,
@@ -32,7 +28,7 @@
3228
write_jpeg,
3329
write_png,
3430
)
35-
from .video import read_video, read_video_timestamps, write_video
31+
from .video import _HAS_GPU_VIDEO_DECODER, read_video, read_video_timestamps, write_video
3632
from .video_reader import VideoReader
3733

3834

torchvision/io/_load_gpu_decoder.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

torchvision/io/video.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99
import numpy as np
1010
import torch
1111

12+
from ..extension import _load_library
13+
1214
from ..utils import _log_api_usage_once
1315
from . import _video_opt
1416

17+
try:
18+
_load_library("Decoder")
19+
_HAS_GPU_VIDEO_DECODER = True
20+
except (ImportError, OSError, ModuleNotFoundError):
21+
_HAS_GPU_VIDEO_DECODER = False
1522

1623
try:
1724
import av

0 commit comments

Comments
 (0)