|
14 | 14 |
|
15 | 15 | """Defines the API for circuit transformers in Cirq."""
|
16 | 16 |
|
17 |
| -import textwrap |
| 17 | +import dataclasses |
| 18 | +import enum |
18 | 19 | import functools
|
| 20 | +import textwrap |
19 | 21 | from typing import (
|
20 | 22 | Any,
|
21 |
| - Callable, |
22 | 23 | Tuple,
|
23 | 24 | Hashable,
|
24 | 25 | List,
|
25 |
| - Type, |
26 | 26 | overload,
|
| 27 | + Type, |
27 | 28 | TYPE_CHECKING,
|
| 29 | + TypeVar, |
28 | 30 | )
|
29 |
| -import dataclasses |
30 |
| -import enum |
31 |
| -from cirq.circuits.circuit import CIRCUIT_TYPE |
| 31 | +from typing_extensions import Protocol |
32 | 32 |
|
33 | 33 | if TYPE_CHECKING:
|
34 | 34 | import cirq
|
@@ -218,96 +218,95 @@ class TransformerContext:
|
218 | 218 | ignore_tags: Tuple[Hashable, ...] = ()
|
219 | 219 |
|
220 | 220 |
|
221 |
| -TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit'] |
222 |
| -_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE] |
223 |
| - |
224 |
| - |
225 |
| -def _transform_and_log( |
226 |
| - func: _TRANSFORMER_TYPE[CIRCUIT_TYPE], |
227 |
| - transformer_name: str, |
228 |
| - circuit: 'cirq.AbstractCircuit', |
229 |
| - context: TransformerContext, |
230 |
| -) -> CIRCUIT_TYPE: |
231 |
| - """Helper to log initial and final circuits before and after calling the transformer.""" |
232 |
| - |
233 |
| - context.logger.register_initial(circuit, transformer_name) |
234 |
| - transformed_circuit = func(circuit, context) |
235 |
| - context.logger.register_final(transformed_circuit, transformer_name) |
236 |
| - return transformed_circuit |
237 |
| - |
| 221 | +class TRANSFORMER(Protocol): |
| 222 | + def __call__( |
| 223 | + self, circuit: 'cirq.AbstractCircuit', context: TransformerContext |
| 224 | + ) -> 'cirq.AbstractCircuit': |
| 225 | + ... |
238 | 226 |
|
239 |
| -def _transformer_class( |
240 |
| - cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
241 |
| -) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: |
242 |
| - old_func = cls.__call__ |
243 | 227 |
|
244 |
| - def transformer_with_logging_cls( |
245 |
| - self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
246 |
| - circuit: 'cirq.AbstractCircuit', |
247 |
| - context: TransformerContext, |
248 |
| - ) -> CIRCUIT_TYPE: |
249 |
| - def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE: |
250 |
| - return old_func(self, c, ct) |
251 |
| - |
252 |
| - return _transform_and_log(call_old_func, cls.__name__, circuit, context) |
253 |
| - |
254 |
| - setattr(cls, '__call__', transformer_with_logging_cls) |
255 |
| - return cls |
256 |
| - |
257 |
| - |
258 |
| -def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: |
259 |
| - @functools.wraps(func) |
260 |
| - def transformer_with_logging_func( |
261 |
| - circuit: 'cirq.AbstractCircuit', |
262 |
| - context: TransformerContext, |
263 |
| - ) -> CIRCUIT_TYPE: |
264 |
| - return _transform_and_log(func, func.__name__, circuit, context) |
265 |
| - |
266 |
| - return transformer_with_logging_func |
| 228 | +_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER) |
| 229 | +_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER]) |
267 | 230 |
|
268 | 231 |
|
269 | 232 | @overload
|
270 |
| -def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: |
| 233 | +def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T: |
271 | 234 | pass
|
272 | 235 |
|
273 | 236 |
|
274 | 237 | @overload
|
275 |
| -def transformer( |
276 |
| - cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
277 |
| -) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: |
| 238 | +def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T: |
278 | 239 | pass
|
279 | 240 |
|
280 | 241 |
|
281 | 242 | def transformer(cls_or_func: Any) -> Any:
|
282 | 243 | """Decorator to verify API and append logging functionality to transformer functions & classes.
|
283 | 244 |
|
284 |
| - The decorated function or class must satisfy |
285 |
| - `Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example: |
| 245 | + A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and |
| 246 | + `cirq.TransformerContext`, and returns another `cirq.AbstractCircuit` without |
| 247 | + modifying the input circuit. A transformer could be a function, for example: |
286 | 248 |
|
287 | 249 | >>> @cirq.transformer
|
288 |
| - >>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit: |
| 250 | + >>> def convert_to_cz( |
| 251 | + >>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext |
| 252 | + >>> ) -> cirq.Circuit: |
289 | 253 | >>> ...
|
290 | 254 |
|
291 |
| - The decorated class must implement the `__call__` method to satisfy the above API. |
| 255 | + Or it could be a class that implements `__call__` with the same API, for example: |
292 | 256 |
|
293 | 257 | >>> @cirq.transformer
|
294 | 258 | >>> class ConvertToSqrtISwaps:
|
295 | 259 | >>> def __init__(self):
|
296 | 260 | >>> ...
|
297 | 261 | >>> def __call__(
|
298 |
| - >>> self, circuit: cirq.Circuit, context: cirq.TransformerContext |
| 262 | + >>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext |
299 | 263 | >>> ) -> cirq.Circuit:
|
300 | 264 | >>> ...
|
301 | 265 |
|
302 | 266 | Args:
|
303 |
| - cls_or_func: The callable class or method to be decorated. |
| 267 | + cls_or_func: The callable class or function to be decorated. |
304 | 268 |
|
305 | 269 | Returns:
|
306 |
| - Decorated class / method which includes additional logging boilerplate. The decorated |
307 |
| - callable always receives a copy of the input circuit so that the input is never mutated. |
| 270 | + Decorated class / function which includes additional logging boilerplate. |
308 | 271 | """
|
309 | 272 | if isinstance(cls_or_func, type):
|
310 |
| - return _transformer_class(cls_or_func) |
| 273 | + cls = cls_or_func |
| 274 | + method = cls.__call__ |
| 275 | + |
| 276 | + @functools.wraps(method) |
| 277 | + def method_with_logging( |
| 278 | + self, circuit: 'cirq.AbstractCircuit', context: TransformerContext |
| 279 | + ) -> 'cirq.AbstractCircuit': |
| 280 | + return _transform_and_log( |
| 281 | + lambda circuit, context: method(self, circuit, context), |
| 282 | + cls.__name__, |
| 283 | + circuit, |
| 284 | + context, |
| 285 | + ) |
| 286 | + |
| 287 | + setattr(cls, '__call__', method_with_logging) |
| 288 | + return cls |
311 | 289 | else:
|
312 | 290 | assert callable(cls_or_func)
|
313 |
| - return _transformer_func(cls_or_func) |
| 291 | + func = cls_or_func |
| 292 | + |
| 293 | + @functools.wraps(func) |
| 294 | + def func_with_logging( |
| 295 | + circuit: 'cirq.AbstractCircuit', context: TransformerContext |
| 296 | + ) -> 'cirq.AbstractCircuit': |
| 297 | + return _transform_and_log(func, func.__name__, circuit, context) |
| 298 | + |
| 299 | + return func_with_logging |
| 300 | + |
| 301 | + |
| 302 | +def _transform_and_log( |
| 303 | + func: TRANSFORMER, |
| 304 | + transformer_name: str, |
| 305 | + circuit: 'cirq.AbstractCircuit', |
| 306 | + context: TransformerContext, |
| 307 | +) -> 'cirq.AbstractCircuit': |
| 308 | + """Helper to log initial and final circuits before and after calling the transformer.""" |
| 309 | + context.logger.register_initial(circuit, transformer_name) |
| 310 | + transformed_circuit = func(circuit, context) |
| 311 | + context.logger.register_final(transformed_circuit, transformer_name) |
| 312 | + return transformed_circuit |
0 commit comments