Skip to content

Commit

Permalink
minor bugfix for balanced CSS codes and warnings about runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
perlinm committed Jan 30, 2025
1 parent 3dbbb98 commit 233b374
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
9 changes: 8 additions & 1 deletion checks/pytest_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import checks_superstaq

EXCLUDE = ("checks/*.py", "experiments/*.py", "*/__init__.py", "docs/source/conf.py")
EXCLUDE = (
"*/__init__.py",
"checks/*.py",
"examples/*.py",
"experiments/*.py",
"build-cython.py",
"docs/source/conf.py",
)

if __name__ == "__main__":
exit(checks_superstaq.pytest_.run(*sys.argv[1:], exclude=EXCLUDE))
8 changes: 8 additions & 0 deletions qldpc/codes/_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ cdef extern from *:
"""

# Python imports
import warnings
from typing import Iterator

import numpy as np
cimport numpy as cnp

Expand Down Expand Up @@ -66,6 +68,9 @@ cdef cnp.ndarray[cnp.uint64_t, ndim=2] rows_to_uint64(

def get_distance_classical(cnp.ndarray[cnp.uint8_t, ndim=2] generator) -> int:
"""Distance of a classical linear binary code."""
if generator.shape[0] > 30:
warnings.warn("Computing the exact distance of a large code may take a (very) long time")

cdef uint64_t num_bits = generator.shape[1]
if num_bits <= 64:
return get_distance_classical_64(generator)
Expand Down Expand Up @@ -168,6 +173,9 @@ def get_distance_quantum(
X and Z support of the corresponding Pauli string. The weight of a Pauli string is then the
symplectic weight of the corresponding bitstring.
"""
if logical_ops.shape[0] + stabilizers.shape[0] > 30:
warnings.warn("Computing the exact distance of a large code may take a (very) long time")

cdef uint64_t num_bits = logical_ops.shape[1]

if not homogeneous:
Expand Down
22 changes: 18 additions & 4 deletions qldpc/codes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import itertools
import math
import random
import warnings
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Iterator, Literal, cast

Expand Down Expand Up @@ -353,6 +354,9 @@ def get_distance_exact(
if self.field.order == 2:
distance = get_distance_classical(self.generator.view(np.ndarray).astype(np.uint8))
else:
warnings.warn(
"Computing the exact distance of a non-binary code may take a (very) long time"
)
distance = min(np.count_nonzero(word) for word in self.iter_words(skip_zero=True))
self._exact_distance = int(distance)
return self._exact_distance
Expand Down Expand Up @@ -903,8 +907,17 @@ def get_distance_exact(self) -> int | float:
stabilizers.view(np.ndarray).astype(np.uint8),
)
else:
warnings.warn(
"Computing the exact distance of a non-binary code may take a (very) long time"
)
distance = len(self)
for word in ClassicalCode(self.matrix).iter_words(skip_zero=True):
code_logical_ops = ClassicalCode.from_generator(self.get_logical_ops())
code_stabilizers = ClassicalCode.from_generator(self.matrix)
for word_l, word_s in itertools.product(
code_logical_ops.iter_words(skip_zero=True),
code_stabilizers.iter_words(),
):
word = word_l + word_s
support_x = word[: len(self)].view(np.ndarray)
support_z = word[len(self) :].view(np.ndarray)
support = support_x + support_z # nonzero wherever a word addresses a qudit
Expand Down Expand Up @@ -1326,9 +1339,7 @@ def rank(self) -> int:
Equivalently, the number of linearly independent parity checks in this code.
"""
rank_x = self.code_x.rank
rank_z = rank_x if self._balanced_codes else self.code_z.rank
return rank_x + rank_z
return self.code_x.rank + self.code_z.rank

def get_distance(
self, pauli: PauliXZ | None = None, *, bound: int | bool | None = None, **decoder_args: Any
Expand Down Expand Up @@ -1369,6 +1380,9 @@ def get_distance_exact(self, pauli: PauliXZ | None = None) -> int | float:
homogeneous=True,
)
else:
warnings.warn(
"Computing the exact distance of a non-binary code may take a (very) long time"
)
code_x = self.code_x if pauli == Pauli.X else self.code_z
code_z = self.code_z if pauli == Pauli.X else self.code_x
dual_code_x = ~code_x
Expand Down
9 changes: 6 additions & 3 deletions qldpc/codes/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def test_distance_classical(bits: int = 3) -> None:
# compute distance of a trinary repetition code
rep_code = codes.RepetitionCode(bits, field=3)
rep_code._exact_distance = None
assert rep_code.get_distance_exact() == 3
with pytest.warns(UserWarning, match="may take a very long time"):
assert rep_code.get_distance_exact() == 3


def test_conversions_classical(bits: int = 5, checks: int = 3) -> None:
Expand Down Expand Up @@ -308,7 +309,8 @@ def test_distance_qudit() -> None:
# fallback pythonic brute-force distance calculation
surface_code = codes.SurfaceCode(2, field=3)
surface_code._exact_distance_x = surface_code._exact_distance_z = None
assert codes.QuditCode.get_distance_exact(surface_code) == 2
with pytest.warns(UserWarning, match="may take a very long time"):
assert codes.QuditCode.get_distance_exact(surface_code) == 2


@pytest.mark.parametrize("field", [2, 3])
Expand Down Expand Up @@ -462,7 +464,8 @@ def test_distance_css() -> None:
code = codes.HGPCode(codes.RepetitionCode(2, field=3))
assert code.get_distance_bound(cutoff=len(code)) == len(code)
assert code.get_distance(bound=True) <= len(code)
assert code.get_distance(bound=False) == 2
with pytest.warns(UserWarning, match="may take a very long time"):
assert code.get_distance(bound=False) == 2

# qubit code distance
code = codes.HGPCode(codes.RepetitionCode(2, field=2))
Expand Down

0 comments on commit 233b374

Please sign in to comment.