Skip to content

Commit be7cdf1

Browse files
authored
Add webp decoder (#8527)
1 parent edb1c33 commit be7cdf1

File tree

18 files changed

+212
-42
lines changed

18 files changed

+212
-42
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/sh
2+
3+
export IS_M1_CONDA_BUILD_JOB=1

.github/scripts/setup-env.sh

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ conda create \
3030
python="${PYTHON_VERSION}" pip \
3131
ninja cmake \
3232
libpng \
33+
libwebp \
3334
'ffmpeg<4.3'
3435
conda activate ci
3536
conda install --quiet --yes libjpeg-turbo -c pytorch

.github/workflows/build-conda-m1.yml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
test-infra-repository: pytorch/test-infra
4343
test-infra-ref: main
4444
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
45+
env-var-script: ./.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh
4546
pre-script: ${{ matrix.pre-script }}
4647
post-script: ${{ matrix.post-script }}
4748
package-name: ${{ matrix.package-name }}

CMakeLists.txt

+17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ option(WITH_CUDA "Enable CUDA support" OFF)
77
option(WITH_MPS "Enable MPS support" OFF)
88
option(WITH_PNG "Enable features requiring LibPNG." ON)
99
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
10+
# Libwebp is disabled by default, which means enabling it from cmake is largely
11+
# untested. Since building from cmake is very low pri anyway, this is OK. If
12+
# you're a user and you need this, please open an issue (and a PR!).
13+
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)
1014

1115
if(WITH_CUDA)
1216
enable_language(CUDA)
@@ -32,6 +36,11 @@ if (WITH_JPEG)
3236
find_package(JPEG REQUIRED)
3337
endif()
3438

39+
if (WITH_WEBP)
40+
add_definitions(-DWEBP_FOUND)
41+
find_package(WEBP REQUIRED)
42+
endif()
43+
3544
function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
3645
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
3746
if(NOT "${old_flags}" STREQUAL "")
@@ -104,6 +113,10 @@ if (WITH_JPEG)
104113
target_link_libraries(${PROJECT_NAME} PRIVATE ${JPEG_LIBRARIES})
105114
endif()
106115

