Skip to content

Commit 1b8b449

Browse files
anurudhpmpharrigan
andauthored
Fix custom controlled versions of SelectOracles and QubitizationWalkOperators (#904)
* Fix custom controlled Select and Walk operators * use mixin class instead * `get_ctrl_spec` delegate to base class for other cases * docstring typo * docstring cleanup * fix some type errors * `type: ignore` the `controlled` mismatch errors * remove superfluous `GateWithRegisters` inheritance --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 38db1af commit 1b8b449

12 files changed

Lines changed: 116 additions & 221 deletions

qualtran/_infra/gate_with_registers.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Union,
2929
)
3030

31+
import attrs
3132
import cirq
3233
import numpy as np
3334
from numpy.typing import NDArray
@@ -39,7 +40,7 @@
3940
if TYPE_CHECKING:
4041
import quimb.tensor as qtn
4142

42-
from qualtran import CtrlSpec, SoquetT
43+
from qualtran import AddControlledT, BloqBuilder, CtrlSpec, SoquetT
4344
from qualtran.cirq_interop import CirqQuregT
4445
from qualtran.drawing import WireSymbol
4546

@@ -382,15 +383,6 @@ def _get_ctrl_spec(
382383
already accepts a `CtrlSpec` and simply returns it OR a Cirq-style API which accepts
383384
parameters expected by `cirq.Gate.controlled()` and converts them to a `CtrlSpec` object.
384385
385-
Users implementing custom `GateWithRegisters.controlled()` overrides can use this helper
386-
to generate a CtrlSpec from the cirq-style API and thus easily support both Cirq & Bloq
387-
APIs. For example
388-
389-
>>> class CustomGWR(GateWithRegisters):
390-
>>> def controlled(self, *args, **kwargs) -> 'Bloq':
391-
>>> ctrl_spec = self._get_ctrl_spec(*args, **kwargs)
392-
>>> # Use ctrl_spec to construct a controlled version of `self`.
393-
394386
Args:
395387
num_controls: Cirq style API to specify control specification -
396388
Total number of control qubits.
@@ -478,10 +470,6 @@ def controlled(
478470
bloq. Bloqs authors can declare their own, custom controlled versions by overriding
479471
`Bloq.get_ctrl_system` in the bloq.
480472
481-
If overriding the `GWR.controlled()` method directly, Bloq authors can use the
482-
`self._get_ctrl_spec` helper to construct a `CtrlSpec` object from the input parameters of
483-
`GWR.controlled()` and use it to return a custom controlled version of this Bloq.
484-
485473
486474
Args:
487475
num_controls: Cirq style API to specify control specification -
@@ -549,3 +537,62 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ
549537

550538
wire_symbols[0] = self.__class__.__name__
551539
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
540+
541+
542+
class SpecializedSingleQubitControlledGate(GateWithRegisters):
543+
"""Add a specialized single-qubit controlled version of a Bloq.
544+
545+
`control_val` is an optional single-bit control. When `control_val` is provided,
546+
the `control_registers` property should return a single named qubit register,
547+
and otherwise return an empty tuple.
548+
549+
Example usage:
550+
551+
@attrs.frozen
552+
class MyGate(SpecializedSingleQubitControlledGate):
553+
control_val: Optional[int] = None
554+
555+
@property
556+
def control_registers() -> Tuple[Register, ...]:
557+
return () if self.control_val is None else (Register('control', QBit()),)
558+
"""
559+
560+
control_val: Optional[int]
561+
562+
@property
563+
@abc.abstractmethod
564+
def control_registers(self) -> Tuple[Register, ...]:
565+
...
566+
567+
def get_single_qubit_controlled_bloq(
568+
self, control_val: int
569+
) -> 'SpecializedSingleQubitControlledGate':
570+
"""Override this to provide a custom controlled bloq"""
571+
return attrs.evolve(self, control_val=control_val) # type: ignore[misc]
572+
573+
def get_ctrl_system(
574+
self, ctrl_spec: Optional['CtrlSpec'] = None
575+
) -> Tuple['Bloq', 'AddControlledT']:
576+
if ctrl_spec is None:
577+
ctrl_spec = CtrlSpec()
578+
579+
if self.control_val is None and ctrl_spec.shapes in [((),), ((1,),)]:
580+
control_val = int(ctrl_spec.cvs[0].item())
581+
cbloq = self.get_single_qubit_controlled_bloq(control_val)
582+
583+
if not hasattr(cbloq, 'control_registers'):
584+
raise TypeError("{cbloq} should have attribute `control_registers`")
585+
586+
(ctrl_reg,) = cbloq.control_registers
587+
588+
def adder(
589+
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT']
590+
) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
591+
soqs = {ctrl_reg.name: ctrl_soqs[0]} | in_soqs
592+
soqs = bb.add_d(cbloq, **soqs)
593+
ctrl_soqs = [soqs.pop(ctrl_reg.name)]
594+
return ctrl_soqs, soqs.values()
595+
596+
return cbloq, adder
597+
598+
return super().get_ctrl_system(ctrl_spec)

qualtran/bloqs/for_testing/random_select_and_prepare.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import Iterator, Optional, Sequence, Tuple
15+
from typing import Iterator, Optional, Tuple
1616

17-
import attrs
1817
import cirq
1918
import numpy as np
2019
from attrs import frozen
2120
from numpy.typing import NDArray
2221

2322
from qualtran import BloqBuilder, BoundedQUInt, QBit, Register, SoquetT
23+
from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate
2424
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
2525
from qualtran.bloqs.qubitization_walk_operator import QubitizationWalkOperator
2626
from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle
@@ -82,7 +82,7 @@ def alphas(self):
8282

8383

8484
@frozen
85-
class TestPauliSelectOracle(SelectOracle):
85+
class TestPauliSelectOracle(SpecializedSingleQubitControlledGate, SelectOracle): # type: ignore[misc]
8686
r"""Paulis acting on $m$ qubits, controlled by an $n$-qubit register.
8787
8888
Given $2^n$ multi-qubit-Paulis (acting on $m$ qubits) $U_j$,
@@ -132,33 +132,6 @@ def selection_registers(self) -> Tuple[Register, ...]:
132132
def target_registers(self) -> Tuple[Register, ...]:
133133
return (Register('target', BoundedQUInt(bitsize=self.target_bitsize)),)
134134

135-
def adjoint(self):
136-
return self
137-
138-
def __pow__(self, power):
139-
if abs(power) == 1:
140-
return self
141-
return NotImplemented
142-
143-
def controlled(
144-
self,
145-
num_controls: Optional[int] = None,
146-
control_values=None,
147-
control_qid_shape: Optional[Tuple[int, ...]] = None,
148-
) -> 'cirq.Gate':
149-
if num_controls is None:
150-
num_controls = 1
151-
if control_values is None:
152-
control_values = [1] * num_controls
153-
if (
154-
isinstance(control_values, Sequence)
155-
and isinstance(control_values[0], int)
156-
and len(control_values) == 1
157-
and self.control_val is None
158-
):
159-
return attrs.evolve(self, control_val=control_values[0])
160-
raise NotImplementedError()
161-
162135
def decompose_from_registers(
163136
self,
164137
*,
@@ -167,14 +140,12 @@ def decompose_from_registers(
167140
target: NDArray[cirq.Qid], # type: ignore[type-var]
168141
**quregs: NDArray[cirq.Qid], # type: ignore[type-var]
169142
) -> Iterator[cirq.OP_TREE]:
170-
if self.control_val is not None:
171-
selection = np.concatenate([selection, quregs['control']])
172-
173143
for cv, U in enumerate(self.select_unitaries):
174144
bits = tuple(map(int, bin(cv)[2:].zfill(self.select_bitsize)))[::-1]
145+
op = U.on(*target).controlled_by(*selection, control_values=bits)
175146
if self.control_val is not None:
176-
bits = (*bits, self.control_val)
177-
yield U.on(*target).controlled_by(*selection, control_values=bits)
147+
op = op.controlled_by(*quregs['control'], control_values=[self.control_val])
148+
yield op
178149

179150

180151
def random_qubitization_walk_operator(

qualtran/bloqs/hubbard_model.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@
4747
See the documentation for `PrepareHubbard` and `SelectHubbard` for details.
4848
"""
4949
from functools import cached_property
50-
from typing import Collection, Iterator, Optional, Sequence, Tuple, TYPE_CHECKING, Union
50+
from typing import Iterator, Optional, Tuple, TYPE_CHECKING
5151

5252
import attrs
5353
import cirq
5454
import numpy as np
5555
from numpy.typing import NDArray
5656

5757
from qualtran import BoundedQUInt, QAny, QBit, Register, Signature
58-
from qualtran._infra.gate_with_registers import total_bits
58+
from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate, total_bits
5959
from qualtran.bloqs.arithmetic import AddConstantMod
6060
from qualtran.bloqs.basic_gates import CSwap
6161
from qualtran.bloqs.mcmt.and_bloq import MultiAnd
@@ -72,7 +72,7 @@
7272

7373

7474
@attrs.frozen
75-
class SelectHubbard(SelectOracle):
75+
class SelectHubbard(SpecializedSingleQubitControlledGate, SelectOracle): # type: ignore[misc]
7676
r"""The SELECT operation optimized for the 2D Hubbard model.
7777
7878
In contrast to the arbitrary chemistry Hamiltonian, we:
@@ -219,29 +219,6 @@ def decompose_from_registers(
219219
q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate
220220
)
221221

222-
def controlled(
223-
self,
224-
num_controls: Optional[int] = None,
225-
control_values: Optional[
226-
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
227-
] = None,
228-
control_qid_shape: Optional[Tuple[int, ...]] = None,
229-
) -> 'SelectHubbard':
230-
if num_controls is None:
231-
num_controls = 1
232-
if control_values is None:
233-
control_values = [1] * num_controls
234-
if (
235-
isinstance(control_values, Sequence)
236-
and isinstance(control_values[0], int)
237-
and len(control_values) == 1
238-
and self.control_val is None
239-
):
240-
return SelectHubbard(self.x_dim, self.y_dim, control_val=control_values[0])
241-
raise NotImplementedError(
242-
f'Cannot create a controlled version of {self} with control_values={control_values}.'
243-
)
244-
245222
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
246223
info = super(SelectHubbard, self)._circuit_diagram_info_(args)
247224
if self.control_val is None:

qualtran/bloqs/hubbard_model_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def test_hubbard_model_consistent_protocols():
6363
select_gate.controlled(num_controls=1, control_values=(0,)),
6464
select_op.controlled_by(cirq.q("control"), control_values=(0,)).gate,
6565
)
66-
with pytest.raises(NotImplementedError, match="Cannot create a controlled version"):
67-
_ = select_gate.controlled(num_controls=2)
6866

6967
# Test diagrams
7068
expected_symbols = ['U', 'V', 'p_x', 'p_y', 'alpha', 'q_x', 'q_y', 'beta']

qualtran/bloqs/mean_estimation/mean_estimation_operator.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Collection, Iterator, Optional, Sequence, Tuple, Union
16+
from typing import Iterator, Optional, Tuple
1717

1818
import attrs
1919
import cirq
2020
from numpy.typing import NDArray
2121

22-
from qualtran import GateWithRegisters, Register, Signature
23-
from qualtran._infra.gate_with_registers import total_bits
22+
from qualtran import CtrlSpec, Register, Signature
23+
from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate, total_bits
2424
from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle
2525
from qualtran.bloqs.reflection_using_prepare import ReflectionUsingPrepare
2626
from qualtran.bloqs.select_and_prepare import PrepareOracle, SelectOracle
@@ -63,7 +63,7 @@ def __attrs_post_init__(self):
6363

6464

6565
@attrs.frozen
66-
class MeanEstimationOperator(GateWithRegisters):
66+
class MeanEstimationOperator(SpecializedSingleQubitControlledGate):
6767
r"""Mean estimation operator $U=REFL_{p} ROT_{y}$ as per Sec 3.1 of arxiv.org:2208.07544.
6868
6969
The MeanEstimationOperator (aka KO Operator) expects `CodeForRandomVariable` to specify the
@@ -82,21 +82,13 @@ class MeanEstimationOperator(GateWithRegisters):
8282
"""
8383

8484
code: CodeForRandomVariable
85-
cv: Tuple[int, ...] = attrs.field(
86-
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
87-
)
85+
control_val: Optional[int] = None
8886
arctan_bitsize: int = 32
8987

90-
@cv.validator
91-
def _validate_cv(self, attribute, value):
92-
assert value in [(), (0,), (1,)]
93-
9488
@cached_property
9589
def reflect(self) -> ReflectionUsingPrepare:
9690
return ReflectionUsingPrepare(
97-
self.code.synthesizer,
98-
control_val=None if self.cv == () else self.cv[0],
99-
global_phase=-1,
91+
self.code.synthesizer, global_phase=-1, control_val=self.control_val
10092
)
10193

10294
@cached_property
@@ -126,36 +118,15 @@ def decompose_from_registers(
126118
yield self.select.on_registers(**select_reg)
127119
yield self.reflect.on_registers(**reflect_reg)
128120

121+
def get_single_qubit_controlled_bloq(self, control_val: int) -> 'MeanEstimationOperator':
122+
c_encoder = self.code.encoder.controlled(ctrl_spec=CtrlSpec(cvs=control_val))
123+
assert isinstance(c_encoder, SelectOracle)
124+
c_code = attrs.evolve(self.code, encoder=c_encoder)
125+
return attrs.evolve(self, code=c_code, control_val=control_val)
126+
129127
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
130-
wire_symbols = [] if self.cv == () else [["(0)", "@"][self.cv[0]]]
128+
wire_symbols = []
129+
if self.control_val is not None:
130+
wire_symbols.append("@" if self.control_val == 1 else "(0)")
131131
wire_symbols += ['U_ko'] * (total_bits(self.signature) - total_bits(self.control_registers))
132132
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
133-
134-
def controlled(
135-
self,
136-
num_controls: Optional[int] = None,
137-
control_values: Optional[
138-
Union[cirq.ops.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
139-
] = None,
140-
control_qid_shape: Optional[Tuple[int, ...]] = None,
141-
) -> 'MeanEstimationOperator':
142-
if num_controls is None:
143-
num_controls = 1
144-
if control_values is None:
145-
control_values = [1] * num_controls
146-
if (
147-
isinstance(control_values, Sequence)
148-
and len(control_values) == 1
149-
and isinstance(control_values[0], int)
150-
and not self.cv
151-
):
152-
c_select = self.code.encoder.controlled(control_values=control_values)
153-
assert isinstance(c_select, SelectOracle)
154-
return MeanEstimationOperator(
155-
CodeForRandomVariable(encoder=c_select, synthesizer=self.code.synthesizer),
156-
cv=self.cv + (control_values[0],),
157-
arctan_bitsize=self.arctan_bitsize,
158-
)
159-
raise NotImplementedError(
160-
f'Cannot create a controlled version of {self} with control_values={control_values}.'
161-
)

qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from attrs import frozen
2222

2323
from qualtran import BoundedQUInt, QAny, QBit, Register
24-
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
24+
from qualtran._infra.gate_with_registers import (
25+
get_named_qubits,
26+
SpecializedSingleQubitControlledGate,
27+
total_bits,
28+
)
2529
from qualtran.bloqs.mean_estimation.mean_estimation_operator import (
2630
CodeForRandomVariable,
2731
MeanEstimationOperator,
@@ -51,7 +55,7 @@ def decompose_from_registers( # type:ignore[override]
5155

5256

5357
@frozen
54-
class BernoulliEncoder(SelectOracle):
58+
class BernoulliEncoder(SpecializedSingleQubitControlledGate, SelectOracle): # type: ignore[misc]
5559
r"""Encodes Bernoulli random variable y0/y1 as $Enc|ii..i>|0> = |ii..i>|y_{i}>$ where i=0/1."""
5660

5761
p: float
@@ -86,10 +90,6 @@ def decompose_from_registers( # type:ignore[override]
8690
if y1:
8791
yield cirq.X(tq).controlled_by(*q, control_values=[1] * self.selection_bitsize)
8892

89-
def controlled(self, *args, **kwargs):
90-
cv = kwargs['control_values'][0]
91-
return BernoulliEncoder(self.p, self.y, self.selection_bitsize, self.target_bitsize, cv)
92-
9393
@cached_property
9494
def mu(self) -> float:
9595
return self.p * self.y[1] + (1 - self.p) * self.y[0]
@@ -278,8 +278,6 @@ def test_mean_estimation_operator_consistent_protocols():
278278
mean_gate.controlled(num_controls=1, control_values=(0,)),
279279
op.controlled_by(cirq.q("control"), control_values=(0,)).gate,
280280
)
281-
with pytest.raises(NotImplementedError, match="Cannot create a controlled version"):
282-
_ = mean_gate.controlled(num_controls=2)
283281

284282
# Test diagrams
285283
n_qubits = cirq.num_qubits(mean_gate)

0 commit comments

Comments
 (0)