1313# limitations under the License.
1414
1515from functools import cached_property
16- from typing import Collection , Iterator , Optional , Sequence , Tuple , Union
16+ from typing import Iterator , Optional , Tuple
1717
1818import attrs
1919import cirq
2020from numpy .typing import NDArray
2121
22- from qualtran import GateWithRegisters , Register , Signature
23- from qualtran ._infra .gate_with_registers import total_bits
22+ from qualtran import CtrlSpec , Register , Signature
23+ from qualtran ._infra .gate_with_registers import SpecializedSingleQubitControlledGate , total_bits
2424from qualtran .bloqs .mean_estimation .complex_phase_oracle import ComplexPhaseOracle
2525from qualtran .bloqs .reflection_using_prepare import ReflectionUsingPrepare
2626from qualtran .bloqs .select_and_prepare import PrepareOracle , SelectOracle
@@ -63,7 +63,7 @@ def __attrs_post_init__(self):
6363
6464
6565@attrs .frozen
66- class MeanEstimationOperator (GateWithRegisters ):
66+ class MeanEstimationOperator (SpecializedSingleQubitControlledGate ):
6767 r"""Mean estimation operator $U=REFL_{p} ROT_{y}$ as per Sec 3.1 of arxiv.org:2208.07544.
6868
6969 The MeanEstimationOperator (aka KO Operator) expects `CodeForRandomVariable` to specify the
@@ -82,21 +82,13 @@ class MeanEstimationOperator(GateWithRegisters):
8282 """
8383
8484 code : CodeForRandomVariable
85- cv : Tuple [int , ...] = attrs .field (
86- converter = lambda v : (v ,) if isinstance (v , int ) else tuple (v ), default = ()
87- )
85+ control_val : Optional [int ] = None
8886 arctan_bitsize : int = 32
8987
90- @cv .validator
91- def _validate_cv (self , attribute , value ):
92- assert value in [(), (0 ,), (1 ,)]
93-
9488 @cached_property
9589 def reflect (self ) -> ReflectionUsingPrepare :
9690 return ReflectionUsingPrepare (
97- self .code .synthesizer ,
98- control_val = None if self .cv == () else self .cv [0 ],
99- global_phase = - 1 ,
91+ self .code .synthesizer , global_phase = - 1 , control_val = self .control_val
10092 )
10193
10294 @cached_property
@@ -126,36 +118,15 @@ def decompose_from_registers(
126118 yield self .select .on_registers (** select_reg )
127119 yield self .reflect .on_registers (** reflect_reg )
128120
121+ def get_single_qubit_controlled_bloq (self , control_val : int ) -> 'MeanEstimationOperator' :
122+ c_encoder = self .code .encoder .controlled (ctrl_spec = CtrlSpec (cvs = control_val ))
123+ assert isinstance (c_encoder , SelectOracle )
124+ c_code = attrs .evolve (self .code , encoder = c_encoder )
125+ return attrs .evolve (self , code = c_code , control_val = control_val )
126+
129127 def _circuit_diagram_info_ (self , args : cirq .CircuitDiagramInfoArgs ) -> cirq .CircuitDiagramInfo :
130- wire_symbols = [] if self .cv == () else [["(0)" , "@" ][self .cv [0 ]]]
128+ wire_symbols = []
129+ if self .control_val is not None :
130+ wire_symbols .append ("@" if self .control_val == 1 else "(0)" )
131131 wire_symbols += ['U_ko' ] * (total_bits (self .signature ) - total_bits (self .control_registers ))
132132 return cirq .CircuitDiagramInfo (wire_symbols = wire_symbols )
133-
134- def controlled (
135- self ,
136- num_controls : Optional [int ] = None ,
137- control_values : Optional [
138- Union [cirq .ops .AbstractControlValues , Sequence [Union [int , Collection [int ]]]]
139- ] = None ,
140- control_qid_shape : Optional [Tuple [int , ...]] = None ,
141- ) -> 'MeanEstimationOperator' :
142- if num_controls is None :
143- num_controls = 1
144- if control_values is None :
145- control_values = [1 ] * num_controls
146- if (
147- isinstance (control_values , Sequence )
148- and len (control_values ) == 1
149- and isinstance (control_values [0 ], int )
150- and not self .cv
151- ):
152- c_select = self .code .encoder .controlled (control_values = control_values )
153- assert isinstance (c_select , SelectOracle )
154- return MeanEstimationOperator (
155- CodeForRandomVariable (encoder = c_select , synthesizer = self .code .synthesizer ),
156- cv = self .cv + (control_values [0 ],),
157- arctan_bitsize = self .arctan_bitsize ,
158- )
159- raise NotImplementedError (
160- f'Cannot create a controlled version of { self } with control_values={ control_values } .'
161- )
0 commit comments