Skip to content

Commit

Permalink
Initial POC
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 3, 2024
0 parents commit a4ab9b6
Show file tree
Hide file tree
Showing 12 changed files with 1,029 additions and 0 deletions.
43 changes: 43 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
build/
dist/
*.egg-info/
*/**/__pycache__
*/__pycache__
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/*~
*~

docs/build
# sphinx-gallery
docs/source/auto_examples/
docs/source/gen_modules/
docs/source/generated/
docs/source/models/generated/
# pytorch-sphinx-theme gets installed here
docs/src

.coverage
htmlcov
.*.swp
*.so*
*.dylib*
*/*.so*
*/*.dylib*
*.swp
*.swo
gen.yml
.mypy_cache
.vscode/
.idea/
*.orig
*-checkpoint.ipynb
*.venv

## Xcode User settings
xcuserdata/

# direnv
.direnv
.envrc
503 changes: 503 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Disclaimer: this is still WIP!!


An extension of `torchvision` for decoding AVIF and HEIC images.

Usage:

```bash
$ pip install torchvision-extra-decoders
```

Then

```py
from torchvision.io import decode_image, decode_heic, decode_avif

img = decode_image("image.heic")
img = decode_image("image.avif")

img = decode_heic("image.heic")
img = decode_avif("image.avif")
```
100 changes: 100 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This software may be used and distributed according to the terms of the
# GNU Lesser General Public License version 2.

import sys
import os
from pathlib import Path
from setuptools import find_packages, setup

from torch.utils.cpp_extension import BuildExtension, CppExtension

def find_library(header):
# returns (found, include dir, library dir)
# if include dir or library dir is None, it means that the library is in
# standard paths and don't need to be added to compiler / linker search
# paths

searching_for = f"Searching for {header}"

# Try conda-related prefixes. If BUILD_PREFIX is set it means conda-build is
# being run. If CONDA_PREFIX is set then we're in a conda environment.
for prefix_env_var in ("BUILD_PREFIX", "CONDA_PREFIX"):
if (prefix := os.environ.get(prefix_env_var)) is not None:
prefix = Path(prefix)
if sys.platform == "win32":
prefix = prefix / "Library"
include_dir = prefix / "include"
library_dir = prefix / "lib"
if (include_dir / header).exists():
print(f"{searching_for}. Found in {prefix_env_var}.")
return True, str(include_dir), str(library_dir)
print(f"{searching_for}. Didn't find in {prefix_env_var}.")

if sys.platform == "linux":
for prefix in (Path("/usr/include"), Path("/usr/local/include")):
if (prefix / header).exists():
print(f"{searching_for}. Found in {prefix}.")
return True, None, None
print(f"{searching_for}. Didn't find in {prefix}")

return False, None, None


def make_extension():

heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif_cxx.h")
if not heic_found:
raise RuntimeError("Couldn't find libheic!")

print(f"{heic_include_dir = }")
print(f"{heic_library_dir = }")

avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if not avif_found:
raise RuntimeError("Couldn't find libavif!")

print(f"{heic_include_dir = }")
print(f"{heic_library_dir = }")

sources = list(Path("torchvision_extra_decoders/csrc/").glob("*.cpp"))
print(f"{sources = }")

return CppExtension(
name="torchvision_extra_decoders.extra_decoders_lib",
sources=sorted(str(s) for s in sources),
include_dirs=[heic_include_dir, avif_include_dir],
library_dirs=[heic_library_dir, avif_library_dir],
libraries=["heif", "avif"],
extra_compile_args={"cxx": ["-g0"]},
)


if __name__ == "__main__":

with open("README.md") as f:
readme = f.read()

PACKAGE_NAME = "torchvision-extra-decoders"

setup(
name=PACKAGE_NAME,
version="0.0.1.dev",
author="PyTorch Team",
author_email="[email protected]",
url="TODO",
description="TODO",
long_description=readme,
long_description_content_type="text/markdown",
license="LGPLv2.1",
packages=find_packages(exclude=("test",)),
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
zip_safe=False,
install_requires=[],
python_requires=">=3.9",
ext_modules=[make_extension()],
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
},
)
12 changes: 12 additions & 0 deletions torchvision_extra_decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# This software may be used and distributed according to the terms of the
# GNU Lesser General Public License version 2.

from pathlib import Path
import torch


def expose_extra_decoders():
suffix = ".so" # TODO: make this cross-platform
torch.ops.load_library(Path(__file__).parent / f"extra_decoders_lib{suffix}")
45 changes: 45 additions & 0 deletions torchvision_extra_decoders/csrc/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

// This software may be used and distributed according to the terms of the
// GNU Lesser General Public License version 2.

#include "common.h"
#include <torch/torch.h>

namespace extra_decoders_ns {

void validate_encoded_data(const torch::Tensor& encoded_data) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1 && encoded_data.numel() > 0,
"Input tensor must be 1-dimensional and non-empty, got ",
encoded_data.dim(),
" dims and ",
encoded_data.numel(),
" numels.");
}

bool should_this_return_rgb_or_rgba(
ImageReadMode mode,
bool has_alpha) {
// Return true if the calling decoding function should return a 3D RGB tensor,
// and false if it should return a 4D RGBA tensor.
// This function ignores the requested "grayscale" modes and treats it as
// "unchanged", so it should only used on decoders who don't support grayscale
// outputs.

if (mode == IMAGE_READ_MODE_RGB) {
return true;
}
if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
return false;
}
// From here we assume mode is "unchanged", even for grayscale ones.
return !has_alpha;
}

} // namespace extra_decoders_ns
27 changes: 27 additions & 0 deletions torchvision_extra_decoders/csrc/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

// This software may be used and distributed according to the terms of the
// GNU Lesser General Public License version 2.

#pragma once

#include <stdint.h>
#include <torch/torch.h>

namespace extra_decoders_ns{

/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0;
const ImageReadMode IMAGE_READ_MODE_GRAY = 1;
const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;

void validate_encoded_data(const torch::Tensor& encoded_data);

bool should_this_return_rgb_or_rgba(
ImageReadMode mode,
bool has_alpha);

} // namespace extra_decoders_ns
89 changes: 89 additions & 0 deletions torchvision_extra_decoders/csrc/decode_avif.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

// This software may be used and distributed according to the terms of the
// GNU Lesser General Public License version 2.

#include <torch/torch.h>
#include "decode_avif.h"
#include "common.h"
#include "avif/avif.h"

namespace extra_decoders_ns {

// This normally comes from avif_cxx.h, but it's not always present when
// installing libavif. So we just copy/paste it here.
struct UniquePtrDeleter {
void operator()(avifDecoder* decoder) const {
avifDecoderDestroy(decoder);
}
};
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;

torch::Tensor decode_avif(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
// This is based on
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
// Refer there for more detail about what each function does, and which
// structure/data is available after which call.

validate_encoded_data(encoded_data);

DecoderPtr decoder(avifDecoderCreate());
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");

auto result = AVIF_RESULT_UNKNOWN_ERROR;
result = avifDecoderSetIOMemory(
decoder.get(), encoded_data.data_ptr<uint8_t>(), encoded_data.numel());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderSetIOMemory failed:",
avifResultToString(result));

result = avifDecoderParse(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");

result = avifDecoderNextImage(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderNextImage failed:",
avifResultToString(result));

avifRGBImage rgb;
memset(&rgb, 0, sizeof(rgb));
avifRGBImageSetDefaults(&rgb, decoder->image);

// images encoded as 10 or 12 bits will be decoded as uint16. The rest are
// decoded as uint8.
auto use_uint8 = (decoder->image->depth <= 8);
rgb.depth = use_uint8 ? 8 : 16;

auto return_rgb =
should_this_return_rgb_or_rgba(
mode, decoder->alphaPresent);

auto num_channels = return_rgb ? 3 : 4;
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;
rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE;

auto out = torch::empty(
{rgb.height, rgb.width, num_channels},
use_uint8 ? torch::kUInt8 : at::kUInt16);
rgb.pixels = (uint8_t*)out.data_ptr();
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);

result = avifImageYUVToRGB(decoder->image, &rgb);
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifImageYUVToRGB failed: ",
avifResultToString(result));

return out.permute({2, 0, 1}); // return CHW, channels-last
}

} // namespace extra_decoders_ns
17 changes: 17 additions & 0 deletions torchvision_extra_decoders/csrc/decode_avif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

// This software may be used and distributed according to the terms of the
// GNU Lesser General Public License version 2.

#pragma once

#include <torch/types.h>
#include "common.h"

namespace extra_decoders_ns {

C10_EXPORT torch::Tensor decode_avif(
const torch::Tensor& encoded_data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace extra_decoders_ns
Loading

0 comments on commit a4ab9b6

Please sign in to comment.