Skip to content

Commit 2aea258

Browse files
authored
Remove remaining adjoint attributes (#909)
* Remove remaining adjoints. * Fix test failure.
1 parent e66eb6e commit 2aea258

3 files changed

Lines changed: 21 additions & 15 deletions

File tree

qualtran/bloqs/arithmetic/addition.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import cirq
2020
import numpy as np
2121
import sympy
22-
from attrs import field, frozen
22+
from attrs import evolve, field, frozen
2323
from numpy.typing import NDArray
2424

2525
from qualtran import (
@@ -273,11 +273,11 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate):
273273
"""
274274

275275
bitsize: int
276-
adjoint: bool = False
276+
is_adjoint: bool = False
277277

278278
@property
279279
def signature(self):
280-
side = Side.LEFT if self.adjoint else Side.RIGHT
280+
side = Side.LEFT if self.is_adjoint else Side.RIGHT
281281
return Signature(
282282
[
283283
Register('a', QUInt(self.bitsize)),
@@ -292,6 +292,9 @@ def registers(self) -> Sequence[Union[int, Sequence[int]]]:
292292
def apply(self, a: int, b: int, c: int) -> Tuple[int, int, int]:
293293
return a, b, c + a + b
294294

295+
def adjoint(self) -> 'OutOfPlaceAdder':
296+
return evolve(self, is_adjoint=not self.is_adjoint)
297+
295298
def on_classical_vals(
296299
self, *, a: 'ClassicalValT', b: 'ClassicalValT'
297300
) -> Dict[str, 'ClassicalValT']:
@@ -315,25 +318,25 @@ def decompose_from_registers(
315318
]
316319
for i in range(self.bitsize)
317320
]
318-
return cirq.inverse(optree) if self.adjoint else optree
321+
return cirq.inverse(optree) if self.is_adjoint else optree
319322

320323
def _t_complexity_(self) -> TComplexity:
321-
and_t = And(uncompute=self.adjoint).t_complexity()
324+
and_t = And(uncompute=self.is_adjoint).t_complexity()
322325
num_clifford = self.bitsize * (5 + and_t.clifford)
323326
num_t = self.bitsize * and_t.t
324327
return TComplexity(t=num_t, clifford=num_clifford)
325328

326329
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
327330
return {
328-
(And(uncompute=self.adjoint), self.bitsize),
331+
(And(uncompute=self.is_adjoint), self.bitsize),
329332
(ArbitraryClifford(n=2), 5 * self.bitsize),
330333
}
331334

332335
def __pow__(self, power: int):
333336
if power == 1:
334337
return self
335338
if power == -1:
336-
return OutOfPlaceAdder(self.bitsize, adjoint=not self.adjoint)
339+
return OutOfPlaceAdder(self.bitsize, is_adjoint=not self.is_adjoint)
337340
raise NotImplementedError("OutOfPlaceAdder.__pow__ defined only for +1/-1.")
338341

339342

qualtran/bloqs/arithmetic/multiplication.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import cirq
1818
import numpy as np
19-
from attrs import frozen
19+
from attrs import evolve, frozen
2020

2121
from qualtran import (
2222
Bloq,
@@ -48,10 +48,10 @@ class PlusEqualProduct(GateWithRegisters, cirq.ArithmeticGate):
4848
a_bitsize: int
4949
b_bitsize: int
5050
result_bitsize: int
51-
adjoint: bool = False
51+
is_adjoint: bool = False
5252

5353
def short_name(self) -> str:
54-
return "result -= a*b" if self.adjoint else "result += a*b"
54+
return "result -= a*b" if self.is_adjoint else "result += a*b"
5555

5656
@property
5757
def signature(self) -> 'Signature':
@@ -64,14 +64,17 @@ def signature(self) -> 'Signature':
6464
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
6565
return [2] * self.a_bitsize, [2] * self.b_bitsize, [2] * self.result_bitsize
6666

67+
def adjoint(self) -> 'PlusEqualProduct':
68+
return evolve(self, is_adjoint=not self.is_adjoint)
69+
6770
def apply(self, a: int, b: int, result: int) -> Union[int, Iterable[int]]:
68-
return a, b, (result + a * b * ((-1) ** self.adjoint)) % (2**self.result_bitsize)
71+
return a, b, (result + a * b * ((-1) ** self.is_adjoint)) % (2**self.result_bitsize)
6972

7073
def with_registers(self, *new_registers: Union[int, Sequence[int]]):
7174
raise NotImplementedError("Not needed.")
7275

7376
def on_classical_vals(self, a: int, b: int, result: int) -> Dict[str, 'ClassicalValT']:
74-
result_out = (result + a * b * ((-1) ** self.adjoint)) % (2**self.result_bitsize)
77+
result_out = (result + a * b * ((-1) ** self.is_adjoint)) % (2**self.result_bitsize)
7578
return {'a': a, 'b': b, 'result': result_out}
7679

7780
def _t_complexity_(self) -> 'TComplexity':
@@ -80,15 +83,15 @@ def _t_complexity_(self) -> 'TComplexity':
8083

8184
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
8285
wire_symbols = ['a'] * self.a_bitsize + ['b'] * self.b_bitsize
83-
wire_symbols += ['c-=a*b' if self.adjoint else 'c+=a*b'] * self.result_bitsize
86+
wire_symbols += ['c-=a*b' if self.is_adjoint else 'c+=a*b'] * self.result_bitsize
8487
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
8588

8689
def __pow__(self, power):
8790
if power == 1:
8891
return self
8992
if power == -1:
9093
return PlusEqualProduct(
91-
self.a_bitsize, self.b_bitsize, self.result_bitsize, not self.adjoint
94+
self.a_bitsize, self.b_bitsize, self.result_bitsize, not self.is_adjoint
9295
)
9396
raise NotImplementedError("PlusEqualProduct.__pow__ defined only for powers +1/-1.")
9497

qualtran/bloqs/chemistry/trotter/grid_ham/potential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def build_composite_bloq(
107107
)
108108
for xyz in range(3):
109109
system_i[xyz], system_j[xyz] = bb.add(
110-
OutOfPlaceAdder(self.bitsize, adjoint=True),
110+
OutOfPlaceAdder(self.bitsize, is_adjoint=True),
111111
a=system_i[xyz],
112112
b=system_j[xyz],
113113
c=diff_ij[xyz],

0 commit comments

Comments
 (0)