-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit a4ab9b6
Showing
12 changed files
with
1,029 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
build/ | ||
dist/ | ||
*.egg-info/ | ||
*/**/__pycache__ | ||
*/__pycache__ | ||
*/*.pyc | ||
*/**/*.pyc | ||
*/**/**/*.pyc | ||
*/**/*~ | ||
*~ | ||
|
||
docs/build | ||
# sphinx-gallery | ||
docs/source/auto_examples/ | ||
docs/source/gen_modules/ | ||
docs/source/generated/ | ||
docs/source/models/generated/ | ||
# pytorch-sphinx-theme gets installed here | ||
docs/src | ||
|
||
.coverage | ||
htmlcov | ||
.*.swp | ||
*.so* | ||
*.dylib* | ||
*/*.so* | ||
*/*.dylib* | ||
*.swp | ||
*.swo | ||
gen.yml | ||
.mypy_cache | ||
.vscode/ | ||
.idea/ | ||
*.orig | ||
*-checkpoint.ipynb | ||
*.venv | ||
|
||
## Xcode User settings | ||
xcuserdata/ | ||
|
||
# direnv | ||
.direnv | ||
.envrc |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
Disclaimer: this is still WIP!! | ||
|
||
|
||
An extension of `torchvision` for decoding AVIF and HEIC images. | ||
|
||
Usage: | ||
|
||
```bash | ||
$ pip install torchvision-extra-decoders | ||
``` | ||
|
||
Then | ||
|
||
```py | ||
from torchvision.io import decode_image, decode_heic, decode_avif | ||
|
||
img = decode_image("image.heic") | ||
img = decode_image("image.avif") | ||
|
||
img = decode_heic("image.heic") | ||
img = decode_avif("image.avif") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
# This software may be used and distributed according to the terms of the | ||
# GNU Lesser General Public License version 2. | ||
|
||
import sys | ||
import os | ||
from pathlib import Path | ||
from setuptools import find_packages, setup | ||
|
||
from torch.utils.cpp_extension import BuildExtension, CppExtension | ||
|
||
def find_library(header): | ||
# returns (found, include dir, library dir) | ||
# if include dir or library dir is None, it means that the library is in | ||
# standard paths and don't need to be added to compiler / linker search | ||
# paths | ||
|
||
searching_for = f"Searching for {header}" | ||
|
||
# Try conda-related prefixes. If BUILD_PREFIX is set it means conda-build is | ||
# being run. If CONDA_PREFIX is set then we're in a conda environment. | ||
for prefix_env_var in ("BUILD_PREFIX", "CONDA_PREFIX"): | ||
if (prefix := os.environ.get(prefix_env_var)) is not None: | ||
prefix = Path(prefix) | ||
if sys.platform == "win32": | ||
prefix = prefix / "Library" | ||
include_dir = prefix / "include" | ||
library_dir = prefix / "lib" | ||
if (include_dir / header).exists(): | ||
print(f"{searching_for}. Found in {prefix_env_var}.") | ||
return True, str(include_dir), str(library_dir) | ||
print(f"{searching_for}. Didn't find in {prefix_env_var}.") | ||
|
||
if sys.platform == "linux": | ||
for prefix in (Path("/usr/include"), Path("/usr/local/include")): | ||
if (prefix / header).exists(): | ||
print(f"{searching_for}. Found in {prefix}.") | ||
return True, None, None | ||
print(f"{searching_for}. Didn't find in {prefix}") | ||
|
||
return False, None, None | ||
|
||
|
||
def make_extension(): | ||
|
||
heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif_cxx.h") | ||
if not heic_found: | ||
raise RuntimeError("Couldn't find libheic!") | ||
|
||
print(f"{heic_include_dir = }") | ||
print(f"{heic_library_dir = }") | ||
|
||
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") | ||
if not avif_found: | ||
raise RuntimeError("Couldn't find libavif!") | ||
|
||
print(f"{heic_include_dir = }") | ||
print(f"{heic_library_dir = }") | ||
|
||
sources = list(Path("torchvision_extra_decoders/csrc/").glob("*.cpp")) | ||
print(f"{sources = }") | ||
|
||
return CppExtension( | ||
name="torchvision_extra_decoders.extra_decoders_lib", | ||
sources=sorted(str(s) for s in sources), | ||
include_dirs=[heic_include_dir, avif_include_dir], | ||
library_dirs=[heic_library_dir, avif_library_dir], | ||
libraries=["heif", "avif"], | ||
extra_compile_args={"cxx": ["-g0"]}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
with open("README.md") as f: | ||
readme = f.read() | ||
|
||
PACKAGE_NAME = "torchvision-extra-decoders" | ||
|
||
setup( | ||
name=PACKAGE_NAME, | ||
version="0.0.1.dev", | ||
author="PyTorch Team", | ||
author_email="[email protected]", | ||
url="TODO", | ||
description="TODO", | ||
long_description=readme, | ||
long_description_content_type="text/markdown", | ||
license="LGPLv2.1", | ||
packages=find_packages(exclude=("test",)), | ||
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]}, | ||
zip_safe=False, | ||
install_requires=[], | ||
python_requires=">=3.9", | ||
ext_modules=[make_extension()], | ||
cmdclass={ | ||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True), | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
# This software may be used and distributed according to the terms of the | ||
# GNU Lesser General Public License version 2. | ||
|
||
from pathlib import Path | ||
import torch | ||
|
||
|
||
def expose_extra_decoders(): | ||
suffix = ".so" # TODO: make this cross-platform | ||
torch.ops.load_library(Path(__file__).parent / f"extra_decoders_lib{suffix}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
// This software may be used and distributed according to the terms of the | ||
// GNU Lesser General Public License version 2. | ||
|
||
#include "common.h" | ||
#include <torch/torch.h> | ||
|
||
namespace extra_decoders_ns { | ||
|
||
void validate_encoded_data(const torch::Tensor& encoded_data) { | ||
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); | ||
TORCH_CHECK( | ||
encoded_data.dtype() == torch::kU8, | ||
"Input tensor must have uint8 data type, got ", | ||
encoded_data.dtype()); | ||
TORCH_CHECK( | ||
encoded_data.dim() == 1 && encoded_data.numel() > 0, | ||
"Input tensor must be 1-dimensional and non-empty, got ", | ||
encoded_data.dim(), | ||
" dims and ", | ||
encoded_data.numel(), | ||
" numels."); | ||
} | ||
|
||
bool should_this_return_rgb_or_rgba( | ||
ImageReadMode mode, | ||
bool has_alpha) { | ||
// Return true if the calling decoding function should return a 3D RGB tensor, | ||
// and false if it should return a 4D RGBA tensor. | ||
// This function ignores the requested "grayscale" modes and treats it as | ||
// "unchanged", so it should only used on decoders who don't support grayscale | ||
// outputs. | ||
|
||
if (mode == IMAGE_READ_MODE_RGB) { | ||
return true; | ||
} | ||
if (mode == IMAGE_READ_MODE_RGB_ALPHA) { | ||
return false; | ||
} | ||
// From here we assume mode is "unchanged", even for grayscale ones. | ||
return !has_alpha; | ||
} | ||
|
||
} // namespace extra_decoders_ns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
// This software may be used and distributed according to the terms of the | ||
// GNU Lesser General Public License version 2. | ||
|
||
#pragma once | ||
|
||
#include <stdint.h> | ||
#include <torch/torch.h> | ||
|
||
namespace extra_decoders_ns{ | ||
|
||
/* Should be kept in-sync with Python ImageReadMode enum */ | ||
using ImageReadMode = int64_t; | ||
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; | ||
const ImageReadMode IMAGE_READ_MODE_GRAY = 1; | ||
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; | ||
const ImageReadMode IMAGE_READ_MODE_RGB = 3; | ||
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; | ||
|
||
void validate_encoded_data(const torch::Tensor& encoded_data); | ||
|
||
bool should_this_return_rgb_or_rgba( | ||
ImageReadMode mode, | ||
bool has_alpha); | ||
|
||
} // namespace extra_decoders_ns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
// This software may be used and distributed according to the terms of the | ||
// GNU Lesser General Public License version 2. | ||
|
||
#include <torch/torch.h> | ||
#include "decode_avif.h" | ||
#include "common.h" | ||
#include "avif/avif.h" | ||
|
||
namespace extra_decoders_ns { | ||
|
||
// This normally comes from avif_cxx.h, but it's not always present when | ||
// installing libavif. So we just copy/paste it here. | ||
struct UniquePtrDeleter { | ||
void operator()(avifDecoder* decoder) const { | ||
avifDecoderDestroy(decoder); | ||
} | ||
}; | ||
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>; | ||
|
||
torch::Tensor decode_avif( | ||
const torch::Tensor& encoded_data, | ||
ImageReadMode mode) { | ||
// This is based on | ||
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c | ||
// Refer there for more detail about what each function does, and which | ||
// structure/data is available after which call. | ||
|
||
validate_encoded_data(encoded_data); | ||
|
||
DecoderPtr decoder(avifDecoderCreate()); | ||
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); | ||
|
||
auto result = AVIF_RESULT_UNKNOWN_ERROR; | ||
result = avifDecoderSetIOMemory( | ||
decoder.get(), encoded_data.data_ptr<uint8_t>(), encoded_data.numel()); | ||
TORCH_CHECK( | ||
result == AVIF_RESULT_OK, | ||
"avifDecoderSetIOMemory failed:", | ||
avifResultToString(result)); | ||
|
||
result = avifDecoderParse(decoder.get()); | ||
TORCH_CHECK( | ||
result == AVIF_RESULT_OK, | ||
"avifDecoderParse failed: ", | ||
avifResultToString(result)); | ||
TORCH_CHECK( | ||
decoder->imageCount == 1, "Avif file contains more than one image"); | ||
|
||
result = avifDecoderNextImage(decoder.get()); | ||
TORCH_CHECK( | ||
result == AVIF_RESULT_OK, | ||
"avifDecoderNextImage failed:", | ||
avifResultToString(result)); | ||
|
||
avifRGBImage rgb; | ||
memset(&rgb, 0, sizeof(rgb)); | ||
avifRGBImageSetDefaults(&rgb, decoder->image); | ||
|
||
// images encoded as 10 or 12 bits will be decoded as uint16. The rest are | ||
// decoded as uint8. | ||
auto use_uint8 = (decoder->image->depth <= 8); | ||
rgb.depth = use_uint8 ? 8 : 16; | ||
|
||
auto return_rgb = | ||
should_this_return_rgb_or_rgba( | ||
mode, decoder->alphaPresent); | ||
|
||
auto num_channels = return_rgb ? 3 : 4; | ||
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; | ||
rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE; | ||
|
||
auto out = torch::empty( | ||
{rgb.height, rgb.width, num_channels}, | ||
use_uint8 ? torch::kUInt8 : at::kUInt16); | ||
rgb.pixels = (uint8_t*)out.data_ptr(); | ||
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); | ||
|
||
result = avifImageYUVToRGB(decoder->image, &rgb); | ||
TORCH_CHECK( | ||
result == AVIF_RESULT_OK, | ||
"avifImageYUVToRGB failed: ", | ||
avifResultToString(result)); | ||
|
||
return out.permute({2, 0, 1}); // return CHW, channels-last | ||
} | ||
|
||
} // namespace extra_decoders_ns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
// This software may be used and distributed according to the terms of the | ||
// GNU Lesser General Public License version 2. | ||
|
||
#pragma once | ||
|
||
#include <torch/types.h> | ||
#include "common.h" | ||
|
||
namespace extra_decoders_ns { | ||
|
||
C10_EXPORT torch::Tensor decode_avif( | ||
const torch::Tensor& encoded_data, | ||
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); | ||
|
||
} // namespace extra_decoders_ns |
Oops, something went wrong.