18
18
import functools
19
19
from typing import (
20
20
Any ,
21
- Callable ,
22
21
Tuple ,
23
22
Hashable ,
24
23
List ,
25
- Type ,
26
24
overload ,
25
+ Protocol ,
26
+ Type ,
27
27
TYPE_CHECKING ,
28
+ TypeVar ,
28
29
)
29
30
import dataclasses
30
31
import enum
31
- from cirq .circuits .circuit import CIRCUIT_TYPE
32
32
33
33
if TYPE_CHECKING :
34
34
import cirq
@@ -218,77 +218,41 @@ class TransformerContext:
218
218
ignore_tags : Tuple [Hashable , ...] = ()
219
219
220
220
221
- TRANSFORMER = Callable [['cirq.AbstractCircuit' , TransformerContext ], 'cirq.AbstractCircuit' ]
222
- _TRANSFORMER_TYPE = Callable [['cirq.AbstractCircuit' , TransformerContext ], CIRCUIT_TYPE ]
223
-
224
-
225
- def _transform_and_log (
226
- func : _TRANSFORMER_TYPE [CIRCUIT_TYPE ],
227
- transformer_name : str ,
228
- circuit : 'cirq.AbstractCircuit' ,
229
- context : TransformerContext ,
230
- ) -> CIRCUIT_TYPE :
231
- """Helper to log initial and final circuits before and after calling the transformer."""
232
-
233
- context .logger .register_initial (circuit , transformer_name )
234
- transformed_circuit = func (circuit , context )
235
- context .logger .register_final (transformed_circuit , transformer_name )
236
- return transformed_circuit
237
-
238
-
239
- def _transformer_class (
240
- cls : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
241
- ) -> Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]]:
242
- old_func = cls .__call__
243
-
244
- def transformer_with_logging_cls (
245
- self : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
246
- circuit : 'cirq.AbstractCircuit' ,
247
- context : TransformerContext ,
248
- ) -> CIRCUIT_TYPE :
249
- def call_old_func (c : 'cirq.AbstractCircuit' , ct : TransformerContext ) -> CIRCUIT_TYPE :
250
- return old_func (self , c , ct )
251
-
252
- return _transform_and_log (call_old_func , cls .__name__ , circuit , context )
221
+ class TRANSFORMER (Protocol ):
222
+ def __call__ (
223
+ self , circuit : 'cirq.AbstractCircuit' , context : TransformerContext
224
+ ) -> 'cirq.AbstractCircuit' :
225
+ ...
253
226
254
- setattr (cls , '__call__' , transformer_with_logging_cls )
255
- return cls
256
227
257
-
258
- def _transformer_func (func : _TRANSFORMER_TYPE [CIRCUIT_TYPE ]) -> _TRANSFORMER_TYPE [CIRCUIT_TYPE ]:
259
- @functools .wraps (func )
260
- def transformer_with_logging_func (
261
- circuit : 'cirq.AbstractCircuit' ,
262
- context : TransformerContext ,
263
- ) -> CIRCUIT_TYPE :
264
- return _transform_and_log (func , func .__name__ , circuit , context )
265
-
266
- return transformer_with_logging_func
228
+ _TRANSFORMER_T = TypeVar ('_TRANSFORMER_T' , bound = TRANSFORMER )
229
+ _TRANSFORMER_CLS_T = TypeVar ('_TRANSFORMER_CLS_T' , bound = Type [TRANSFORMER ])
267
230
268
231
269
232
@overload
270
- def transformer (cls_or_func : _TRANSFORMER_TYPE [ CIRCUIT_TYPE ] ) -> _TRANSFORMER_TYPE [ CIRCUIT_TYPE ] :
233
+ def transformer (cls_or_func : _TRANSFORMER_T ) -> _TRANSFORMER_T :
271
234
pass
272
235
273
236
274
237
@overload
275
- def transformer (
276
- cls_or_func : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
277
- ) -> Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]]:
238
+ def transformer (cls_or_func : _TRANSFORMER_CLS_T ) -> _TRANSFORMER_CLS_T :
278
239
pass
279
240
280
241
281
242
def transformer (cls_or_func : Any ) -> Any :
282
243
"""Decorator to verify API and append logging functionality to transformer functions & classes.
283
244
284
- The decorated function or class must satisfy
285
- `Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
245
+ A transformer is a callable that takes as inputs a cirq.AbstractCircuit and
246
+ cirq.TransformerContext, and returns another cirq.AbstractCircuit without
247
+ modifying the input circuit. A transformer could be a function, for example:
286
248
287
249
>>> @cirq.transformer
288
- >>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
250
+ >>> def convert_to_cz(
251
+ >>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
252
+ >>> ) -> cirq.Circuit:
289
253
>>> ...
290
254
291
- The decorated class must implement the `__call__` method to satisfy the above API.
255
+ Or it could be a class that implements `__call__` with the same API, for example:
292
256
293
257
>>> @cirq.transformer
294
258
>>> class ConvertToSqrtISwaps:
@@ -300,14 +264,45 @@ def transformer(cls_or_func: Any) -> Any:
300
264
>>> ...
301
265
302
266
Args:
303
- cls_or_func: The callable class or method to be decorated.
267
+ cls_or_func: The callable class or function to be decorated.
304
268
305
269
Returns:
306
- Decorated class / method which includes additional logging boilerplate. The decorated
307
- callable always receives a copy of the input circuit so that the input is never mutated.
270
+ Decorated class / function which includes additional logging boilerplate.
308
271
"""
309
272
if isinstance (cls_or_func , type ):
310
- return _transformer_class (cls_or_func )
273
+ cls = cls_or_func
274
+ method = cls .__call__
275
+
276
+ @functools .wraps (method )
277
+ def method_with_logging (self , circuit , context ):
278
+ return _transform_and_log (
279
+ lambda circuit , context : method (self , circuit , context ),
280
+ cls .__name__ ,
281
+ circuit ,
282
+ context ,
283
+ )
284
+
285
+ setattr (cls , '__call__' , method_with_logging )
286
+ return cls
311
287
else :
312
288
assert callable (cls_or_func )
313
- return _transformer_func (cls_or_func )
289
+ func = cls_or_func
290
+
291
+ @functools .wraps (func )
292
+ def func_with_logging (circuit , context ):
293
+ return _transform_and_log (func , func .__name__ , circuit , context )
294
+
295
+ return func_with_logging
296
+
297
+
298
+ def _transform_and_log (
299
+ func : TRANSFORMER ,
300
+ transformer_name : str ,
301
+ circuit : 'cirq.AbstractCircuit' ,
302
+ context : TransformerContext ,
303
+ ) -> 'cirq.AbstractCircuit' :
304
+ """Helper to log initial and final circuits before and after calling the transformer."""
305
+ context .logger .register_initial (circuit , transformer_name )
306
+ transformed_circuit = func (circuit , context )
307
+ context .logger .register_final (transformed_circuit , transformer_name )
308
+ return transformed_circuit
0 commit comments