Skip to content

Commit 0262880

Browse files
authored
Remove attribute power and default to Power/GateWithRegisters.__pow__ (#903)
* Remove attribute `power` and default to `Power`/`GateWithRegisters.__pow__` * add test for `Power` * add circuit diagram info for `Power` * lint * fix infinite loop * check if `Power.power` is an integer * use `(0)` for 0-control (consistent with cirq) * `Power` cirq diagram: put exponent on every register * fix test * add diagram tests for controlled * fix doc in notebook
1 parent c9eb3c3 commit 0262880

11 files changed

Lines changed: 172 additions & 66 deletions

qualtran/_infra/controlled.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,18 @@ def as_cirq_op(
462462
sub_op.controlled_by(*ctrl_qubits, control_values=self.ctrl_spec.to_cirq_cv()),
463463
cirq_quregs | ctrl_regs,
464464
)
465+
466+
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
467+
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
468+
469+
if isinstance(self.subbloq, cirq.Gate):
470+
sub_info = cirq.circuit_diagram_info(self.subbloq, args, None)
471+
if sub_info is not None:
472+
cv_info = cirq.circuit_diagram_info(self.ctrl_spec.to_cirq_cv())
473+
474+
return cirq.CircuitDiagramInfo(
475+
wire_symbols=(*cv_info.wire_symbols, *sub_info.wire_symbols),
476+
exponent=sub_info.exponent,
477+
)
478+
479+
return _wire_symbol_to_cirq_diagram_info(self, args)

qualtran/_infra/controlled_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
OneState,
4040
Swap,
4141
XGate,
42+
XPowGate,
4243
YGate,
4344
ZeroState,
4445
ZGate,
@@ -402,3 +403,51 @@ def test_controlled_tensor_for_and_bloq(ctrl_spec: CtrlSpec):
402403
_verify_ctrl_tensor_for_and(ctrl_spec, (1, 0))
403404
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 1))
404405
_verify_ctrl_tensor_for_and(ctrl_spec, (0, 0))
406+
407+
408+
def test_controlled_diagrams():
409+
ctrl_gate = XPowGate(0.25).controlled()
410+
cirq.testing.assert_has_diagram(
411+
cirq.Circuit(ctrl_gate.on_registers(**get_named_qubits(ctrl_gate.signature))),
412+
'''
413+
ctrl: ───@────────
414+
415+
q: ──────X^0.25───''',
416+
)
417+
418+
ctrl_0_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=0))
419+
cirq.testing.assert_has_diagram(
420+
cirq.Circuit(ctrl_0_gate.on_registers(**get_named_qubits(ctrl_0_gate.signature))),
421+
'''
422+
ctrl: ───(0)──────
423+
424+
q: ──────X^0.25───''',
425+
)
426+
427+
multi_ctrl_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=[0, 1]))
428+
cirq.testing.assert_has_diagram(
429+
cirq.Circuit(multi_ctrl_gate.on_registers(**get_named_qubits(multi_ctrl_gate.signature))),
430+
'''
431+
ctrl[0]: ───(0)──────
432+
433+
ctrl[1]: ───@────────
434+
435+
q: ─────────X^0.25───''',
436+
)
437+
438+
ctrl_bloq = Swap(2).controlled(CtrlSpec(cvs=[0, 1]))
439+
cirq.testing.assert_has_diagram(
440+
cirq.Circuit(ctrl_bloq.on_registers(**get_named_qubits(ctrl_bloq.signature))),
441+
'''
442+
ctrl[0]: ───(0)────
443+
444+
ctrl[1]: ───@──────
445+
446+
x0: ────────×(x)───
447+
448+
x1: ────────×(x)───
449+
450+
y0: ────────×(y)───
451+
452+
y1: ────────×(y)───''',
453+
)

qualtran/bloqs/for_testing/atom.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import cached_property
1616
from typing import Any, Dict, Optional, TYPE_CHECKING
1717

18+
import attrs
1819
import numpy as np
1920
from attrs import frozen
2021

