|
| 1 | +r"""Quantinuum system results. |
| 2 | +
|
| 3 | +Includes conversions to traditional distributions over bitstrings if a tagging |
| 4 | +convention is used, including conversion to a pytket BackendResult. |
| 5 | +
|
| 6 | +Under this convention, tags are assumed to be a name of a bit register unless they fit |
| 7 | +the regex pattern `^([a-z][\w_]*)\[(\d+)\]$` (like `my_Reg[12]`) in which case they |
| 8 | +are assumed to refer to the nth element of a bit register. |
| 9 | +
|
| 10 | +For results of the form ``` result("<register>", value) ``` `value` can be `{0, 1}`, |
| 11 | +wherein the register is assumed to be length 1, or lists over those values, |
| 12 | +wherein the list is taken to be the value of the entire register. |
| 13 | +
|
| 14 | +For results of the form ``` result("<register>[n]", value) ``` `value` can only be |
| 15 | +`{0,1}`. |
| 16 | +The register is assumed to be at least `n+1` in size and unset |
| 17 | +elements are assumed to be `0`. |
| 18 | +
|
| 19 | +Subsequent writes to the same register/element in the same shot will overwrite. |
| 20 | +
|
| 21 | +To convert to a `BackendResult` all registers must be present in all shots, and register |
| 22 | +sizes cannot change between shots. |
| 23 | +
|
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import annotations |
| 27 | + |
| 28 | +import re |
| 29 | +from collections import Counter, defaultdict |
| 30 | +from dataclasses import dataclass, field |
| 31 | +from typing import TYPE_CHECKING, Literal |
| 32 | + |
| 33 | +from typing_extensions import deprecated |
| 34 | + |
| 35 | +if TYPE_CHECKING: |
| 36 | + from collections.abc import Iterable |
| 37 | + |
| 38 | + from pytket.backends.backendresult import BackendResult |
| 39 | + |
| 40 | + |
| 41 | +#: Primitive data types that can be returned by a result |
| 42 | +DataPrimitive = int | float | bool |
| 43 | +#: Data value that can be returned by a result: a primitive or a list of primitives |
| 44 | +DataValue = DataPrimitive | list[DataPrimitive] |
| 45 | +TaggedResult = tuple[str, DataValue] |
| 46 | +# Pattern to match register index in tag, e.g. "reg[0]" |
| 47 | +REG_INDEX_PATTERN = re.compile(r"^([a-z][\w_]*)\[(\d+)\]$") |
| 48 | + |
| 49 | +BitChar = Literal["0", "1"] |
| 50 | + |
| 51 | + |
| 52 | +@dataclass |
| 53 | +class QsysShot: |
| 54 | + """Results from a single shot execution.""" |
| 55 | + |
| 56 | + entries: list[TaggedResult] = field(default_factory=list) |
| 57 | + |
| 58 | + def __init__(self, entries: Iterable[TaggedResult] | None = None): |
| 59 | + self.entries = list(entries or []) |
| 60 | + |
| 61 | + def append(self, tag: str, data: DataValue) -> None: |
| 62 | + self.entries.append((tag, data)) |
| 63 | + |
| 64 | + def as_dict(self) -> dict[str, DataValue]: |
| 65 | + """Convert results to a dictionary. |
| 66 | +
|
| 67 | + For duplicate tags, the last value is used. |
| 68 | +
|
| 69 | + Returns: |
| 70 | + dict: A dictionary where the keys are the tags and the |
| 71 | + values are the data. |
| 72 | +
|
| 73 | + Example: |
| 74 | + >>> results = QsysShot() |
| 75 | + >>> results.append("tag1", 1) |
| 76 | + >>> results.append("tag2", 2) |
| 77 | + >>> results.append("tag2", 3) |
| 78 | + >>> results.as_dict() |
| 79 | + {'tag1': 1, 'tag2': 3} |
| 80 | + """ |
| 81 | + return dict(self.entries) |
| 82 | + |
| 83 | + def to_register_bits(self) -> dict[str, str]: |
| 84 | + """Convert results to a dictionary of register bit values.""" |
| 85 | + reg_bits: dict[str, list[BitChar]] = {} |
| 86 | + |
| 87 | + res_dict = self.as_dict() |
| 88 | + # relies on the fact that dict preserves insertion order |
| 89 | + for tag, data in res_dict.items(): |
| 90 | + match = re.match(REG_INDEX_PATTERN, tag) |
| 91 | + if match is not None: |
| 92 | + reg_name, reg_index_str = match.groups() |
| 93 | + reg_index = int(reg_index_str) |
| 94 | + |
| 95 | + if reg_name not in reg_bits: |
| 96 | + # Initialize register counts to False |
| 97 | + reg_bits[reg_name] = ["0"] * (reg_index + 1) |
| 98 | + bitlst = reg_bits[reg_name] |
| 99 | + if reg_index >= len(bitlst): |
| 100 | + # Extend register counts with "0" |
| 101 | + bitlst += ["0"] * (reg_index - len(bitlst) + 1) |
| 102 | + |
| 103 | + bitlst[reg_index] = _cast_primitive_bit(data) |
| 104 | + continue |
| 105 | + match data: |
| 106 | + case list(vs): |
| 107 | + reg_bits[tag] = [_cast_primitive_bit(v) for v in vs] |
| 108 | + case _: |
| 109 | + reg_bits[tag] = [_cast_primitive_bit(data)] |
| 110 | + |
| 111 | + return {reg: "".join(bits) for reg, bits in reg_bits.items()} |
| 112 | + |
| 113 | + def collate_tags(self) -> dict[str, list[DataValue]]: |
| 114 | + """Collate all the entries with the same tag in to a dictionary with a list |
| 115 | + containing all the data for that tag. |
| 116 | + """ |
| 117 | + tags: dict[str, list[DataValue]] = defaultdict(list) |
| 118 | + for tag, data in self.entries: |
| 119 | + tags[tag].append(data) |
| 120 | + return dict(tags) |
| 121 | + |
| 122 | + |
| 123 | +@deprecated("Use QsysShot instead.") |
| 124 | +class HResult(QsysShot): |
| 125 | + """Deprecated alias for QsysShot.""" |
| 126 | + |
| 127 | + |
| 128 | +def _cast_primitive_bit(data: DataValue) -> BitChar: |
| 129 | + if isinstance(data, int) and data in {0, 1}: |
| 130 | + return str(data) # type: ignore[return-value] |
| 131 | + msg = f"Expected bit data for register value found {data}" |
| 132 | + raise ValueError(msg) |
| 133 | + |
| 134 | + |
| 135 | +@dataclass |
| 136 | +class QsysResult: |
| 137 | + """Results accumulated over multiple shots.""" |
| 138 | + |
| 139 | + results: list[QsysShot] |
| 140 | + |
| 141 | + def __init__( |
| 142 | + self, results: Iterable[QsysShot | Iterable[TaggedResult]] | None = None |
| 143 | + ): |
| 144 | + self.results = [ |
| 145 | + res if isinstance(res, QsysShot) else QsysShot(res) for res in results or [] |
| 146 | + ] |
| 147 | + |
| 148 | + def register_counts( |
| 149 | + self, strict_names: bool = False, strict_lengths: bool = False |
| 150 | + ) -> dict[str, Counter[str]]: |
| 151 | + """Convert results to a dictionary of register counts. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + dict: A dictionary where the keys are the register names |
| 155 | + and the values are the counts of the register bitstrings. |
| 156 | + """ |
| 157 | + return { |
| 158 | + reg: Counter(bitstrs) |
| 159 | + for reg, bitstrs in self.register_bitstrings( |
| 160 | + strict_lengths=strict_lengths, strict_names=strict_names |
| 161 | + ).items() |
| 162 | + } |
| 163 | + |
| 164 | + def register_bitstrings( |
| 165 | + self, strict_names: bool = False, strict_lengths: bool = False |
| 166 | + ) -> dict[str, list[str]]: |
| 167 | + """Convert results to a dictionary from register name to list of bitstrings over |
| 168 | + the shots. |
| 169 | +
|
| 170 | + Args: |
| 171 | + strict_names: Whether to enforce that all shots have the same |
| 172 | + registers. |
| 173 | + strict_lengths: Whether to enforce that all register bitstrings have |
| 174 | + the same length. |
| 175 | +
|
| 176 | + """ |
| 177 | + shot_dct: dict[str, list[str]] = defaultdict(list) |
| 178 | + for shot in self.results: |
| 179 | + bitstrs = shot.to_register_bits() |
| 180 | + for reg, bitstr in bitstrs.items(): |
| 181 | + if ( |
| 182 | + strict_lengths |
| 183 | + and reg in shot_dct |
| 184 | + and len(shot_dct[reg][0]) != len(bitstr) |
| 185 | + ): |
| 186 | + msg = "All register bitstrings must have the same length." |
| 187 | + raise ValueError(msg) |
| 188 | + shot_dct[reg].append(bitstr) |
| 189 | + if strict_names and bitstrs.keys() != shot_dct.keys(): |
| 190 | + msg = "All shots must have the same registers." |
| 191 | + raise ValueError(msg) |
| 192 | + return dict(shot_dct) |
| 193 | + |
| 194 | + def to_pytket(self) -> BackendResult: |
| 195 | + """Convert results to a pytket BackendResult. |
| 196 | +
|
| 197 | + Returns: |
| 198 | + BackendResult: A BackendResult object with the shots. |
| 199 | +
|
| 200 | + Raises: |
| 201 | + ImportError: If pytket is not installed. |
| 202 | + ValueError: If a register's bitstrings have different lengths or not all |
| 203 | + registers are present in all shots. |
| 204 | + """ |
| 205 | + try: |
| 206 | + from pytket._tket.unit_id import Bit |
| 207 | + from pytket.backends.backendresult import BackendResult |
| 208 | + from pytket.utils.outcomearray import OutcomeArray |
| 209 | + except ImportError as e: |
| 210 | + msg = "Pytket is an optional dependency, install with the `pytket` extra" |
| 211 | + raise ImportError(msg) from e |
| 212 | + reg_shots = self.register_bitstrings(strict_lengths=True, strict_names=True) |
| 213 | + reg_sizes: dict[str, int] = { |
| 214 | + reg: len(next(iter(reg_shots[reg]), "")) for reg in reg_shots |
| 215 | + } |
| 216 | + registers = list(reg_shots.keys()) |
| 217 | + bits = [Bit(reg, i) for reg in registers for i in range(reg_sizes[reg])] |
| 218 | + |
| 219 | + int_shots = [ |
| 220 | + int("".join(reg_shots[reg][i] for reg in registers), 2) |
| 221 | + for i in range(len(self.results)) |
| 222 | + ] |
| 223 | + return BackendResult( |
| 224 | + shots=OutcomeArray.from_ints(int_shots, width=len(bits)), c_bits=bits |
| 225 | + ) |
| 226 | + |
| 227 | + def _collated_shots_iter(self) -> Iterable[dict[str, list[DataValue]]]: |
| 228 | + for shot in self.results: |
| 229 | + yield shot.collate_tags() |
| 230 | + |
| 231 | + def collated_shots(self) -> list[dict[str, list[DataValue]]]: |
| 232 | + """For each shot generate a dictionary of tags to collated data.""" |
| 233 | + return list(self._collated_shots_iter()) |
| 234 | + |
| 235 | + def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]: |
| 236 | + """Calculate counts of bit strings for each tag by collating across shots using |
| 237 | + `QsysResult.tag_collated_shots`. Each `result` entry per shot is seen to be |
| 238 | + appending to the bitstring for that tag. |
| 239 | +
|
| 240 | + If the result value is a list, it is flattened and appended to the bitstring. |
| 241 | +
|
| 242 | + Example: |
| 243 | + >>> shots = [QsysShot([("a", 1), ("a", 0)]), QsysShot([("a", [0, 1])])] |
| 244 | + >>> res = QsysResult(shots) |
| 245 | + >>> res.collated_counts() |
| 246 | + Counter({(('a', '10'),): 1, (('a', '01'),): 1}) |
| 247 | +
|
| 248 | + Raises: |
| 249 | + ValueError: If any value is a float. |
| 250 | + """ |
| 251 | + return Counter( |
| 252 | + tuple((tag, _flat_bitstring(data)) for tag, data in d.items()) |
| 253 | + for d in self._collated_shots_iter() |
| 254 | + ) |
| 255 | + |
| 256 | + |
| 257 | +@deprecated("Use QsysResult instead.") |
| 258 | +class HShots(QsysResult): |
| 259 | + """Deprecated alias for QsysResult.""" |
| 260 | + |
| 261 | + |
| 262 | +def _flat_bitstring(data: Iterable[DataValue]) -> str: |
| 263 | + return "".join(_cast_primitive_bit(prim) for prim in _flatten(data)) |
| 264 | + |
| 265 | + |
| 266 | +def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]: |
| 267 | + for i in itr: |
| 268 | + if isinstance(i, list): |
| 269 | + yield from _flatten(i) |
| 270 | + else: |
| 271 | + yield i |
0 commit comments