Skip to content

Add MLIC series models#357

Open
Yiozolm wants to merge 3 commits into
InterDigitalInc:masterfrom
Yiozolm:feat-mlicpp-series
Open

Add MLIC series models#357
Yiozolm wants to merge 3 commits into
InterDigitalInc:masterfrom
Yiozolm:feat-mlicpp-series

Conversation

@Yiozolm

@Yiozolm Yiozolm commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR adds the MLIC family of learned image compression models and the multi-context checkerboard latent codecs needed to express their entropy models in CompressAI:

  • MLIC
  • MLICPlus
  • MLICPlusPlus
  • MLICv2

The work is related to the following papers:

  • MLIC: Multi-Reference Entropy Model for Learned Image Compression
  • MLIC++: Linear Complexity Multi-Reference Entropy Modeling for Learned Image Compression
  • MLICv2: Enhanced Multi-Reference Entropy Modeling for Learned Image Compression
  • Improving inference for neural image compression

Only the MLIC++ implementation is adapted from the official JiangWeibeta/MLIC code.
The MLIC, MLICPlus, and MLICv2 implementations are paper-based reproductions written against CompressAI's model and latent-codec structure.

This PR intentionally does not include the SGA-based inference-time latent refinement used by MLICv2+.
That part is planned for a follow-up PR so this change can focus on the base MLIC-series models and reusable multi-context checkerboard codec components.

Changes

  • Add multi-context checkerboard latent codecs.
  • Add private MLIC and MLICv2 model helper modules.
  • Add MLIC, MLICPlus, MLICPlusPlus, and MLICv2 model classes.
  • Register MLIC-series constructors in the image zoo.
  • Add an MLIC++ checkpoint conversion example for official upstream checkpoints.
  • Add focused tests for codecs, helper layers, model construction, zoo registration, and state-dict loading.

Validation

  • uv run ruff check compressai tests examples
  • uv run ruff format --check compressai tests examples
  • Focused MLIC test suite: 86 passed, 1 skipped

@YodaEmbedding YodaEmbedding self-assigned this Jun 29, 2026
@YodaEmbedding YodaEmbedding self-requested a review June 29, 2026 20:49
@YodaEmbedding YodaEmbedding added the enhancement New feature or request label Jun 29, 2026
@YodaEmbedding

YodaEmbedding commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the PR.

I have applied two commits:

  • refactor(latent_codecs): extract checkerboard helpers directly inside checkerboard.py
  • fix(latent_codecs): standardize y.shape[2:4] -> y.shape[1:]

Their diffs are shown below.


Please verify correctness, and I will merge this PR.


commit 6b9ef05d77232173c2395ab87474577c1bfa5689
Author: Mateen Ulhaq <mateen.ulhaq@interdigital.com>
Date:   Mon Jun 29 22:43:17 2026 -0700

    refactor(latent_codecs): extract checkerboard helpers directly inside checkerboard.py