@@ -127,6 +128,7 @@ class TestGWRAtom(GateWithRegisters):
127128
"""
128129

129130
tag: Optional[str] = None
131+
is_adjoint: bool = False
130132

131133
@cached_property
132134
def signature(self) -> Signature:
@@ -157,16 +159,15 @@ def _unitary_(self):
157159
return np.eye(2)
158160

159161
def adjoint(self) -> 'Bloq':
160-
return self
162+
return attrs.evolve(self, is_adjoint=not self.is_adjoint)
161163

162164
def _t_complexity_(self) -> 'TComplexity':
163165
return TComplexity(100)
164166

165167
def __repr__(self):
166-
if self.tag:
167-
return f'TestGWRAtom({self.tag!r})'
168-
else:
169-
return 'TestGWRAtom()'
168+
tag = f'{self.tag!r}' if self.tag else ''
169+
dagger = '†' if self.is_adjoint else ''
170+
return f'TestGWRAtom({tag}){dagger}'
170171

171172
def short_name(self) -> str:
172173
if self.tag:

qualtran/bloqs/for_testing/atom_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def test_test_gwr_atom():
3939
assert ta.short_name() == 'GWRAtom'
4040
with pytest.raises(DecomposeTypeError):
4141
ta.decompose_bloq()
42-
assert ta.adjoint() == ta
42+
assert ta.adjoint() == TestGWRAtom(is_adjoint=True)
4343
np.testing.assert_allclose(cirq.unitary(ta), np.eye(2))

qualtran/bloqs/mean_estimation/mean_estimation_operator.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ class MeanEstimationOperator(GateWithRegisters):
8585
cv: Tuple[int, ...] = attrs.field(
8686
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
8787
)
88-
power: int = 1
8988
arctan_bitsize: int = 32
9089

9190
@cv.validator
@@ -124,17 +123,12 @@ def decompose_from_registers(
124123
) -> cirq.OP_TREE:
125124
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
126125
reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
127-
select_op = self.select.on_registers(**select_reg)
128-
reflect_op = self.reflect.on_registers(**reflect_reg)
129-
for _ in range(self.power):
130-
yield select_op
131-
yield reflect_op
126+
yield self.select.on_registers(**select_reg)
127+
yield self.reflect.on_registers(**reflect_reg)
132128

133129
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
134-
wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]]
130+
wire_symbols = [] if self.cv == () else [["(0)", "@"][self.cv[0]]]
135131
wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers))
136-
if self.power != 1:
137-
wire_symbols[-1] = f'U_ko^{self.power}'
138132
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
139133

140134
def controlled(
@@ -160,17 +154,8 @@ def controlled(
160154
return MeanEstimationOperator(
161155
CodeForRandomVariable(encoder=c_select, synthesizer=self.code.synthesizer),
162156
cv=self.cv + (control_values[0],),
163-
power=self.power,
164157
arctan_bitsize=self.arctan_bitsize,
165158
)
166159
raise NotImplementedError(
167160
f'Cannot create a controlled version of {self} with control_values={control_values}.'
168161
)
169-
170-
def with_power(self, new_power: int) -> 'MeanEstimationOperator':
171-
return MeanEstimationOperator(
172-
self.code, cv=self.cv, power=new_power, arctan_bitsize=self.arctan_bitsize
173-
)
174-
175-
def __pow__(self, power: int):
176-
return self.with_power(self.power * power)

qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,19 @@ def test_mean_estimation_operator_consistent_protocols():
281281
with pytest.raises(NotImplementedError, match="Cannot create a controlled version"):
282282
_ = mean_gate.controlled(num_controls=2)
283283

284-
# Test with_power
285-
assert mean_gate.with_power(5) ** 2 == MeanEstimationOperator(
286-
code, arctan_bitsize=arctan_bitsize, power=10
287-
)
288284
# Test diagrams
289-
expected_symbols = ['U_ko'] * cirq.num_qubits(mean_gate)
290-
assert cirq.circuit_diagram_info(mean_gate).wire_symbols == tuple(expected_symbols)
291-
control_symbols = ['@']
285+
n_qubits = cirq.num_qubits(mean_gate)
286+
287+
assert cirq.circuit_diagram_info(mean_gate).wire_symbols == tuple(['U_ko'] * n_qubits)
288+
292289
assert cirq.circuit_diagram_info(mean_gate.controlled()).wire_symbols == tuple(
293-
control_symbols + expected_symbols
290+
['@'] + ['U_ko'] * n_qubits
294291
)
295-
control_symbols = ['@(0)']
292+
296293
assert cirq.circuit_diagram_info(
297294
mean_gate.controlled(control_values=(0,))
298-
).wire_symbols == tuple(control_symbols + expected_symbols)
299-
expected_symbols[-1] = 'U_ko^2'
295+
).wire_symbols == tuple(['(0)'] + ['U_ko'] * n_qubits)
296+
300297
assert cirq.circuit_diagram_info(
301-
mean_gate.with_power(2).controlled(control_values=(0,))
302-
).wire_symbols == tuple(control_symbols + expected_symbols)
298+
(mean_gate**2).controlled(control_values=(0,))
299+
).wire_symbols == tuple(['(0)'] + ['U_ko^2'] * n_qubits)

qualtran/bloqs/qubitization_walk_operator.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@
5555
"#### Parameters\n",
5656
" - `select`: The SELECT lcu gate implementing $SELECT=\\sum_{l}|l><l|H_{l}$.\n",
5757
" - `prepare`: Then PREPARE lcu gate implementing $PREPARE|00...00> = \\sum_{l=0}^{L - 1}\\sqrt{\\frac{w_{l}}{\\lambda}} |l> = |\\ell>$\n",
58-
" - `control_val`: If 0/1, a controlled version of the walk operator is constructed. Defaults to None, in which case the resulting walk operator is not controlled.\n",
59-
" - `power`: Constructs $W^{power}$ by repeatedly decomposing into `power` copies of $W$. Defaults to 1. \n",
58+
" - `control_val`: If 0/1, a controlled version of the walk operator is constructed. Defaults to None, in which case the resulting walk operator is not controlled. \n",
6059
"\n",
6160
"#### References\n",
6261
" - [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Figure 1.\n"

qualtran/bloqs/qubitization_walk_operator.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from qualtran._infra.gate_with_registers import total_bits
2424
from qualtran.bloqs.reflection_using_prepare import ReflectionUsingPrepare
2525
from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle
26-
from qualtran.cirq_interop.t_complexity_protocol import t_complexity
2726
from qualtran.resource_counting.generalizers import (
2827
cirq_to_bloqs,
2928
ignore_cliffords,
@@ -55,8 +54,6 @@ class QubitizationWalkOperator(GateWithRegisters):
5554
$PREPARE|00...00> = \sum_{l=0}^{L - 1}\sqrt{\frac{w_{l}}{\lambda}} |l> = |\ell>$
5655
control_val: If 0/1, a controlled version of the walk operator is constructed. Defaults to
5756
None, in which case the resulting walk operator is not controlled.
58-
power: Constructs $W^{power}$ by repeatedly decomposing into `power` copies of $W$.
59-
Defaults to 1.
6057
6158
References:
6259
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
@@ -67,7 +64,6 @@ class QubitizationWalkOperator(GateWithRegisters):
6764
select: SelectOracle
6865
prepare: PrepareOracle
6966
control_val: Optional[int] = None
70-
power: int = 1
7167

7268
def __attrs_post_init__(self):
7369
assert self.select.control_registers == self.reflect.control_registers
@@ -105,18 +101,14 @@ def decompose_from_registers(
105101
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
106102
) -> cirq.OP_TREE:
107103
select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature}
108-
select_op = self.select.on_registers(**select_reg)
104+
yield self.select.on_registers(**select_reg)
109105

110106
reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature}
111-
reflect_op = self.reflect.on_registers(**reflect_reg)
112-
for _ in range(self.power):
113-
yield select_op
114-
yield reflect_op
107+
yield self.reflect.on_registers(**reflect_reg)
115108

116109
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
117110
wire_symbols = ['@' if self.control_val else '@(0)'] * total_bits(self.control_registers)
118111
wire_symbols += ['W'] * (total_bits(self.signature) - total_bits(self.control_registers))
119-
wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W'
120112
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
121113

122114
def controlled(
@@ -144,17 +136,6 @@ def controlled(
144136
f'Cannot create a controlled version of {self} with control_values={control_values}.'
145137
)
146138

147-
def with_power(self, new_power: int) -> 'QubitizationWalkOperator':
148-
return attrs.evolve(self, power=new_power)
149-
150-
def __pow__(self, power: int):
151-
return self.with_power(self.power * power)
152-
153-
def _t_complexity_(self):
154-
if self.power > 1:
155-
return self.power * t_complexity(self.with_power(1))
156-
return NotImplemented
157-
158139

159140
@bloq_example(generalizer=[cirq_to_bloqs, ignore_split_join, ignore_cliffords])
160141
def _walk_op() -> QubitizationWalkOperator:

qualtran/bloqs/qubitization_walk_operator_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,16 @@ def test_qubitization_walk_operator_diagrams():
112112
target3: ──────SelectPauliLCU─────────
113113
''',
114114
)
115+
115116
# 2. Diagram for $W^{2} = SELECT.R_{L}.SELCT.R_{L}$
116-
walk_squared_op = walk.with_power(2).on_registers(**g.quregs)
117-
circuit = cirq.Circuit(cirq.decompose_once(walk_squared_op))
117+
def decompose_twice(op):
118+
ops = []
119+
for sub_op in cirq.decompose_once(op):
120+
ops += cirq.decompose_once(sub_op)
121+
return ops
122+
123+
walk_squared_op = (walk**2).on_registers(**g.quregs)
124+
circuit = cirq.Circuit(decompose_twice(walk_squared_op))
118125
cirq.testing.assert_has_diagram(
119126
circuit,
120127
'''

qualtran/bloqs/util_bloqs.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@
3838
)
3939
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
4040
from qualtran.drawing import directional_text_box, WireSymbol
41-
from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt
41+
from qualtran.resource_counting.symbolic_counting_utils import is_symbolic, SymbolicInt
4242
from qualtran.simulation.classical_sim import bits_to_ints, ints_to_bits
4343

4444
if TYPE_CHECKING:
45+
import cirq
4546
from numpy.typing import NDArray
4647

4748
from qualtran import AddControlledT, CtrlSpec
@@ -497,7 +498,10 @@ class Power(GateWithRegisters):
497498
def __attrs_post_init__(self):
498499
if any(reg.side != Side.THRU for reg in self.bloq.signature):
499500
raise ValueError('Bloq to repeat must have only THRU registers')
500-
if self.power < 1:
501+
502+
if not is_symbolic(self.power) and (
503+
not isinstance(self.power, (int, np.integer)) or self.power < 1
504+
):
501505
raise ValueError(f'{self.power=} must be a positive integer.')
502506

503507
def adjoint(self) -> 'Bloq':
@@ -514,3 +518,21 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
514518

515519
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
516520
return {(self.bloq, self.power)}
521+
522+
def __pow__(self, power) -> 'Power':
523+
bloq = self.bloq.adjoint() if power < 0 else self.bloq
524+
return Power(bloq, self.power * abs(power))
525+
526+
def _circuit_diagram_info_(
527+
self, args: 'cirq.CircuitDiagramInfoArgs'
528+
) -> 'cirq.CircuitDiagramInfo':
529+
import cirq
530+
531+
info = cirq.circuit_diagram_info(self.bloq, args, default=None)
532+
533+
if info is None:
534+
info = super()._circuit_diagram_info_(args)
535+
536+
wire_symbols = [f'{symbol}^{self.power}' for symbol in info.wire_symbols]
537+
538+
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

0 commit comments

Comments
 (0)