Add MLIC series models#357
Open
Yiozolm wants to merge 3 commits into
Open
Conversation
Collaborator
|
Thanks for the PR. I have applied two commits:
Their diffs are shown below. Please verify correctness, and I will merge this PR. commit 6b9ef05d77232173c2395ab87474577c1bfa5689
Author: Mateen Ulhaq <mateen.ulhaq@interdigital.com>
Date: Mon Jun 29 22:43:17 2026 -0700
refactor(latent_codecs): extract checkerboard helpers directly inside checkerboard.py
diff --git a/compressai/latent_codecs/_checkerboard_helpers.py b/compressai/latent_codecs/_checkerboard_helpers.py
deleted file mode 100644
index 84b65a3a..00000000
--- a/compressai/latent_codecs/_checkerboard_helpers.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright (c) 2021-2025, InterDigital Communications, Inc
-# All rights reserved.
-
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted (subject to the limitations in the disclaimer
-# below) provided that the following conditions are met:
-
-# * Redistributions of source code must retain the above copyright notice,
-# this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-# * Neither the name of InterDigital Communications, Inc nor the names of its
-# contributors may be used to endorse or promote products derived from this
-# software without specific prior written permission.
-
-# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
-# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
-# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
-# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
-# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
-# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
-# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
-# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
-# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Pure functional helpers shared by checkerboard latent codecs.
-
-These are extracted from :class:`CheckerboardLatentCodec` so that sibling
-codecs (e.g. :class:`MultiContextCheckerboardLatentCodec`) can reuse the
-exact same checkerboard split / merge / mask logic without duplicating it.
-A single source of truth here also means an anchor-parity boundary fix
-applies to every checkerboard codec at once.
-"""
-
-from __future__ import annotations
-
-import torch
-
-from torch import Tensor
-
-__all__ = [
- "embed",
- "embed_step",
- "mask_all",
- "mask_all_but_step",
- "merge",
- "step_parity",
- "unembed",
- "write_step",
-]
-
-
-def step_parity(step: str, anchor_parity: str) -> str:
- """Resolve a ``step`` ('anchor' / 'non_anchor') to a parity string."""
- if step == "anchor":
- return anchor_parity
- if step == "non_anchor":
- return "odd" if anchor_parity == "even" else "even"
- raise ValueError(f'Invalid "step" value "{step}"')
-
-
-def unembed(y: Tensor, *, anchor_parity: str) -> Tensor:
- """Separate single tensor into two even/odd checkerboard chunks.
-
- .. code-block:: none
-
- ■ □ ■ □ ■ ■ □ □
- □ ■ □ ■ ---> ■ ■ □ □
- ■ □ ■ □ ■ ■ □ □
- """
- n, c, h, w = y.shape
- y_packed = y.new_zeros((2, n, c, h, w // 2))
- if anchor_parity == "even":
- y_packed[0, ..., 0::2, :] = y[..., 0::2, 0::2]
- y_packed[0, ..., 1::2, :] = y[..., 1::2, 1::2]
- y_packed[1, ..., 0::2, :] = y[..., 0::2, 1::2]
- y_packed[1, ..., 1::2, :] = y[..., 1::2, 0::2]
- else:
- y_packed[0, ..., 0::2, :] = y[..., 0::2, 1::2]
- y_packed[0, ..., 1::2, :] = y[..., 1::2, 0::2]
- y_packed[1, ..., 0::2, :] = y[..., 0::2, 0::2]
- y_packed[1, ..., 1::2, :] = y[..., 1::2, 1::2]
- return y_packed
-
-
-def embed(y_packed: Tensor, *, anchor_parity: str) -> Tensor:
- """Combine two even/odd checkerboard chunks into single tensor.
-
- .. code-block:: none
-
- ■ ■ □ □ ■ □ ■ □
- ■ ■ □ □ ---> □ ■ □ ■
- ■ ■ □ □ ■ □ ■ □
- """
- num_chunks, n, c, h, w_half = y_packed.shape
- assert num_chunks == 2
- y = y_packed.new_zeros((n, c, h, w_half * 2))
- if anchor_parity == "even":
- y[..., 0::2, 0::2] = y_packed[0, ..., 0::2, :]
- y[..., 1::2, 1::2] = y_packed[0, ..., 1::2, :]
- y[..., 0::2, 1::2] = y_packed[1, ..., 0::2, :]
- y[..., 1::2, 0::2] = y_packed[1, ..., 1::2, :]
- else:
- y[..., 0::2, 1::2] = y_packed[0, ..., 0::2, :]
- y[..., 1::2, 0::2] = y_packed[0, ..., 1::2, :]
- y[..., 0::2, 0::2] = y_packed[1, ..., 0::2, :]
- y[..., 1::2, 1::2] = y_packed[1, ..., 1::2, :]
- return y
-
-
-def embed_step(
- step_index: int, y_i: Tensor, width: int, *, anchor_parity: str
-) -> Tensor:
- """Embed a per-step half-width tensor back into a full-grid tensor."""
- n, c, h, _ = y_i.shape
- y_packed = y_i.new_zeros((2, n, c, h, width // 2))
- y_packed[step_index] = y_i
- return embed(y_packed, anchor_parity=anchor_parity)
-
-
-def write_step(dest: Tensor, src: Tensor, step: str, *, anchor_parity: str) -> None:
- """Copy ``src`` pixels at the current step's positions into ``dest`` in-place."""
- parity = step_parity(step, anchor_parity)
- if parity == "even":
- dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
- dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
- else:
- dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
- dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
-
-
-def mask_all_but_step(y: Tensor, step: str, *, anchor_parity: str) -> Tensor:
- """Keep only pixels in the current step, and zero out the rest."""
- y = y.clone()
- parity = step_parity(step, anchor_parity)
- if parity == "even":
- y[..., 0::2, 1::2] = 0
- y[..., 1::2, 0::2] = 0
- else:
- y[..., 0::2, 0::2] = 0
- y[..., 1::2, 1::2] = 0
- return y
-
-
-def mask_all(y: Tensor) -> Tensor:
- """Return a zero tensor with the same shape, dtype and device as ``y``."""
- return torch.zeros_like(y)
-
-
-def merge(*args: Tensor) -> Tensor:
- """Concatenate tensors along the channel dimension."""
- return torch.cat(args, dim=1)
diff --git a/compressai/latent_codecs/_selective_checkerboard.py b/compressai/latent_codecs/_selective_checkerboard.py
index f4813cea..2674c27a 100644
--- a/compressai/latent_codecs/_selective_checkerboard.py
+++ b/compressai/latent_codecs/_selective_checkerboard.py
@@ -36,7 +36,7 @@ from torch import Tensor
from compressai.entropy_models import GaussianConditional
-from . import _checkerboard_helpers as _ckb
+from . import checkerboard as _ckb
__all__ = [
"apply_selective_y_hat",
diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py
index 0be77f4d..c4b4625b 100644
--- a/compressai/latent_codecs/checkerboard.py
+++ b/compressai/latent_codecs/checkerboard.py
@@ -43,6 +43,14 @@ from .base import LatentCodec
__all__ = [
"CheckerboardLatentCodec",
+ "embed",
+ "embed_step",
+ "mask_all",
+ "mask_all_but_step",
+ "merge",
+ "step_parity",
+ "unembed",
+ "write_step",
]
@@ -297,84 +305,134 @@ class CheckerboardLatentCodec(LatentCodec):
}
def unembed(self, y: Tensor) -> Tensor:
- """Separate single tensor into two even/odd checkerboard chunks.
-
- .. code-block:: none
-
- ■ □ ■ □ ■ ■ □ □
- □ ■ □ ■ ---> ■ ■ □ □
- ■ □ ■ □ ■ ■ □ □
- """
- n, c, h, w = y.shape
- y_ = y.new_zeros((2, n, c, h, w // 2))
- if self.anchor_parity == "even":
- y_[0, ..., 0::2, :] = y[..., 0::2, 0::2]
- y_[0, ..., 1::2, :] = y[..., 1::2, 1::2]
- y_[1, ..., 0::2, :] = y[..., 0::2, 1::2]
- y_[1, ..., 1::2, :] = y[..., 1::2, 0::2]
- else:
- y_[0, ..., 0::2, :] = y[..., 0::2, 1::2]
- y_[0, ..., 1::2, :] = y[..., 1::2, 0::2]
- y_[1, ..., 0::2, :] = y[..., 0::2, 0::2]
- y_[1, ..., 1::2, :] = y[..., 1::2, 1::2]
- return y_
+ return unembed(y, anchor_parity=self.anchor_parity)
def embed(self, y_: Tensor) -> Tensor:
- """Combine two even/odd checkerboard chunks into single tensor.
-
- .. code-block:: none
-
- ■ ■ □ □ ■ □ ■ □
- ■ ■ □ □ ---> □ ■ □ ■
- ■ ■ □ □ ■ □ ■ □
- """
- num_chunks, n, c, h, w_half = y_.shape
- assert num_chunks == 2
- y = y_.new_zeros((n, c, h, w_half * 2))
- if self.anchor_parity == "even":
- y[..., 0::2, 0::2] = y_[0, ..., 0::2, :]
- y[..., 1::2, 1::2] = y_[0, ..., 1::2, :]
- y[..., 0::2, 1::2] = y_[1, ..., 0::2, :]
- y[..., 1::2, 0::2] = y_[1, ..., 1::2, :]
- else:
- y[..., 0::2, 1::2] = y_[0, ..., 0::2, :]
- y[..., 1::2, 0::2] = y_[0, ..., 1::2, :]
- y[..., 0::2, 0::2] = y_[1, ..., 0::2, :]
- y[..., 1::2, 1::2] = y_[1, ..., 1::2, :]
- return y
+ return embed(y_, anchor_parity=self.anchor_parity)
def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
- """Copy pixels in the current step."""
- assert step in ("anchor", "non_anchor")
- parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
- if parity == "even":
- dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
- dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
- else:
- dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
- dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
+ return write_step(dest, src, step, anchor_parity=self.anchor_parity)
def _mask_all_but_step(self, y: Tensor, step: str) -> Tensor:
- """Keep only pixels in the current step, and zero out the rest."""
- y = y.clone()
- parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
- if parity == "even":
- y[..., 0::2, 1::2] = 0
- y[..., 1::2, 0::2] = 0
- elif parity == "odd":
- y[..., 0::2, 0::2] = 0
- y[..., 1::2, 1::2] = 0
- return y
+ return mask_all_but_step(y, step, anchor_parity=self.anchor_parity)
def _mask_all(self, y: Tensor) -> Tensor:
- y = y.clone()
- y[:] = 0
- return y
+ return mask_all(y)
def merge(self, *args: Tensor) -> Tensor:
- return torch.cat(args, dim=1)
+ return merge(*args)
def quantize(self, y: Tensor) -> Tensor:
mode = "noise" if self.training else "dequantize"
y_hat = EntropyModel.quantize(None, y, mode)
return y_hat
+
+
+def unembed(y: Tensor, *, anchor_parity: str) -> Tensor:
+ """Separate single tensor into two even/odd checkerboard chunks.
+
+ .. code-block:: none
+
+ ■ □ ■ □ ■ ■ □ □
+ □ ■ □ ■ ---> ■ ■ □ □
+ ■ □ ■ □ ■ ■ □ □
+ """
+ n, c, h, w = y.shape
+ y_ = y.new_zeros((2, n, c, h, w // 2))
+ if anchor_parity == "even":
+ y_[0, ..., 0::2, :] = y[..., 0::2, 0::2]
+ y_[0, ..., 1::2, :] = y[..., 1::2, 1::2]
+ y_[1, ..., 0::2, :] = y[..., 0::2, 1::2]
+ y_[1, ..., 1::2, :] = y[..., 1::2, 0::2]
+ elif anchor_parity == "odd":
+ y_[0, ..., 0::2, :] = y[..., 0::2, 1::2]
+ y_[0, ..., 1::2, :] = y[..., 1::2, 0::2]
+ y_[1, ..., 0::2, :] = y[..., 0::2, 0::2]
+ y_[1, ..., 1::2, :] = y[..., 1::2, 1::2]
+ else:
+ raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+ return y_
+
+
+def embed(y_: Tensor, *, anchor_parity: str) -> Tensor:
+ """Combine two even/odd checkerboard chunks into single tensor.
+
+ .. code-block:: none
+
+ ■ ■ □ □ ■ □ ■ □
+ ■ ■ □ □ ---> □ ■ □ ■
+ ■ ■ □ □ ■ □ ■ □
+ """
+ num_chunks, n, c, h, w_half = y_.shape
+ assert num_chunks == 2
+ y = y_.new_zeros((n, c, h, w_half * 2))
+ if anchor_parity == "even":
+ y[..., 0::2, 0::2] = y_[0, ..., 0::2, :]
+ y[..., 1::2, 1::2] = y_[0, ..., 1::2, :]
+ y[..., 0::2, 1::2] = y_[1, ..., 0::2, :]
+ y[..., 1::2, 0::2] = y_[1, ..., 1::2, :]
+ elif anchor_parity == "odd":
+ y[..., 0::2, 1::2] = y_[0, ..., 0::2, :]
+ y[..., 1::2, 0::2] = y_[0, ..., 1::2, :]
+ y[..., 0::2, 0::2] = y_[1, ..., 0::2, :]
+ y[..., 1::2, 1::2] = y_[1, ..., 1::2, :]
+ else:
+ raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+ return y
+
+
+def embed_step(
+ step_index: int, y_i: Tensor, width: int, *, anchor_parity: str
+) -> Tensor:
+ """Embed a per-step half-width tensor back into a full-grid tensor."""
+ n, c, h, _ = y_i.shape
+ y_ = y_i.new_zeros((2, n, c, h, width // 2))
+ y_[step_index] = y_i
+ return embed(y_, anchor_parity=anchor_parity)
+
+
+def step_parity(step: str, anchor_parity: str) -> str:
+ """Resolve a ``step`` ('anchor' / 'non_anchor') to a parity string."""
+ if anchor_parity not in ("even", "odd"):
+ raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+ if step == "anchor":
+ return anchor_parity
+ if step == "non_anchor":
+ return "odd" if anchor_parity == "even" else "even"
+ raise ValueError(f'Invalid "step" value "{step}"')
+
+
+def write_step(dest: Tensor, src: Tensor, step: str, *, anchor_parity: str) -> None:
+ """Copy ``src`` pixels at the current step's positions into ``dest`` in-place."""
+ parity = step_parity(step, anchor_parity)
+ if parity == "even":
+ dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
+ dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
+ else:
+ dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
+ dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
+
+
+def mask_all_but_step(y: Tensor, step: str, *, anchor_parity: str) -> Tensor:
+ """Keep only pixels in the current step, and zero out the rest."""
+ y = y.clone()
+ parity = step_parity(step, anchor_parity)
+ if parity == "even":
+ y[..., 0::2, 1::2] = 0
+ y[..., 1::2, 0::2] = 0
+ elif parity == "odd":
+ y[..., 0::2, 0::2] = 0
+ y[..., 1::2, 1::2] = 0
+ return y
+
+
+def mask_all(y: Tensor) -> Tensor:
+ """Return a zero tensor with the same shape, dtype and device as ``y``."""
+ y = y.clone()
+ y[:] = 0
+ return y
+
+
+def merge(*args: Tensor) -> Tensor:
+ """Concatenate tensors along the channel dimension."""
+ return torch.cat(args, dim=1)
diff --git a/compressai/latent_codecs/multi_context_checkerboard.py b/compressai/latent_codecs/multi_context_checkerboard.py
index 3def49f8..ebeee7d6 100644
--- a/compressai/latent_codecs/multi_context_checkerboard.py
+++ b/compressai/latent_codecs/multi_context_checkerboard.py
@@ -38,8 +38,8 @@ from compressai.entropy_models import GaussianConditional
from compressai.ops import quantize_ste
from compressai.registry import register_module
-from . import _checkerboard_helpers as _ckb
from . import _selective_checkerboard as _sel
+from . import checkerboard as _ckb
from .base import LatentCodec
from .gaussian_conditional import GaussianConditionalLatentCodec
commit 173b94487d6119bb97e23d4b65cfdec2bddde072
Author: Mateen Ulhaq <mateen.ulhaq@interdigital.com>
Date: Mon Jun 29 22:47:10 2026 -0700
fix(latent_codecs): standardize y.shape[2:4] -> y.shape[1:]
diff --git a/compressai/latent_codecs/_selective_checkerboard.py b/compressai/latent_codecs/_selective_checkerboard.py
index 2674c27a..657b167b 100644
--- a/compressai/latent_codecs/_selective_checkerboard.py
+++ b/compressai/latent_codecs/_selective_checkerboard.py
@@ -190,7 +190,7 @@ def compress_selected(
y_hat[sample_index].reshape(-1)[mask] = y_hat_i.reshape(-1).to(y_hat.dtype)
y_strings.append(y_string)
- return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat}
+ return {"strings": [y_strings], "shape": y.shape[1:], "y_hat": y_hat}
def decompress_selected(
|
YodaEmbedding
approved these changes
Jun 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds the MLIC family of learned image compression models and the multi-context checkerboard latent codecs needed to express their entropy models in CompressAI:
MLICMLICPlusMLICPlusPlusMLICv2The work is related to the following papers:
Only the MLIC++ implementation is adapted from the official
JiangWeibeta/MLICcode.The
MLIC,MLICPlus, andMLICv2implementations are paper-based reproductions written against CompressAI's model and latent-codec structure.This PR intentionally does not include the SGA-based inference-time latent refinement used by MLICv2+.
That part is planned for a follow-up PR so this change can focus on the base MLIC-series models and reusable multi-context checkerboard codec components.
Changes
MLIC,MLICPlus,MLICPlusPlus, andMLICv2model classes.Validation
uv run ruff check compressai tests examplesuv run ruff format --check compressai tests examples86 passed, 1 skipped