diff --git a/compressai/latent_codecs/_checkerboard_helpers.py b/compressai/latent_codecs/_checkerboard_helpers.py
deleted file mode 100644
index 84b65a3a..00000000
--- a/compressai/latent_codecs/_checkerboard_helpers.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright (c) 2021-2025, InterDigital Communications, Inc
-# All rights reserved.
-
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted (subject to the limitations in the disclaimer
-# below) provided that the following conditions are met:
-
-# * Redistributions of source code must retain the above copyright notice,
-#   this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright notice,
-#   this list of conditions and the following disclaimer in the documentation
-#   and/or other materials provided with the distribution.
-# * Neither the name of InterDigital Communications, Inc nor the names of its
-#   contributors may be used to endorse or promote products derived from this
-#   software without specific prior written permission.
-
-# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
-# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
-# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
-# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
-# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
-# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
-# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
-# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
-# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Pure functional helpers shared by checkerboard latent codecs.
-
-These are extracted from :class:`CheckerboardLatentCodec` so that sibling
-codecs (e.g. :class:`MultiContextCheckerboardLatentCodec`) can reuse the
-exact same checkerboard split / merge / mask logic without duplicating it.
-A single source of truth here also means an anchor-parity boundary fix
-applies to every checkerboard codec at once.
-"""
-
-from __future__ import annotations
-
-import torch
-
-from torch import Tensor
-
-__all__ = [
-    "embed",
-    "embed_step",
-    "mask_all",
-    "mask_all_but_step",
-    "merge",
-    "step_parity",
-    "unembed",
-    "write_step",
-]
-
-
-def step_parity(step: str, anchor_parity: str) -> str:
-    """Resolve a ``step`` ('anchor' / 'non_anchor') to a parity string."""
-    if step == "anchor":
-        return anchor_parity
-    if step == "non_anchor":
-        return "odd" if anchor_parity == "even" else "even"
-    raise ValueError(f'Invalid "step" value "{step}"')
-
-
-def unembed(y: Tensor, *, anchor_parity: str) -> Tensor:
-    """Separate single tensor into two even/odd checkerboard chunks.
-
-    .. code-block:: none
-
-        ■ □ ■ □         ■ ■   □ □
-        □ ■ □ ■   --->  ■ ■   □ □
-        ■ □ ■ □         ■ ■   □ □
-    """
-    n, c, h, w = y.shape
-    y_packed = y.new_zeros((2, n, c, h, w // 2))
-    if anchor_parity == "even":
-        y_packed[0, ..., 0::2, :] = y[..., 0::2, 0::2]
-        y_packed[0, ..., 1::2, :] = y[..., 1::2, 1::2]
-        y_packed[1, ..., 0::2, :] = y[..., 0::2, 1::2]
-        y_packed[1, ..., 1::2, :] = y[..., 1::2, 0::2]
-    else:
-        y_packed[0, ..., 0::2, :] = y[..., 0::2, 1::2]
-        y_packed[0, ..., 1::2, :] = y[..., 1::2, 0::2]
-        y_packed[1, ..., 0::2, :] = y[..., 0::2, 0::2]
-        y_packed[1, ..., 1::2, :] = y[..., 1::2, 1::2]
-    return y_packed
-
-
-def embed(y_packed: Tensor, *, anchor_parity: str) -> Tensor:
-    """Combine two even/odd checkerboard chunks into single tensor.
-
-    .. code-block:: none
-
-        ■ ■   □ □         ■ □ ■ □
-        ■ ■   □ □   --->  □ ■ □ ■
-        ■ ■   □ □         ■ □ ■ □
-    """
-    num_chunks, n, c, h, w_half = y_packed.shape
-    assert num_chunks == 2
-    y = y_packed.new_zeros((n, c, h, w_half * 2))
-    if anchor_parity == "even":
-        y[..., 0::2, 0::2] = y_packed[0, ..., 0::2, :]
-        y[..., 1::2, 1::2] = y_packed[0, ..., 1::2, :]
-        y[..., 0::2, 1::2] = y_packed[1, ..., 0::2, :]
-        y[..., 1::2, 0::2] = y_packed[1, ..., 1::2, :]
-    else:
-        y[..., 0::2, 1::2] = y_packed[0, ..., 0::2, :]
-        y[..., 1::2, 0::2] = y_packed[0, ..., 1::2, :]
-        y[..., 0::2, 0::2] = y_packed[1, ..., 0::2, :]
-        y[..., 1::2, 1::2] = y_packed[1, ..., 1::2, :]
-    return y
-
-
-def embed_step(
-    step_index: int, y_i: Tensor, width: int, *, anchor_parity: str
-) -> Tensor:
-    """Embed a per-step half-width tensor back into a full-grid tensor."""
-    n, c, h, _ = y_i.shape
-    y_packed = y_i.new_zeros((2, n, c, h, width // 2))
-    y_packed[step_index] = y_i
-    return embed(y_packed, anchor_parity=anchor_parity)
-
-
-def write_step(dest: Tensor, src: Tensor, step: str, *, anchor_parity: str) -> None:
-    """Copy ``src`` pixels at the current step's positions into ``dest`` in-place."""
-    parity = step_parity(step, anchor_parity)
-    if parity == "even":
-        dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
-        dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
-    else:
-        dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
-        dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
-
-
-def mask_all_but_step(y: Tensor, step: str, *, anchor_parity: str) -> Tensor:
-    """Keep only pixels in the current step, and zero out the rest."""
-    y = y.clone()
-    parity = step_parity(step, anchor_parity)
-    if parity == "even":
-        y[..., 0::2, 1::2] = 0
-        y[..., 1::2, 0::2] = 0
-    else:
-        y[..., 0::2, 0::2] = 0
-        y[..., 1::2, 1::2] = 0
-    return y
-
-
-def mask_all(y: Tensor) -> Tensor:
-    """Return a zero tensor with the same shape, dtype and device as ``y``."""
-    return torch.zeros_like(y)
-
-
-def merge(*args: Tensor) -> Tensor:
-    """Concatenate tensors along the channel dimension."""
-    return torch.cat(args, dim=1)
diff --git a/compressai/latent_codecs/_selective_checkerboard.py b/compressai/latent_codecs/_selective_checkerboard.py
index f4813cea..2674c27a 100644
--- a/compressai/latent_codecs/_selective_checkerboard.py
+++ b/compressai/latent_codecs/_selective_checkerboard.py
@@ -36,7 +36,7 @@ from torch import Tensor
 
 from compressai.entropy_models import GaussianConditional
 
-from . import _checkerboard_helpers as _ckb
+from . import checkerboard as _ckb
 
 __all__ = [
     "apply_selective_y_hat",
diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py
index 0be77f4d..c4b4625b 100644
--- a/compressai/latent_codecs/checkerboard.py
+++ b/compressai/latent_codecs/checkerboard.py
@@ -43,6 +43,14 @@ from .base import LatentCodec
 
 __all__ = [
     "CheckerboardLatentCodec",
+    "embed",
+    "embed_step",
+    "mask_all",
+    "mask_all_but_step",
+    "merge",
+    "step_parity",
+    "unembed",
+    "write_step",
 ]
 
 
@@ -297,84 +305,134 @@ class CheckerboardLatentCodec(LatentCodec):
         }
 
     def unembed(self, y: Tensor) -> Tensor:
-        """Separate single tensor into two even/odd checkerboard chunks.
-
-        .. code-block:: none
-
-            ■ □ ■ □         ■ ■   □ □
-            □ ■ □ ■   --->  ■ ■   □ □
-            ■ □ ■ □         ■ ■   □ □
-        """
-        n, c, h, w = y.shape
-        y_ = y.new_zeros((2, n, c, h, w // 2))
-        if self.anchor_parity == "even":
-            y_[0, ..., 0::2, :] = y[..., 0::2, 0::2]
-            y_[0, ..., 1::2, :] = y[..., 1::2, 1::2]
-            y_[1, ..., 0::2, :] = y[..., 0::2, 1::2]
-            y_[1, ..., 1::2, :] = y[..., 1::2, 0::2]
-        else:
-            y_[0, ..., 0::2, :] = y[..., 0::2, 1::2]
-            y_[0, ..., 1::2, :] = y[..., 1::2, 0::2]
-            y_[1, ..., 0::2, :] = y[..., 0::2, 0::2]
-            y_[1, ..., 1::2, :] = y[..., 1::2, 1::2]
-        return y_
+        return unembed(y, anchor_parity=self.anchor_parity)
 
     def embed(self, y_: Tensor) -> Tensor:
-        """Combine two even/odd checkerboard chunks into single tensor.
-
-        .. code-block:: none
-
-            ■ ■   □ □         ■ □ ■ □
-            ■ ■   □ □   --->  □ ■ □ ■
-            ■ ■   □ □         ■ □ ■ □
-        """
-        num_chunks, n, c, h, w_half = y_.shape
-        assert num_chunks == 2
-        y = y_.new_zeros((n, c, h, w_half * 2))
-        if self.anchor_parity == "even":
-            y[..., 0::2, 0::2] = y_[0, ..., 0::2, :]
-            y[..., 1::2, 1::2] = y_[0, ..., 1::2, :]
-            y[..., 0::2, 1::2] = y_[1, ..., 0::2, :]
-            y[..., 1::2, 0::2] = y_[1, ..., 1::2, :]
-        else:
-            y[..., 0::2, 1::2] = y_[0, ..., 0::2, :]
-            y[..., 1::2, 0::2] = y_[0, ..., 1::2, :]
-            y[..., 0::2, 0::2] = y_[1, ..., 0::2, :]
-            y[..., 1::2, 1::2] = y_[1, ..., 1::2, :]
-        return y
+        return embed(y_, anchor_parity=self.anchor_parity)
 
     def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
-        """Copy pixels in the current step."""
-        assert step in ("anchor", "non_anchor")
-        parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
-        if parity == "even":
-            dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
-            dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
-        else:
-            dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
-            dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
+        return write_step(dest, src, step, anchor_parity=self.anchor_parity)
 
     def _mask_all_but_step(self, y: Tensor, step: str) -> Tensor:
-        """Keep only pixels in the current step, and zero out the rest."""
-        y = y.clone()
-        parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
-        if parity == "even":
-            y[..., 0::2, 1::2] = 0
-            y[..., 1::2, 0::2] = 0
-        elif parity == "odd":
-            y[..., 0::2, 0::2] = 0
-            y[..., 1::2, 1::2] = 0
-        return y
+        return mask_all_but_step(y, step, anchor_parity=self.anchor_parity)
 
     def _mask_all(self, y: Tensor) -> Tensor:
-        y = y.clone()
-        y[:] = 0
-        return y
+        return mask_all(y)
 
     def merge(self, *args: Tensor) -> Tensor:
-        return torch.cat(args, dim=1)
+        return merge(*args)
 
     def quantize(self, y: Tensor) -> Tensor:
         mode = "noise" if self.training else "dequantize"
         y_hat = EntropyModel.quantize(None, y, mode)
         return y_hat
+
+
+def unembed(y: Tensor, *, anchor_parity: str) -> Tensor:
+    """Separate single tensor into two even/odd checkerboard chunks.
+
+    .. code-block:: none
+
+        ■ □ ■ □         ■ ■   □ □
+        □ ■ □ ■   --->  ■ ■   □ □
+        ■ □ ■ □         ■ ■   □ □
+    """
+    n, c, h, w = y.shape
+    y_ = y.new_zeros((2, n, c, h, w // 2))
+    if anchor_parity == "even":
+        y_[0, ..., 0::2, :] = y[..., 0::2, 0::2]
+        y_[0, ..., 1::2, :] = y[..., 1::2, 1::2]
+        y_[1, ..., 0::2, :] = y[..., 0::2, 1::2]
+        y_[1, ..., 1::2, :] = y[..., 1::2, 0::2]
+    elif anchor_parity == "odd":
+        y_[0, ..., 0::2, :] = y[..., 0::2, 1::2]
+        y_[0, ..., 1::2, :] = y[..., 1::2, 0::2]
+        y_[1, ..., 0::2, :] = y[..., 0::2, 0::2]
+        y_[1, ..., 1::2, :] = y[..., 1::2, 1::2]
+    else:
+        raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+    return y_
+
+
+def embed(y_: Tensor, *, anchor_parity: str) -> Tensor:
+    """Combine two even/odd checkerboard chunks into single tensor.
+
+    .. code-block:: none
+
+        ■ ■   □ □         ■ □ ■ □
+        ■ ■   □ □   --->  □ ■ □ ■
+        ■ ■   □ □         ■ □ ■ □
+    """
+    num_chunks, n, c, h, w_half = y_.shape
+    assert num_chunks == 2
+    y = y_.new_zeros((n, c, h, w_half * 2))
+    if anchor_parity == "even":
+        y[..., 0::2, 0::2] = y_[0, ..., 0::2, :]
+        y[..., 1::2, 1::2] = y_[0, ..., 1::2, :]
+        y[..., 0::2, 1::2] = y_[1, ..., 0::2, :]
+        y[..., 1::2, 0::2] = y_[1, ..., 1::2, :]
+    elif anchor_parity == "odd":
+        y[..., 0::2, 1::2] = y_[0, ..., 0::2, :]
+        y[..., 1::2, 0::2] = y_[0, ..., 1::2, :]
+        y[..., 0::2, 0::2] = y_[1, ..., 0::2, :]
+        y[..., 1::2, 1::2] = y_[1, ..., 1::2, :]
+    else:
+        raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+    return y
+
+
+def embed_step(
+    step_index: int, y_i: Tensor, width: int, *, anchor_parity: str
+) -> Tensor:
+    """Embed a per-step half-width tensor back into a full-grid tensor."""
+    n, c, h, _ = y_i.shape
+    y_ = y_i.new_zeros((2, n, c, h, width // 2))
+    y_[step_index] = y_i
+    return embed(y_, anchor_parity=anchor_parity)
+
+
+def step_parity(step: str, anchor_parity: str) -> str:
+    """Resolve a ``step`` ('anchor' / 'non_anchor') to a parity string."""
+    if anchor_parity not in ("even", "odd"):
+        raise ValueError(f'Invalid anchor_parity "{anchor_parity}"')
+    if step == "anchor":
+        return anchor_parity
+    if step == "non_anchor":
+        return "odd" if anchor_parity == "even" else "even"
+    raise ValueError(f'Invalid "step" value "{step}"')
+
+
+def write_step(dest: Tensor, src: Tensor, step: str, *, anchor_parity: str) -> None:
+    """Copy ``src`` pixels at the current step's positions into ``dest`` in-place."""
+    parity = step_parity(step, anchor_parity)
+    if parity == "even":
+        dest[..., 0::2, 0::2] = src[..., 0::2, 0::2]
+        dest[..., 1::2, 1::2] = src[..., 1::2, 1::2]
+    else:
+        dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
+        dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
+
+
+def mask_all_but_step(y: Tensor, step: str, *, anchor_parity: str) -> Tensor:
+    """Keep only pixels in the current step, and zero out the rest."""
+    y = y.clone()
+    parity = step_parity(step, anchor_parity)
+    if parity == "even":
+        y[..., 0::2, 1::2] = 0
+        y[..., 1::2, 0::2] = 0
+    elif parity == "odd":
+        y[..., 0::2, 0::2] = 0
+        y[..., 1::2, 1::2] = 0
+    return y
+
+
+def mask_all(y: Tensor) -> Tensor:
+    """Return a zero tensor with the same shape, dtype and device as ``y``."""
+    y = y.clone()
+    y[:] = 0
+    return y
+
+
+def merge(*args: Tensor) -> Tensor:
+    """Concatenate tensors along the channel dimension."""
+    return torch.cat(args, dim=1)
diff --git a/compressai/latent_codecs/multi_context_checkerboard.py b/compressai/latent_codecs/multi_context_checkerboard.py
index 3def49f8..ebeee7d6 100644
--- a/compressai/latent_codecs/multi_context_checkerboard.py
+++ b/compressai/latent_codecs/multi_context_checkerboard.py
@@ -38,8 +38,8 @@ from compressai.entropy_models import GaussianConditional
 from compressai.ops import quantize_ste
 from compressai.registry import register_module
 
-from . import _checkerboard_helpers as _ckb
 from . import _selective_checkerboard as _sel
+from . import checkerboard as _ckb
 from .base import LatentCodec
 from .gaussian_conditional import GaussianConditionalLatentCodec
 
commit 173b94487d6119bb97e23d4b65cfdec2bddde072
Author: Mateen Ulhaq <mateen.ulhaq@interdigital.com>
Date:   Mon Jun 29 22:47:10 2026 -0700

    fix(latent_codecs): standardize y.shape[2:4] -> y.shape[1:]

diff --git a/compressai/latent_codecs/_selective_checkerboard.py b/compressai/latent_codecs/_selective_checkerboard.py
index 2674c27a..657b167b 100644
--- a/compressai/latent_codecs/_selective_checkerboard.py
+++ b/compressai/latent_codecs/_selective_checkerboard.py
@@ -190,7 +190,7 @@ def compress_selected(
         y_hat[sample_index].reshape(-1)[mask] = y_hat_i.reshape(-1).to(y_hat.dtype)
         y_strings.append(y_string)
 
-    return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat}
+    return {"strings": [y_strings], "shape": y.shape[1:], "y_hat": y_hat}
 
 
 def decompress_selected(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants