Skip to content

Commit 9e78fe2

Browse files
authoredAug 19, 2024
Add AVIF decoder (Part 1- this is not public or available yet) (#8596)
1 parent 0a0f34b commit 9e78fe2

File tree

11 files changed

+174
-1
lines changed

11 files changed

+174
-1
lines changed
 

‎CMakeLists.txt

+15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ option(WITH_JPEG "Enable features requiring LibJPEG." ON)
1111
# untested. Since building from cmake is very low pri anyway, this is OK. If
1212
# you're a user and you need this, please open an issue (and a PR!).
1313
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)
14+
# Same here
15+
option(WITH_AVIF "Enable features requiring LibAVIF." OFF)
1416

1517
if(WITH_CUDA)
1618
enable_language(CUDA)
@@ -41,6 +43,11 @@ if (WITH_WEBP)
4143
find_package(WEBP REQUIRED)
4244
endif()
4345

46+
if (WITH_AVIF)
47+
add_definitions(-DAVIF_FOUND)
48+
find_package(AVIF REQUIRED)
49+
endif()
50+
4451
function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
4552
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
4653
if(NOT "${old_flags}" STREQUAL "")
@@ -117,6 +124,10 @@ if (WITH_WEBP)
117124
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
118125
endif()
119126

127+
if (WITH_AVIF)
128+
target_link_libraries(${PROJECT_NAME} PRIVATE ${AVIF_LIBRARIES})
129+
endif()
130+
120131
set_target_properties(${PROJECT_NAME} PROPERTIES
121132
EXPORT_NAME TorchVision
122133
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
@@ -135,6 +146,10 @@ if (WITH_WEBP)
135146
include_directories(${WEBP_INCLUDE_DIRS})
136147
endif()
137148

149+
if (WITH_AVIF)
150+
include_directories(${AVIF_INCLUDE_DIRS})
151+
endif()
152+
138153
set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")
139154

140155
configure_package_config_file(cmake/TorchVisionConfig.cmake.in

‎setup.py

+17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
2020
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
2121
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
22+
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
2223
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
2324
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
2425
# Note: the GPU video decoding stuff used to be called "video codec", which
@@ -49,6 +50,7 @@
4950
print(f"{USE_PNG = }")
5051
print(f"{USE_JPEG = }")
5152
print(f"{USE_WEBP = }")
53+
print(f"{USE_AVIF = }")
5254
print(f"{USE_NVJPEG = }")
5355
print(f"{NVCC_FLAGS = }")
5456
print(f"{USE_CPU_VIDEO_DECODER = }")
@@ -332,6 +334,21 @@ def make_image_extension():
332334
else:
333335
warnings.warn("Building torchvision without WEBP support")
334336

337+
if USE_AVIF:
338+
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
339+
if avif_found:
340+
print("Building torchvision with AVIF support")
341+
print(f"{avif_include_dir = }")
342+
print(f"{avif_library_dir = }")
343+
if avif_include_dir is not None and avif_library_dir is not None:
344+
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
345+
include_dirs.append(avif_include_dir)
346+
library_dirs.append(avif_library_dir)
347+
libraries.append("avif")
348+
define_macros += [("AVIF_FOUND", 1)]
349+
else:
350+
warnings.warn("Building torchvision without AVIF support")
351+
335352
if USE_NVJPEG and torch.cuda.is_available():
336353
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()
337354

693 Bytes
Binary file not shown.

‎test/test_image.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
1515
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
1616
from torchvision.io.image import (
17+
_decode_avif,
1718
decode_gif,
1819
decode_image,
1920
decode_jpeg,
@@ -873,7 +874,7 @@ def test_decode_gif_webp_errors(decode_fun):
873874
decode_fun(encoded_data[::2])
874875
if decode_fun is decode_gif:
875876
expected_match = re.escape("DGifOpenFileName() failed - 103")
876-
else:
877+
elif decode_fun is decode_webp:
877878
expected_match = "WebPDecodeRGB failed."
878879
with pytest.raises(RuntimeError, match=expected_match):
879880
decode_fun(encoded_data)
@@ -890,5 +891,17 @@ def test_decode_webp(decode_fun, scripted):
890891
assert img[None].is_contiguous(memory_format=torch.channels_last)
891892

892893

894+
@pytest.mark.xfail(reason="AVIF support not enabled yet.")
895+
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
896+
@pytest.mark.parametrize("scripted", (False, True))
897+
def test_decode_avif(decode_fun, scripted):
898+
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
899+
if scripted:
900+
decode_fun = torch.jit.script(decode_fun)
901+
img = decode_fun(encoded_bytes)
902+
assert img.shape == (3, 100, 100)
903+
assert img[None].is_contiguous(memory_format=torch.channels_last)
904+
905+
893906
if __name__ == "__main__":
894907
pytest.main([__file__])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include "decode_avif.h"
2+
3+
#if AVIF_FOUND
4+
#include "avif/avif.h"
5+
#endif // AVIF_FOUND
6+
7+
namespace vision {
8+
namespace image {
9+
10+
#if !AVIF_FOUND
11+
torch::Tensor decode_avif(const torch::Tensor& data) {
12+
TORCH_CHECK(
13+
false, "decode_avif: torchvision not compiled with libavif support");
14+
}
15+
#else
16+
17+
// This normally comes from avif_cxx.h, but it's not always present when
18+
// installing libavif. So we just copy/paste it here.
19+
struct UniquePtrDeleter {
20+
void operator()(avifDecoder* decoder) const {
21+
avifDecoderDestroy(decoder);
22+
}
23+
};
24+
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;
25+
26+
torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
27+
// This is based on
28+
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
29+
// Refer there for more detail about what each function does, and which
30+
// structure/data is available after which call.
31+
32+
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
33+
TORCH_CHECK(
34+
encoded_data.dtype() == torch::kU8,
35+
"Input tensor must have uint8 data type, got ",
36+
encoded_data.dtype());
37+
TORCH_CHECK(
38+
encoded_data.dim() == 1,
39+
"Input tensor must be 1-dimensional, got ",
40+
encoded_data.dim(),
41+
" dims.");
42+
43+
DecoderPtr decoder(avifDecoderCreate());
44+
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");
45+
46+
auto result = AVIF_RESULT_UNKNOWN_ERROR;
47+
result = avifDecoderSetIOMemory(
48+
decoder.get(), encoded_data.data_ptr<uint8_t>(), encoded_data.numel());
49+
TORCH_CHECK(
50+
result == AVIF_RESULT_OK,
51+
"avifDecoderSetIOMemory failed:",
52+
avifResultToString(result));
53+
54+
result = avifDecoderParse(decoder.get());
55+
TORCH_CHECK(
56+
result == AVIF_RESULT_OK,
57+
"avifDecoderParse failed: ",
58+
avifResultToString(result));
59+
TORCH_CHECK(
60+
decoder->imageCount == 1, "Avif file contains more than one image");
61+
TORCH_CHECK(
62+
decoder->image->depth <= 8,
63+
"avif images with bitdepth > 8 are not supported");
64+
65+
result = avifDecoderNextImage(decoder.get());
66+
TORCH_CHECK(
67+
result == AVIF_RESULT_OK,
68+
"avifDecoderNextImage failed:",
69+
avifResultToString(result));
70+
71+
auto out = torch::empty(
72+
{decoder->image->height, decoder->image->width, 3}, torch::kUInt8);
73+
74+
avifRGBImage rgb;
75+
memset(&rgb, 0, sizeof(rgb));
76+
avifRGBImageSetDefaults(&rgb, decoder->image);
77+
rgb.format = AVIF_RGB_FORMAT_RGB;
78+
rgb.pixels = out.data_ptr<uint8_t>();
79+
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);
80+
81+
result = avifImageYUVToRGB(decoder->image, &rgb);
82+
TORCH_CHECK(
83+
result == AVIF_RESULT_OK,
84+
"avifImageYUVToRGB failed: ",
85+
avifResultToString(result));
86+
87+
return out.permute({2, 0, 1}); // return CHW, channels-last
88+
}
89+
#endif // AVIF_FOUND
90+
91+
} // namespace image
92+
} // namespace vision
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/types.h>
4+
5+
namespace vision {
6+
namespace image {
7+
8+
C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data);
9+
10+
} // namespace image
11+
} // namespace vision

