Skip to content

Commit

Permalink
More caching in automorphism group calculation (#210)
Browse files Browse the repository at this point in the history
* cache automorphism group data

* even more caching

* fix cache coverage

* move group generator caching to abstract.py

* linting fix
  • Loading branch information
perlinm authored Jan 29, 2025
1 parent 1918b0c commit 046f119
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 5 deletions.
4 changes: 4 additions & 0 deletions qldpc/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ def from_name(name: str, field: int | None = None) -> Group:
generators = [GroupMember(gen) for gen in external.groups.get_generators(standardized_name)]
return Group(*generators, name=standardized_name, field=field)

def hashable_generators(self) -> tuple[tuple[int, ...], ...]:
"""Generators of this group in a hashable form."""
return tuple(tuple(generator) for generator in self.generators)


################################################################################
# elements of a group algebra
Expand Down
2 changes: 2 additions & 0 deletions qldpc/abstract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_permutation_group() -> None:
gen = galois.GF(2)([[1]])
abstract.Group.from_generating_mats(gen, field=3)

assert isinstance(hash(group.hashable_generators()), int)


def test_trivial_group() -> None:
"""Trivial group tests."""
Expand Down
10 changes: 8 additions & 2 deletions qldpc/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def get_disk_cache(cache_name: str, *, cache_dir: str | None = None) -> diskcach


def use_disk_cache(
cache_name: str, *, cache_dir: str | None = None
cache_name: str,
*,
cache_dir: str | None = None,
key_func: Callable[..., Hashable] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator to cache results to disk."""

Expand All @@ -49,7 +52,10 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
def function_with_cache(*args: Hashable, **kwargs: Hashable) -> Any:
# retrieve results from cache, if available
cache = get_disk_cache(cache_name, cache_dir=cache_dir)
key = args + tuple(kwargs.items())
if key_func is not None:
key = key_func(*args, **kwargs)
else:
key = args + tuple(kwargs.items())
if key in cache:
return cache[key]

Expand Down
7 changes: 7 additions & 0 deletions qldpc/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ def get_five(_: str) -> int:
get_five("test_arg") # save results to cache
assert cache == {("test_arg",): 5} # check cache
assert cache[("test_arg",)] == get_five("test_arg") # retrieve results

@qldpc.cache.use_disk_cache("test_name", key_func=lambda _: None)
def get_six(_: str) -> int:
return 6

assert get_six("test_arg") == 6
assert cache == {("test_arg",): 5, None: 6}
20 changes: 17 additions & 3 deletions qldpc/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import numpy as np
import stim

from qldpc import abstract, codes
from qldpc import abstract, cache, codes
from qldpc.objects import Pauli, conjugate_xz, op_to_string

CACHE_NAME = "qldpc_automorphisms"


def restrict_to_qubits(func: Callable[..., stim.Circuit]) -> Callable[..., stim.Circuit]:
"""Restrict a circuit constructor to qubit-based codes."""
Expand Down Expand Up @@ -242,8 +244,20 @@ def get_transversal_automorphism_group(
group_gates = abstract.Group(*map(abstract.GroupMember, column_perms))

# intersect the groups above to find the group generated by a transversal gate set
group_aut_sympy = group_code.to_sympy().subgroup_search(group_gates.to_sympy().contains)
return abstract.Group.from_sympy(group_aut_sympy, field=code.field.order)
group_aut_gens = _sympy_group_intersection_generators(group_code, group_gates)
return abstract.Group(*map(abstract.GroupMember, group_aut_gens))


@cache.use_disk_cache(
CACHE_NAME,
key_func=lambda xx, yy: (xx.hashable_generators(), yy.hashable_generators()),
)
def _sympy_group_intersection_generators(
group_a: abstract.Group, group_b: abstract.Group
) -> tuple[tuple[int, ...], ...]:
"""Get the generators of the intersection of two Sympy permutation groups."""
group_sympy = group_a.to_sympy().subgroup_search(group_b.to_sympy().contains)
return abstract.Group(group_sympy).hashable_generators()


@restrict_to_qubits
Expand Down

0 comments on commit 046f119

Please sign in to comment.