Skip to content

Commit c77db58

Browse files
maffootanujkhattar
authored andcommitted
Use a Protocol for TRANSFORMER to ensure common arg names (quantumlib#4871)
* Use a Protocol for TRANSFORMER to ensure common arg names Also cleans up some of the internals of the transformer decorator and simplifies the types. Follow-up to quantumlib#4797 * Fix Protocol import for 3.7 * Fixes from review * Add type annotations in transformer implementation Co-authored-by: Tanuj Khattar <[email protected]>
1 parent a41481c commit c77db58

File tree

1 file changed

+63
-64
lines changed

1 file changed

+63
-64
lines changed

cirq-core/cirq/transformers/transformer_api.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@
1414

1515
"""Defines the API for circuit transformers in Cirq."""
1616

17-
import textwrap
17+
import dataclasses
18+
import enum
1819
import functools
20+
import textwrap
1921
from typing import (
2022
Any,
21-
Callable,
2223
Tuple,
2324
Hashable,
2425
List,
25-
Type,
2626
overload,
27+
Type,
2728
TYPE_CHECKING,
29+
TypeVar,
2830
)
29-
import dataclasses
30-
import enum
31-
from cirq.circuits.circuit import CIRCUIT_TYPE
31+
from typing_extensions import Protocol
3232

3333
if TYPE_CHECKING:
3434
import cirq
@@ -218,96 +218,95 @@ class TransformerContext:
218218
ignore_tags: Tuple[Hashable, ...] = ()
219219

220220

221-
TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
222-
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]
223-
224-
225-
def _transform_and_log(
226-
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
227-
transformer_name: str,
228-
circuit: 'cirq.AbstractCircuit',
229-
context: TransformerContext,
230-
) -> CIRCUIT_TYPE:
231-
"""Helper to log initial and final circuits before and after calling the transformer."""
232-
233-
context.logger.register_initial(circuit, transformer_name)
234-
transformed_circuit = func(circuit, context)
235-
context.logger.register_final(transformed_circuit, transformer_name)
236-
return transformed_circuit
237-
221+
class TRANSFORMER(Protocol):
222+
def __call__(
223+
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
224+
) -> 'cirq.AbstractCircuit':
225+
...
238226

239-
def _transformer_class(
240-
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
241-
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
242-
old_func = cls.__call__
243227

244-
def transformer_with_logging_cls(
245-
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
246-
circuit: 'cirq.AbstractCircuit',
247-
context: TransformerContext,
248-
) -> CIRCUIT_TYPE:
249-
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
250-
return old_func(self, c, ct)
251-
252-
return _transform_and_log(call_old_func, cls.__name__, circuit, context)
253-
254-
setattr(cls, '__call__', transformer_with_logging_cls)
255-
return cls
256-
257-
258-
def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
259-
@functools.wraps(func)
260-
def transformer_with_logging_func(
261-
circuit: 'cirq.AbstractCircuit',
262-
context: TransformerContext,
263-
) -> CIRCUIT_TYPE:
264-
return _transform_and_log(func, func.__name__, circuit, context)
265-
266-
return transformer_with_logging_func
228+
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
229+
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
267230

268231

269232
@overload
270-
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
233+
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
271234
pass
272235

273236

274237
@overload
275-
def transformer(
276-
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
277-
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
238+
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
278239
pass
279240

280241

281242
def transformer(cls_or_func: Any) -> Any:
282243
"""Decorator to verify API and append logging functionality to transformer functions & classes.
283244
284-
The decorated function or class must satisfy
285-
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
245+
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
246+
`cirq.TransformerContext`, and returns another `cirq.AbstractCircuit` without
247+
modifying the input circuit. A transformer could be a function, for example:
286248
287249
>>> @cirq.transformer
288-
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
250+
>>> def convert_to_cz(
251+
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
252+
>>> ) -> cirq.Circuit:
289253
>>> ...
290254
291-
The decorated class must implement the `__call__` method to satisfy the above API.
255+
Or it could be a class that implements `__call__` with the same API, for example:
292256
293257
>>> @cirq.transformer
294258
>>> class ConvertToSqrtISwaps:
295259
>>> def __init__(self):
296260
>>> ...
297261
>>> def __call__(
298-
>>> self, circuit: cirq.Circuit, context: cirq.TransformerContext
262+
>>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
299263
>>> ) -> cirq.Circuit:
300264
>>> ...
301265
302266
Args:
303-
cls_or_func: The callable class or method to be decorated.
267+
cls_or_func: The callable class or function to be decorated.
304268
305269
Returns:
306-
Decorated class / method which includes additional logging boilerplate. The decorated
307-
callable always receives a copy of the input circuit so that the input is never mutated.
270+
Decorated class / function which includes additional logging boilerplate.
308271
"""
309272
if isinstance(cls_or_func, type):
310-
return _transformer_class(cls_or_func)
273+
cls = cls_or_func
274+
method = cls.__call__
275+
276+
@functools.wraps(method)
277+
def method_with_logging(
278+
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
279+
) -> 'cirq.AbstractCircuit':
280+
return _transform_and_log(
281+
lambda circuit, context: method(self, circuit, context),
282+
cls.__name__,
283+
circuit,
284+
context,
285+
)
286+
287+
setattr(cls, '__call__', method_with_logging)
288+
return cls
311289
else:
312290
assert callable(cls_or_func)
313-
return _transformer_func(cls_or_func)
291+
func = cls_or_func
292+
293+
@functools.wraps(func)
294+
def func_with_logging(
295+
circuit: 'cirq.AbstractCircuit', context: TransformerContext
296+
) -> 'cirq.AbstractCircuit':
297+
return _transform_and_log(func, func.__name__, circuit, context)
298+
299+
return func_with_logging
300+
301+
302+
def _transform_and_log(
303+
func: TRANSFORMER,
304+
transformer_name: str,
305+
circuit: 'cirq.AbstractCircuit',
306+
context: TransformerContext,
307+
) -> 'cirq.AbstractCircuit':
308+
"""Helper to log initial and final circuits before and after calling the transformer."""
309+
context.logger.register_initial(circuit, transformer_name)
310+
transformed_circuit = func(circuit, context)
311+
context.logger.register_final(transformed_circuit, transformer_name)
312+
return transformed_circuit

0 commit comments

Comments
 (0)