‎torchvision/csrc/io/image/cpu/decode_image.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "decode_image.h"
22

3+
#include "decode_avif.h"
34
#include "decode_gif.h"
45
#include "decode_jpeg.h"
56
#include "decode_png.h"
@@ -48,6 +49,18 @@ torch::Tensor decode_image(
4849
return decode_gif(data);
4950
}
5051

52+
// We assume the signature of an avif file is
53+
// 0000 0020 6674 7970 6176 6966
54+
// xxxx xxxx f t y p a v i f
55+
// We only check for the "ftyp avif" part.
56+
// This is probably not perfect, but hopefully this should cover most files.
57+
const uint8_t avif_signature[8] = {
58+
0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif"
59+
TORCH_CHECK(data.numel() >= 12, err_msg);
60+
if ((memcmp(avif_signature, datap + 4, 8) == 0)) {
61+
return decode_avif(data);
62+
}
63+
5164
const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
5265
const uint8_t webp_signature_end[7] = {
5366
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"

‎torchvision/csrc/io/image/image.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ static auto registry =
2222
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
2323
&decode_jpeg)
2424
.op("image::decode_webp", &decode_webp)
25+
.op("image::decode_avif", &decode_avif)
2526
.op("image::encode_jpeg", &encode_jpeg)
2627
.op("image::read_file", &read_file)
2728
.op("image::write_file", &write_file)

‎torchvision/csrc/io/image/image.h

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "cpu/decode_avif.h"
34
#include "cpu/decode_gif.h"
45
#include "cpu/decode_image.h"
56
#include "cpu/decode_jpeg.h"

‎torchvision/io/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
"decode_image",
6262
"decode_jpeg",
6363
"decode_png",
64+
"decode_webp",
65+
"decode_gif",
6466
"encode_jpeg",
6567
"encode_png",
6668
"read_file",

‎torchvision/io/image.py

+8
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,11 @@ def decode_webp(
382382
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
383383
_log_api_usage_once(decode_webp)
384384
return torch.ops.image.decode_webp(input)
385+
386+
387+
def _decode_avif(
388+
input: torch.Tensor,
389+
) -> torch.Tensor:
390+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
391+
_log_api_usage_once(decode_webp)
392+
return torch.ops.image.decode_avif(input)

0 commit comments

Comments
 (0)
Please sign in to comment.