Skip to content

[wip] Update VideoDecoder init #799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ void SingleStreamDecoder::addVideoStream(
if (seekMode_ == SeekMode::custom_frame_mappings) {
TORCH_CHECK(
customFrameMappings.has_value(),
"Please provide frame mappings when using custom_frame_mappings seek mode.");
"Missing frame mappings when custom_frame_mappings seek mode is set.");
readCustomFrameMappingsUpdateMetadataAndIndex(
streamIndex, customFrameMappings.value());
}
Expand Down
49 changes: 48 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import io
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -78,14 +79,27 @@ def __init__(
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
seek_mode: Literal["exact", "approximate", "custom_frame_mappings"] = "exact",
custom_frame_mappings: Optional[Union[bytes, bytearray, str]] = None,
):
allowed_seek_modes = ("exact", "approximate")
if seek_mode not in allowed_seek_modes:
raise ValueError(
f"Invalid seek mode ({seek_mode}). "
f"Supported values are {', '.join(allowed_seek_modes)}."
)
if custom_frame_mappings:
if seek_mode not in ("exact", "custom_frame_mappings"):
raise ValueError(
"While setting custom frame mappings, do not set `seek_mode`."
)
# Set seek mode to avoid exact mode scan
seek_mode = "custom_frame_mappings"
custom_frame_mappings_data = (
read_custom_frame_mappings(custom_frame_mappings)
if custom_frame_mappings is not None
else None
)

self._decoder = create_decoder(source=source, seek_mode=seek_mode)

Expand All @@ -108,6 +122,7 @@ def __init__(
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
custom_frame_mappings=custom_frame_mappings_data,
)

(
Expand Down Expand Up @@ -377,3 +392,35 @@ def _get_and_validate_stream_metadata(
end_stream_seconds,
num_frames,
)


def read_custom_frame_mappings(
custom_frame_mappings: Union[bytes, bytearray, str]
) -> tuple[Tensor, Tensor, Tensor]:
try:
if hasattr(custom_frame_mappings, "read"):
input_data = json.load(custom_frame_mappings)
else:
input_data = json.loads(custom_frame_mappings)
except json.JSONDecodeError:
raise ValueError(
"Invalid custom frame mappings. "
"It should be a valid JSON string or a JSON file object."
)
all_frames, is_key_frame, duration = zip(
*[
(float(frame["pts"]), frame["key_frame"], float(frame["duration"]))
for frame in input_data["frames"]
]
)
all_frames = Tensor(all_frames)
is_key_frame = Tensor(is_key_frame)
duration = Tensor(duration)
assert (
len(all_frames) == len(is_key_frame) == len(duration)
), "Mismatched lengths in frame index data"
return (
all_frames,
is_key_frame,
duration,
)
81 changes: 81 additions & 0 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import gc
import json
from functools import partial
from unittest.mock import patch

import numpy
Expand Down Expand Up @@ -1252,6 +1253,86 @@ def test_10bit_videos_cpu(self, asset):
decoder = VideoDecoder(asset.path)
decoder.get_frame_at(10)

def setup_frame_mappings(tmp_path: str, file: bool, stream_index: int):
json_path = tmp_path / "custom_frame_mappings.json"
custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index)
if file:
# Write the custom frame mappings to a JSON file
with open(json_path, "w") as f:
f.write(custom_frame_mappings)
return json_path
else:
# Return the custom frame mappings as a JSON string
return custom_frame_mappings

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("stream_index", [0, 3])
@pytest.mark.parametrize(
"method",
(
partial(setup_frame_mappings, file=True),
partial(setup_frame_mappings, file=False),
),
)
def test_custom_frame_mappings(self, tmp_path, device, stream_index, method):
custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index)
# Optionally open the custom frame mappings file if it is a file path
# or use a null context if it is a string.
with (
open(custom_frame_mappings, "r")
if hasattr(custom_frame_mappings, "read")
else contextlib.nullcontext()
) as custom_frame_mappings:
decoder = VideoDecoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
frame_0 = decoder.get_frame_at(0)
frame_5 = decoder.get_frame_at(5)
assert_frames_equal(
frame_0.data,
NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(
device
),
)
assert_frames_equal(
frame_5.data,
NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(
device
),
)
frames0_5 = decoder.get_frames_played_in_range(
frame_0.pts_seconds, frame_5.pts_seconds
)
assert_frames_equal(
frames0_5.data,
NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to(
device
),
)

decoder = VideoDecoder(
H265_VIDEO.path,
stream_index=0,
custom_frame_mappings=H265_VIDEO.generate_custom_frame_mappings(0),
)
ref_frame6 = H265_VIDEO.get_frame_data_by_index(5)
assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_custom_frame_mappings_init_fails(self, device):
# Init fails if "approximate" seek mode is used with custom frame mappings
with pytest.raises(ValueError, match="seek_mode"):
VideoDecoder(
NASA_VIDEO.path,
stream_index=3,
device=device,
seek_mode="approximate",
custom_frame_mappings=NASA_VIDEO.generate_custom_frame_mappings(3),
)


class TestAudioDecoder:
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def test_seek_mode_custom_frame_mappings_fails(self):
)
with pytest.raises(
RuntimeError,
match="Please provide frame mappings when using custom_frame_mappings seek mode.",
match="Missing frame mappings when custom_frame_mappings seek mode is set.",
):
add_video_stream(decoder, stream_index=0, custom_frame_mappings=None)

Expand Down
40 changes: 21 additions & 19 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,27 +236,29 @@ def get_custom_frame_mappings(
if stream_index is None:
stream_index = self.default_stream_index
if self._custom_frame_mappings_data.get(stream_index) is None:
self.generate_custom_frame_mappings(stream_index)
self.create_custom_frame_mappings(stream_index)
return self._custom_frame_mappings_data[stream_index]

def generate_custom_frame_mappings(self, stream_index: int) -> None:
result = json.loads(
subprocess.run(
[
"ffprobe",
"-i",
f"{self.path}",
"-select_streams",
f"{stream_index}",
"-show_frames",
"-of",
"json",
],
check=True,
capture_output=True,
text=True,
).stdout
)
def generate_custom_frame_mappings(self, stream_index: int) -> str:
result = subprocess.run(
[
"ffprobe",
"-i",
f"{self.path}",
"-select_streams",
f"{stream_index}",
"-show_frames",
"-of",
"json",
],
check=True,
capture_output=True,
text=True,
).stdout
return result

def create_custom_frame_mappings(self, stream_index: int) -> None:
result = json.loads(self.generate_custom_frame_mappings(stream_index))
all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]])
is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]])
duration = torch.tensor(
Expand Down
Loading