1414import itertools
1515import math
1616from functools import cached_property
17- from typing import Any , Dict , Iterable , Optional , Sequence , Set , Tuple , TYPE_CHECKING , Union
17+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Set , Tuple , TYPE_CHECKING , Union
1818
1919import cirq
2020import numpy as np
@@ -253,7 +253,7 @@ def _add_diff_size_regs() -> Add:
253253
254254
255255@frozen
256- class OutOfPlaceAdder (GateWithRegisters , cirq .ArithmeticGate ):
256+ class OutOfPlaceAdder (GateWithRegisters , cirq .ArithmeticGate ): # type: ignore[misc]
257257 r"""An n-bit addition gate.
258258
259259 Implements $U|a\rangle|b\rangle 0\rangle \rightarrow |a\rangle|b\rangle|a+b\rangle$
@@ -310,11 +310,14 @@ def decompose_from_registers(
310310 self , * , context : cirq .DecompositionContext , ** quregs : NDArray [cirq .Qid ]
311311 ) -> cirq .OP_TREE :
312312 a , b , c = quregs ['a' ][::- 1 ], quregs ['b' ][::- 1 ], quregs ['c' ][::- 1 ]
313- optree = [
313+ optree : List [ List [ cirq . Operation ]] = [
314314 [
315- [cirq .CX (a [i ], b [i ]), cirq .CX (a [i ], c [i ])],
315+ cirq .CX (a [i ], b [i ]),
316+ cirq .CX (a [i ], c [i ]),
316317 And ().on (b [i ], c [i ], c [i + 1 ]),
317- [cirq .CX (a [i ], b [i ]), cirq .CX (a [i ], c [i + 1 ]), cirq .CX (b [i ], c [i ])],
318+ cirq .CX (a [i ], b [i ]),
319+ cirq .CX (a [i ], c [i + 1 ]),
320+ cirq .CX (b [i ], c [i ]),
318321 ]
319322 for i in range (self .bitsize )
320323 ]
@@ -418,13 +421,13 @@ def on_classical_vals(
418421 else :
419422 return {'x' : x + self .k }
420423
421- if (self .cvs == ctrls ). all ( ):
424+ if np . all (self .cvs == ctrls ):
422425 x = x + self .k
423426
424427 return {'ctrls' : ctrls , 'x' : x }
425428
426429 def build_composite_bloq (
427- self , bb : 'BloqBuilder' , x : SoquetT , ** regs : SoquetT
430+ self , bb : 'BloqBuilder' , x : Soquet , ** regs : SoquetT
428431 ) -> Dict [str , 'SoquetT' ]:
429432 # Assign registers to variables and allocate ancilla bits for classical integer k.
430433 if len (self .cvs ) > 0 :
@@ -444,7 +447,7 @@ def build_composite_bloq(
444447 # controlled.
445448 for i in range (self .bitsize ):
446449 if binary_rep [i ] == 1 :
447- if len (self .cvs ) > 0 :
450+ if len (self .cvs ) > 0 and ctrls is not None :
448451 ctrls , k_split [i ] = bb .add (
449452 MultiControlX (cvs = self .cvs ), ctrls = ctrls , x = k_split [i ]
450453 )
@@ -453,14 +456,18 @@ def build_composite_bloq(
453456
454457 # Rejoin the qubits representing k for in-place addition.
455458 k = bb .join (k_split , dtype = x .reg .dtype )
459+ if not isinstance (x .reg .dtype , (QInt , QUInt , QMontgomeryUInt )):
460+ raise ValueError (
461+ "Only QInt, QUInt and QMontgomerUInt types are supported for composite addition."
462+ )
456463 k , x = bb .add (Add (x .reg .dtype , x .reg .dtype ), a = k , b = x )
457464
458465 # Resplit the k qubits in order to undo the original bit flips to go from the binary
459466 # representation back to the zero state.
460467 k_split = bb .split (k )
461468 for i in range (self .bitsize ):
462469 if binary_rep [i ] == 1 :
463- if len (self .cvs ) > 0 :
470+ if len (self .cvs ) > 0 and ctrls is not None :
464471 ctrls , k_split [i ] = bb .add (
465472 MultiControlX (cvs = self .cvs ), ctrls = ctrls , x = k_split [i ]
466473 )
@@ -472,7 +479,7 @@ def build_composite_bloq(
472479 bb .free (k )
473480
474481 # Return the output registers.
475- if len (self .cvs ) > 0 :
482+ if len (self .cvs ) > 0 and ctrls is not None :
476483 return {'ctrls' : ctrls , 'x' : x }
477484 else :
478485 return {'x' : x }
@@ -499,7 +506,7 @@ def _simple_add_k_large() -> SimpleAddConstant:
499506
500507
501508@frozen (auto_attribs = True )
502- class AddConstantMod (GateWithRegisters , cirq .ArithmeticGate ):
509+ class AddConstantMod (GateWithRegisters , cirq .ArithmeticGate ): # type: ignore[misc]
503510 """Applies U(add, M)|x> = |(x + add) % M> if x < M else |x>.
504511
505512 Applies modular addition to input register `|x>` given parameters `mod` and `add_val` s.t.
0 commit comments