Skip to content

Commit c4d19d1

Browse files
committed
Use a Protocol for TRANSFORMER to ensure common arg names
Also cleans up some of the internals of the transformer decorator and simplifies the types. Follow-up to #4797
1 parent 6119620 commit c4d19d1

File tree

1 file changed

+55
-60
lines changed

1 file changed

+55
-60
lines changed

cirq-core/cirq/transformers/transformer_api.py

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
import functools
1919
from typing import (
2020
Any,
21-
Callable,
2221
Tuple,
2322
Hashable,
2423
List,
25-
Type,
2624
overload,
25+
Protocol,
26+
Type,
2727
TYPE_CHECKING,
28+
TypeVar,
2829
)
2930
import dataclasses
3031
import enum
31-
from cirq.circuits.circuit import CIRCUIT_TYPE
3232

3333
if TYPE_CHECKING:
3434
import cirq
@@ -218,77 +218,41 @@ class TransformerContext:
218218
ignore_tags: Tuple[Hashable, ...] = ()
219219

220220

221-
TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
222-
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]
223-
224-
225-
def _transform_and_log(
226-
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
227-
transformer_name: str,
228-
circuit: 'cirq.AbstractCircuit',
229-
context: TransformerContext,
230-
) -> CIRCUIT_TYPE:
231-
"""Helper to log initial and final circuits before and after calling the transformer."""
232-
233-
context.logger.register_initial(circuit, transformer_name)
234-
transformed_circuit = func(circuit, context)
235-
context.logger.register_final(transformed_circuit, transformer_name)
236-
return transformed_circuit
237-
238-
239-
def _transformer_class(
240-
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
241-
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
242-
old_func = cls.__call__
243-
244-
def transformer_with_logging_cls(
245-
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
246-
circuit: 'cirq.AbstractCircuit',
247-
context: TransformerContext,
248-
) -> CIRCUIT_TYPE:
249-
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
250-
return old_func(self, c, ct)
251-
252-
return _transform_and_log(call_old_func, cls.__name__, circuit, context)
221+
class TRANSFORMER(Protocol):
222+
def __call__(
223+
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
224+
) -> 'cirq.AbstractCircuit':
225+
...
253226

254-
setattr(cls, '__call__', transformer_with_logging_cls)
255-
return cls
256227

257-
258-
def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
259-
@functools.wraps(func)
260-
def transformer_with_logging_func(
261-
circuit: 'cirq.AbstractCircuit',
262-
context: TransformerContext,
263-
) -> CIRCUIT_TYPE:
264-
return _transform_and_log(func, func.__name__, circuit, context)
265-
266-
return transformer_with_logging_func
228+
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
229+
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
267230

268231

269232
@overload
270-
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
233+
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
271234
pass
272235

273236

274237
@overload
275-
def transformer(
276-
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
277-
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
238+
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
278239
pass
279240

280241

281242
def transformer(cls_or_func: Any) -> Any:
282243
"""Decorator to verify API and append logging functionality to transformer functions & classes.
283244
284-
The decorated function or class must satisfy
285-
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
245+
A transformer is a callable that takes as inputs a cirq.AbstractCircuit and
246+
cirq.TransformerContext, and returns another cirq.AbstractCircuit without
247+
modifying the input circuit. A transformer could be a function, for example:
286248
287249
>>> @cirq.transformer
288-
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
250+
>>> def convert_to_cz(
251+
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
252+
>>> ) -> cirq.Circuit:
289253
>>> ...
290254
291-
The decorated class must implement the `__call__` method to satisfy the above API.
255+
Or it could be a class that implements `__call__` with the same API, for example:
292256
293257
>>> @cirq.transformer
294258
>>> class ConvertToSqrtISwaps:
@@ -300,14 +264,45 @@ def transformer(cls_or_func: Any) -> Any:
300264
>>> ...
301265
302266
Args:
303-
cls_or_func: The callable class or method to be decorated.
267+
cls_or_func: The callable class or function to be decorated.
304268
305269
Returns:
306-
Decorated class / method which includes additional logging boilerplate. The decorated
307-
callable always receives a copy of the input circuit so that the input is never mutated.
270+
Decorated class / function which includes additional logging boilerplate.
308271
"""
309272
if isinstance(cls_or_func, type):
310-
return _transformer_class(cls_or_func)
273+
cls = cls_or_func
274+
method = cls.__call__
275+
276+
@functools.wraps(method)
277+
def method_with_logging(self, circuit, context):
278+
return _transform_and_log(
279+
lambda circuit, context: method(self, circuit, context),
280+
cls.__name__,
281+
circuit,
282+
context,
283+
)
284+
285+
setattr(cls, '__call__', method_with_logging)
286+
return cls
311287
else:
312288
assert callable(cls_or_func)
313-
return _transformer_func(cls_or_func)
289+
func = cls_or_func
290+
291+
@functools.wraps(func)
292+
def func_with_logging(circuit, context):
293+
return _transform_and_log(func, func.__name__, circuit, context)
294+
295+
return func_with_logging
296+
297+
298+
def _transform_and_log(
299+
func: TRANSFORMER,
300+
transformer_name: str,
301+
circuit: 'cirq.AbstractCircuit',
302+
context: TransformerContext,
303+
) -> 'cirq.AbstractCircuit':
304+
"""Helper to log initial and final circuits before and after calling the transformer."""
305+
context.logger.register_initial(circuit, transformer_name)
306+
transformed_circuit = func(circuit, context)
307+
context.logger.register_final(transformed_circuit, transformer_name)
308+
return transformed_circuit

0 commit comments

Comments
 (0)