Skip to content

Commit b6efb03

Browse files
authored
feat(hugr-py): move in result classes from guppylang (#2084)
Guppy PR: CQCL/guppylang#918
1 parent a4f5196 commit b6efb03

File tree

6 files changed

+650
-1
lines changed

6 files changed

+650
-1
lines changed

hugr-py/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737

3838
[project.optional-dependencies]
3939
docs = ["sphinx>=8.1.3,<9.0.0", "sphinx-book-theme>=1.1.2"]
40+
pytket = ["pytket >= 1.34.0"]
4041

4142
[project.urls]
4243
homepage = "https://github.com/CQCL/hugr/tree/main/hugr-py"

hugr-py/src/hugr/qsystem/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Quantinuum system results and utilities."""

hugr-py/src/hugr/qsystem/result.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
r"""Quantinuum system results.
2+
3+
Includes conversions to traditional distributions over bitstrings if a tagging
4+
convention is used, including conversion to a pytket BackendResult.
5+
6+
Under this convention, tags are assumed to be a name of a bit register unless they fit
7+
the regex pattern `^([a-z][\w_]*)\[(\d+)\]$` (like `my_Reg[12]`) in which case they
8+
are assumed to refer to the nth element of a bit register.
9+
10+
For results of the form ``` result("<register>", value) ``` `value` can be `{0, 1}`,
11+
wherein the register is assumed to be length 1, or lists over those values,
12+
wherein the list is taken to be the value of the entire register.
13+
14+
For results of the form ``` result("<register>[n]", value) ``` `value` can only be
15+
`{0,1}`.
16+
The register is assumed to be at least `n+1` in size and unset
17+
elements are assumed to be `0`.
18+
19+
Subsequent writes to the same register/element in the same shot will overwrite.
20+
21+
To convert to a `BackendResult` all registers must be present in all shots, and register
22+
sizes cannot change between shots.
23+
24+
"""
25+
26+
from __future__ import annotations
27+
28+
import re
29+
from collections import Counter, defaultdict
30+
from dataclasses import dataclass, field
31+
from typing import TYPE_CHECKING, Literal
32+
33+
from typing_extensions import deprecated
34+
35+
if TYPE_CHECKING:
36+
from collections.abc import Iterable
37+
38+
from pytket.backends.backendresult import BackendResult
39+
40+
41+
#: Primitive data types that can be returned by a result
42+
DataPrimitive = int | float | bool
43+
#: Data value that can be returned by a result: a primitive or a list of primitives
44+
DataValue = DataPrimitive | list[DataPrimitive]
45+
TaggedResult = tuple[str, DataValue]
46+
# Pattern to match register index in tag, e.g. "reg[0]"
47+
REG_INDEX_PATTERN = re.compile(r"^([a-z][\w_]*)\[(\d+)\]$")
48+
49+
BitChar = Literal["0", "1"]
50+
51+
52+
@dataclass
53+
class QsysShot:
54+
"""Results from a single shot execution."""
55+
56+
entries: list[TaggedResult] = field(default_factory=list)
57+
58+
def __init__(self, entries: Iterable[TaggedResult] | None = None):
59+
self.entries = list(entries or [])
60+
61+
def append(self, tag: str, data: DataValue) -> None:
62+
self.entries.append((tag, data))
63+
64+
def as_dict(self) -> dict[str, DataValue]:
65+
"""Convert results to a dictionary.
66+
67+
For duplicate tags, the last value is used.
68+
69+
Returns:
70+
dict: A dictionary where the keys are the tags and the
71+
values are the data.
72+
73+
Example:
74+
>>> results = QsysShot()
75+
>>> results.append("tag1", 1)
76+
>>> results.append("tag2", 2)
77+
>>> results.append("tag2", 3)
78+
>>> results.as_dict()
79+
{'tag1': 1, 'tag2': 3}
80+
"""
81+
return dict(self.entries)
82+
83+
def to_register_bits(self) -> dict[str, str]:
84+
"""Convert results to a dictionary of register bit values."""
85+
reg_bits: dict[str, list[BitChar]] = {}
86+
87+
res_dict = self.as_dict()
88+
# relies on the fact that dict preserves insertion order
89+
for tag, data in res_dict.items():
90+
match = re.match(REG_INDEX_PATTERN, tag)
91+
if match is not None:
92+
reg_name, reg_index_str = match.groups()
93+
reg_index = int(reg_index_str)
94+
95+
if reg_name not in reg_bits:
96+
# Initialize register counts to False
97+
reg_bits[reg_name] = ["0"] * (reg_index + 1)
98+
bitlst = reg_bits[reg_name]
99+
if reg_index >= len(bitlst):
100+
# Extend register counts with "0"
101+
bitlst += ["0"] * (reg_index - len(bitlst) + 1)
102+
103+
bitlst[reg_index] = _cast_primitive_bit(data)
104+
continue
105+
match data:
106+
case list(vs):
107+
reg_bits[tag] = [_cast_primitive_bit(v) for v in vs]
108+
case _:
109+
reg_bits[tag] = [_cast_primitive_bit(data)]
110+
111+
return {reg: "".join(bits) for reg, bits in reg_bits.items()}
112+
113+
def collate_tags(self) -> dict[str, list[DataValue]]:
114+
"""Collate all the entries with the same tag in to a dictionary with a list
115+
containing all the data for that tag.
116+
"""
117+
tags: dict[str, list[DataValue]] = defaultdict(list)
118+
for tag, data in self.entries:
119+
tags[tag].append(data)
120+
return dict(tags)
121+
122+
123+
@deprecated("Use QsysShot instead.")
124+
class HResult(QsysShot):
125+
"""Deprecated alias for QsysShot."""
126+
127+
128+
def _cast_primitive_bit(data: DataValue) -> BitChar:
129+
if isinstance(data, int) and data in {0, 1}:
130+
return str(data) # type: ignore[return-value]
131+
msg = f"Expected bit data for register value found {data}"
132+
raise ValueError(msg)
133+
134+
135+
@dataclass
136+
class QsysResult:
137+
"""Results accumulated over multiple shots."""
138+
139+
results: list[QsysShot]
140+
141+
def __init__(
142+
self, results: Iterable[QsysShot | Iterable[TaggedResult]] | None = None
143+
):
144+
self.results = [
145+
res if isinstance(res, QsysShot) else QsysShot(res) for res in results or []
146+
]
147+
148+
def register_counts(
149+
self, strict_names: bool = False, strict_lengths: bool = False
150+
) -> dict[str, Counter[str]]:
151+
"""Convert results to a dictionary of register counts.
152+
153+
Returns:
154+
dict: A dictionary where the keys are the register names
155+
and the values are the counts of the register bitstrings.
156+
"""
157+
return {
158+
reg: Counter(bitstrs)
159+
for reg, bitstrs in self.register_bitstrings(
160+
strict_lengths=strict_lengths, strict_names=strict_names
161+
).items()
162+
}
163+
164+
def register_bitstrings(
165+
self, strict_names: bool = False, strict_lengths: bool = False
166+
) -> dict[str, list[str]]:
167+
"""Convert results to a dictionary from register name to list of bitstrings over
168+
the shots.
169+
170+
Args:
171+
strict_names: Whether to enforce that all shots have the same
172+
registers.
173+
strict_lengths: Whether to enforce that all register bitstrings have
174+
the same length.
175+
176+
"""
177+
shot_dct: dict[str, list[str]] = defaultdict(list)
178+
for shot in self.results:
179+
bitstrs = shot.to_register_bits()
180+
for reg, bitstr in bitstrs.items():
181+
if (
182+
strict_lengths
183+
and reg in shot_dct
184+
and len(shot_dct[reg][0]) != len(bitstr)
185+
):
186+
msg = "All register bitstrings must have the same length."
187+
raise ValueError(msg)
188+
shot_dct[reg].append(bitstr)
189+
if strict_names and bitstrs.keys() != shot_dct.keys():
190+
msg = "All shots must have the same registers."
191+
raise ValueError(msg)
192+
return dict(shot_dct)
193+
194+
def to_pytket(self) -> BackendResult:
195+
"""Convert results to a pytket BackendResult.
196+
197+
Returns:
198+
BackendResult: A BackendResult object with the shots.
199+
200+
Raises:
201+
ImportError: If pytket is not installed.
202+
ValueError: If a register's bitstrings have different lengths or not all
203+
registers are present in all shots.
204+
"""
205+
try:
206+
from pytket._tket.unit_id import Bit
207+
from pytket.backends.backendresult import BackendResult
208+
from pytket.utils.outcomearray import OutcomeArray
209+
except ImportError as e:
210+
msg = "Pytket is an optional dependency, install with the `pytket` extra"
211+
raise ImportError(msg) from e
212+
reg_shots = self.register_bitstrings(strict_lengths=True, strict_names=True)
213+
reg_sizes: dict[str, int] = {
214+
reg: len(next(iter(reg_shots[reg]), "")) for reg in reg_shots
215+
}
216+
registers = list(reg_shots.keys())
217+
bits = [Bit(reg, i) for reg in registers for i in range(reg_sizes[reg])]
218+
219+
int_shots = [
220+
int("".join(reg_shots[reg][i] for reg in registers), 2)
221+
for i in range(len(self.results))
222+
]
223+
return BackendResult(
224+
shots=OutcomeArray.from_ints(int_shots, width=len(bits)), c_bits=bits
225+
)
226+
227+
def _collated_shots_iter(self) -> Iterable[dict[str, list[DataValue]]]:
228+
for shot in self.results:
229+
yield shot.collate_tags()
230+
231+
def collated_shots(self) -> list[dict[str, list[DataValue]]]:
232+
"""For each shot generate a dictionary of tags to collated data."""
233+
return list(self._collated_shots_iter())
234+
235+
def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]:
236+
"""Calculate counts of bit strings for each tag by collating across shots using
237+
`QsysResult.tag_collated_shots`. Each `result` entry per shot is seen to be
238+
appending to the bitstring for that tag.
239+
240+
If the result value is a list, it is flattened and appended to the bitstring.
241+
242+
Example:
243+
>>> shots = [QsysShot([("a", 1), ("a", 0)]), QsysShot([("a", [0, 1])])]
244+
>>> res = QsysResult(shots)
245+
>>> res.collated_counts()
246+
Counter({(('a', '10'),): 1, (('a', '01'),): 1})
247+
248+
Raises:
249+
ValueError: If any value is a float.
250+
"""
251+
return Counter(
252+
tuple((tag, _flat_bitstring(data)) for tag, data in d.items())
253+
for d in self._collated_shots_iter()
254+
)
255+
256+
257+
@deprecated("Use QsysResult instead.")
258+
class HShots(QsysResult):
259+
"""Deprecated alias for QsysResult."""
260+
261+
262+
def _flat_bitstring(data: Iterable[DataValue]) -> str:
263+
return "".join(_cast_primitive_bit(prim) for prim in _flatten(data))
264+
265+
266+
def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]:
267+
for i in itr:
268+
if isinstance(i, list):
269+
yield from _flatten(i)
270+
else:
271+
yield i

0 commit comments

Comments
 (0)