1616
1717import cirq
1818import numpy as np
19- from attrs import frozen
19+ from attrs import evolve , frozen
2020
2121from qualtran import (
2222 Bloq ,
@@ -48,10 +48,10 @@ class PlusEqualProduct(GateWithRegisters, cirq.ArithmeticGate):
4848 a_bitsize : int
4949 b_bitsize : int
5050 result_bitsize : int
51- adjoint : bool = False
51+ is_adjoint : bool = False
5252
5353 def short_name (self ) -> str :
54- return "result -= a*b" if self .adjoint else "result += a*b"
54+ return "result -= a*b" if self .is_adjoint else "result += a*b"
5555
5656 @property
5757 def signature (self ) -> 'Signature' :
@@ -64,14 +64,17 @@ def signature(self) -> 'Signature':
6464 def registers (self ) -> Sequence [Union [int , Sequence [int ]]]:
6565 return [2 ] * self .a_bitsize , [2 ] * self .b_bitsize , [2 ] * self .result_bitsize
6666
67+ def adjoint (self ) -> 'PlusEqualProduct' :
68+ return evolve (self , is_adjoint = not self .is_adjoint )
69+
6770 def apply (self , a : int , b : int , result : int ) -> Union [int , Iterable [int ]]:
68- return a , b , (result + a * b * ((- 1 ) ** self .adjoint )) % (2 ** self .result_bitsize )
71+ return a , b , (result + a * b * ((- 1 ) ** self .is_adjoint )) % (2 ** self .result_bitsize )
6972
7073 def with_registers (self , * new_registers : Union [int , Sequence [int ]]):
7174 raise NotImplementedError ("Not needed." )
7275
7376 def on_classical_vals (self , a : int , b : int , result : int ) -> Dict [str , 'ClassicalValT' ]:
74- result_out = (result + a * b * ((- 1 ) ** self .adjoint )) % (2 ** self .result_bitsize )
77+ result_out = (result + a * b * ((- 1 ) ** self .is_adjoint )) % (2 ** self .result_bitsize )
7578 return {'a' : a , 'b' : b , 'result' : result_out }
7679
7780 def _t_complexity_ (self ) -> 'TComplexity' :
@@ -80,15 +83,15 @@ def _t_complexity_(self) -> 'TComplexity':
8083
8184 def _circuit_diagram_info_ (self , args : cirq .CircuitDiagramInfoArgs ) -> cirq .CircuitDiagramInfo :
8285 wire_symbols = ['a' ] * self .a_bitsize + ['b' ] * self .b_bitsize
83- wire_symbols += ['c-=a*b' if self .adjoint else 'c+=a*b' ] * self .result_bitsize
86+ wire_symbols += ['c-=a*b' if self .is_adjoint else 'c+=a*b' ] * self .result_bitsize
8487 return cirq .CircuitDiagramInfo (wire_symbols = wire_symbols )
8588
8689 def __pow__ (self , power ):
8790 if power == 1 :
8891 return self
8992 if power == - 1 :
9093 return PlusEqualProduct (
91- self .a_bitsize , self .b_bitsize , self .result_bitsize , not self .adjoint
94+ self .a_bitsize , self .b_bitsize , self .result_bitsize , not self .is_adjoint
9295 )
9396 raise NotImplementedError ("PlusEqualProduct.__pow__ defined only for powers +1/-1." )
9497
0 commit comments