116+
if (WITH_WEBP)
117+
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
118+
endif()
119+
107120
set_target_properties(${PROJECT_NAME} PROPERTIES
108121
EXPORT_NAME TorchVision
109122
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
@@ -118,6 +131,10 @@ if (WITH_JPEG)
118131
include_directories(${JPEG_INCLUDE_DIRS})
119132
endif()
120133

134+
if (WITH_WEBP)
135+
include_directories(${WEBP_INCLUDE_DIRS})
136+
endif()
137+
121138
set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")
122139

123140
configure_package_config_file(cmake/TorchVisionConfig.cmake.in

docs/source/io.rst

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ videos.
1010
Images
1111
------
1212

13+
Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG
14+
decoding can also be done on CUDA GPUs.
15+
16+
For encoding, JPEG (cpu and CUDA) and PNG are supported.
17+
1318
.. autosummary::
1419
:toctree: generated/
1520
:template: function.rst
@@ -20,6 +25,7 @@ Images
2025
decode_jpeg
2126
write_jpeg
2227
decode_gif
28+
decode_webp
2329
encode_png
2430
decode_png
2531
write_png

packaging/pre_build_script.sh

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
23
if [[ "$(uname)" == Darwin ]]; then
34
# Uninstall Conflicting jpeg brew formulae
45
jpeg_packages=$(brew list | grep jpeg)
@@ -12,8 +13,10 @@ if [[ "$(uname)" == Darwin ]]; then
1213
fi
1314

1415
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
15-
# Install libpng from Anaconda (defaults)
16-
conda install libpng -yq
16+
conda install libpng libwebp -yq
17+
# Installing webp also installs a non-turbo jpeg, so we uninstall jpeg stuff
18+
# before re-installing them
19+
conda uninstall libjpeg-turbo libjpeg -y
1720
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch
1821

1922
# Copy binaries to be included in the wheel distribution
@@ -29,7 +32,7 @@ else
2932
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly
3033
fi
3134

32-
yum install -y libjpeg-turbo-devel freetype gnutls
35+
yum install -y libjpeg-turbo-devel libwebp-devel freetype gnutls
3336
pip install auditwheel
3437
fi
3538

packaging/torchvision/meta.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ requirements:
1111
- {{ compiler('c') }} # [win]
1212
- libpng
1313
- libjpeg-turbo
14+
- libwebp
1415
- ffmpeg >=4.2.2, <5.0.0 # [linux]
1516

1617
host:
@@ -28,6 +29,7 @@ requirements:
2829
- libpng
2930
- ffmpeg >=4.2.2, <5.0.0 # [linux]
3031
- libjpeg-turbo
32+
- libwebp
3133
- pillow >=5.3.0, !=8.3.*
3234
- pytorch-mutex 1.0 {{ build_variant }} # [not osx ]
3335
{{ environ.get('CONDA_PYTORCH_CONSTRAINT', 'pytorch') }}

setup.py

+18
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DEBUG = os.getenv("DEBUG", "0") == "1"
1919
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
2020
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
21+
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
2122
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
2223
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
2324
USE_FFMPEG = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
@@ -41,6 +42,7 @@
4142
print(f"{DEBUG = }")
4243
print(f"{USE_PNG = }")
4344
print(f"{USE_JPEG = }")
45+
print(f"{USE_WEBP = }")
4446
print(f"{USE_NVJPEG = }")
4547
print(f"{NVCC_FLAGS = }")
4648
print(f"{USE_FFMPEG = }")
@@ -308,6 +310,22 @@ def make_image_extension():
308310
else:
309311
warnings.warn("Building torchvision without JPEG support")
310312

313+
if USE_WEBP:
314+
webp_found, webp_include_dir, webp_library_dir = find_library(header="webp/decode.h")
315+
if webp_found:
316+
print("Building torchvision with WEBP support")
317+
print(f"{webp_include_dir = }")
318+
print(f"{webp_library_dir = }")
319+
if webp_include_dir is not None and webp_library_dir is not None:
320+
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
321+
include_dirs.append(webp_include_dir)
322+
library_dirs.append(webp_library_dir)
323+
webp_library = "libwebp" if sys.platform == "win32" else "webp"
324+
libraries.append(webp_library)
325+
define_macros += [("WEBP_FOUND", 1)]
326+
else:
327+
warnings.warn("Building torchvision without WEBP support")
328+
311329
if USE_NVJPEG and torch.cuda.is_available():
312330
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()
313331

952 Bytes
Binary file not shown.

test/smoke_test.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Run smoke tests"""
22

3+
import os
34
import sys
45
from pathlib import Path
56

67
import torch
78
import torchvision
8-
from torchvision.io import decode_jpeg, read_file, read_image
9+
from torchvision.io import decode_jpeg, decode_webp, read_file, read_image
910
from torchvision.models import resnet50, ResNet50_Weights
1011

12+
1113
SCRIPT_DIR = Path(__file__).parent
1214

1315

@@ -25,6 +27,9 @@ def smoke_test_torchvision_read_decode() -> None:
2527
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
2628
if img_png.shape != (4, 471, 354):
2729
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
30+
img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
31+
if img_webp.shape != (3, 100, 100):
32+
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")
2833

2934

3035
def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
@@ -77,11 +82,16 @@ def main() -> None:
7782
print(f"torchvision: {torchvision.__version__}")
7883
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
7984

80-
# Turn 1.11.0aHASH into 1.11 (major.minor only)
81-
version = ".".join(torchvision.__version__.split(".")[:2])
82-
if version >= "0.16":
83-
print(f"{torch.ops.image._jpeg_version() = }")
84-
assert torch.ops.image._is_compiled_against_turbo()
85+
print(f"{torch.ops.image._jpeg_version() = }")
86+
if not torch.ops.image._is_compiled_against_turbo():
87+
msg = "Torchvision wasn't compiled against libjpeg-turbo"
88+
if os.getenv("IS_M1_CONDA_BUILD_JOB") == "1":
89+
# When building the conda package on M1, it's difficult to enforce
90+
# that we build against turbo due to interactions with the libwebp
91+
# package. So we just accept it, instead of raising an error.
92+
print(msg)
93+
else:
94+
raise ValueError(msg)
8595

8696
smoke_test_torchvision()
8797
smoke_test_torchvision_read_decode()

test/test_image.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
decode_image,
1919
decode_jpeg,
2020
decode_png,
21+
decode_webp,
2122
encode_jpeg,
2223
encode_png,
2324
ImageReadMode,
@@ -861,16 +862,32 @@ def test_decode_gif(tmpdir, name, scripted):
861862
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)
862863

863864

864-
def test_decode_gif_errors():
865+
@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp))
866+
def test_decode_gif_webp_errors(decode_fun):
865867
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
866868
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
867-
decode_gif(encoded_data[None])
869+
decode_fun(encoded_data[None])
868870
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
869-
decode_gif(encoded_data.float())
871+
decode_fun(encoded_data.float())
870872
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
871-
decode_gif(encoded_data[::2])
872-
with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")):
873-
decode_gif(encoded_data)
873+
decode_fun(encoded_data[::2])
874+
if decode_fun is decode_gif:
875+
expected_match = re.escape("DGifOpenFileName() failed - 103")
876+
else:
877+
expected_match = "WebPDecodeRGB failed."
878+
with pytest.raises(RuntimeError, match=expected_match):
879+
decode_fun(encoded_data)
880+
881+
882+
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
883+
@pytest.mark.parametrize("scripted", (False, True))
884+
def test_decode_webp(decode_fun, scripted):
885+
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
886+
if scripted:
887+
decode_fun = torch.jit.script(decode_fun)
888+
img = decode_fun(encoded_bytes)
889+
assert img.shape == (3, 100, 100)
890+
assert img[None].is_contiguous(memory_format=torch.channels_last)
874891

875892

876893
if __name__ == "__main__":

torchvision/csrc/io/image/cpu/decode_image.cpp

+27-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "decode_gif.h"
44
#include "decode_jpeg.h"
55
#include "decode_png.h"
6+
#include "decode_webp.h"
67

78
namespace vision {
89
namespace image {
@@ -20,29 +21,43 @@ torch::Tensor decode_image(
2021
data.dim() == 1 && data.numel() > 0,
2122
"Expected a non empty 1-dimensional tensor");
2223

24+
auto err_msg =
25+
"Unsupported image file. Only jpeg, png and gif are currently supported.";
26+
2327
auto datap = data.data_ptr<uint8_t>();
2428

2529
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
30+
TORCH_CHECK(data.numel() >= 3, err_msg);
31+
if (memcmp(jpeg_signature, datap, 3) == 0) {
32+
return decode_jpeg(data, mode, apply_exif_orientation);
33+
}
34+
2635
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
36+
TORCH_CHECK(data.numel() >= 4, err_msg);
37+
if (memcmp(png_signature, datap, 4) == 0) {
38+
return decode_png(data, mode, apply_exif_orientation);
39+
}
40+
2741
const uint8_t gif_signature_1[6] = {
2842
0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a"
2943
const uint8_t gif_signature_2[6] = {
3044
0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a"
31-
32-
if (memcmp(jpeg_signature, datap, 3) == 0) {
33-
return decode_jpeg(data, mode, apply_exif_orientation);
34-
} else if (memcmp(png_signature, datap, 4) == 0) {
35-
return decode_png(data, mode, apply_exif_orientation);
36-
} else if (
37-
memcmp(gif_signature_1, datap, 6) == 0 ||
45+
TORCH_CHECK(data.numel() >= 6, err_msg);
46+
if (memcmp(gif_signature_1, datap, 6) == 0 ||
3847
memcmp(gif_signature_2, datap, 6) == 0) {
3948
return decode_gif(data);
40-
} else {
41-
TORCH_CHECK(
42-
false,
43-
"Unsupported image file. Only jpeg, png and gif ",
44-
"are currently supported.");
4549
}
50+
51+
const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
52+
const uint8_t webp_signature_end[7] = {
53+
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"
54+
TORCH_CHECK(data.numel() >= 15, err_msg);
55+
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
56+
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
57+
return decode_webp(data);
58+
}
59+
60+
TORCH_CHECK(false, err_msg);
4661
}
4762

4863
} // namespace image
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include "decode_webp.h"
2+
3+
#if WEBP_FOUND
4+
#include "webp/decode.h"
5+
#endif // WEBP_FOUND
6+
7+
namespace vision {
8+
namespace image {
9+
10+
#if !WEBP_FOUND
11+
torch::Tensor decode_webp(const torch::Tensor& data) {
12+
TORCH_CHECK(
13+
false, "decode_webp: torchvision not compiled with libwebp support");
14+
}
15+
#else
16+
17+
torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
18+
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
19+
TORCH_CHECK(
20+
encoded_data.dtype() == torch::kU8,
21+
"Input tensor must have uint8 data type, got ",
22+
encoded_data.dtype());
23+
TORCH_CHECK(
24+
encoded_data.dim() == 1,
25+
"Input tensor must be 1-dimensional, got ",
26+
encoded_data.dim(),
27+
" dims.");
28+
29+
int width = 0;
30+
int height = 0;
31+
auto decoded_data = WebPDecodeRGB(
32+
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
33+
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
34+
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
35+
return out.permute({2, 0, 1}); // return CHW, channels-last
36+
}
37+
#endif // WEBP_FOUND
38+
39+
} // namespace image
40+
} // namespace vision
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <torch/types.h>
4+
5+
namespace vision {
6+
namespace image {
7+
8+
C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);
9+
10+
} // namespace image
11+
} // namespace vision

torchvision/csrc/io/image/image.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ static auto registry =
2121
.op("image::encode_png", &encode_png)
2222
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
2323
&decode_jpeg)
24+
.op("image::decode_webp", &decode_webp)
2425
.op("image::encode_jpeg", &encode_jpeg)
2526
.op("image::read_file", &read_file)
2627
.op("image::write_file", &write_file)

torchvision/csrc/io/image/image.h

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "cpu/decode_image.h"
55
#include "cpu/decode_jpeg.h"
66
#include "cpu/decode_png.h"
7+
#include "cpu/decode_webp.h"
78
#include "cpu/encode_jpeg.h"
89
#include "cpu/encode_png.h"
910
#include "cpu/read_write_file.h"

0 commit comments

Comments
 (0)