Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Sequence, Tuple, Union
from typing import Dict, Iterator, Sequence, Tuple, Union

import attrs
import cirq
Expand All @@ -25,6 +25,7 @@
from qualtran._infra.data_types import BQUInt
from qualtran._infra.gate_with_registers import total_bits
from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate
from qualtran.simulation.classical_sim import ClassicalValT


@attrs.frozen
Expand Down Expand Up @@ -137,5 +138,54 @@ def nth_operation( # type: ignore[override]
yield self.target_gate(target[target_idx]).controlled_by(control)
yield cirq.CZ(*accumulator, target[target_idx])

def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return NotImplemented
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
return NotImplemented
control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']

# When target_gate == cirq.X, flip the selection-th bit in target. The ith bit of a
# size N regirster is addressed with the unsigned integer 2^(N - 1 - i) in our big
# endian convention.
if control and self.target_gate == cirq.X:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
target = (2 ** (max_selection - selection)) ^ target
Comment thread
mpharrigan marked this conversation as resolved.
Comment thread
mpharrigan marked this conversation as resolved.
# When target_gate == cirq.Z, the action is only in the phase.

return {control_name: control, selection_name: selection, 'target': target}

def basis_state_phase(self, **vals) -> Union[complex, None]:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return None
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
return None
control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']
if control:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
# This gate applies Z in positions 0 through (selection - 1). The effect is
# a phase of plus or minus 1 depending on the parity of the number of ones
# in those positions. For an N-bit big endien integer, the first j bits can
# be isolated by shifting right by N - j.
#
# The target gate X has no additional phase, so calculate as in the
# previous paragraph.
if self.target_gate == cirq.X:
num_phases = (target >> (max_selection - selection + 1)).bit_count()
Comment thread
mpharrigan marked this conversation as resolved.
# The taget gate Z is applied in position selection, so consider the full
# range 0 through selection.
else:
num_phases = (target >> (max_selection - selection)).bit_count()
Comment thread
mpharrigan marked this conversation as resolved.
return 1 if (num_phases % 2) == 0 else -1
Comment thread
mpharrigan marked this conversation as resolved.
return 1

def __str__(self):
return f'SelectedMajoranaFermion({self.target_gate})'
16 changes: 15 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion
from qualtran.cirq_interop.testing import GateHelper
from qualtran.testing import assert_valid_bloq_decomposition
from qualtran.testing import (
assert_consistent_phased_classical_action,
assert_valid_bloq_decomposition,
)


@pytest.mark.slow
Expand Down Expand Up @@ -148,3 +151,14 @@ def test_selected_majorana_fermion_gate_make_on():
op = gate.on_registers(**get_named_qubits(gate.signature))
op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature))
assert op == op2


@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)])
@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Z])
def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize, target_gate):
gate = SelectedMajoranaFermion(
Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=target_gate
)
assert_consistent_phased_classical_action(
gate, selection=range(target_bitsize), target=range(2**target_bitsize), control=range(2)
)
27 changes: 27 additions & 0 deletions qualtran/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Side,
)
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
from qualtran.simulation.classical_sim import do_phased_classical_simulation
from qualtran.symbolics import is_symbolic

if TYPE_CHECKING:
Expand Down Expand Up @@ -714,3 +715,29 @@ def assert_consistent_classical_action(
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)


def assert_consistent_phased_classical_action(
bloq: Bloq,
**parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]],
):
"""Check that the bloq has a phased classical action consistent with its decomposition.

Args:
bloq: bloq to test.
parameter_ranges: named arguments giving ranges for each of the registers of the bloq.
"""
cb = bloq.decompose_bloq()
parameter_names = tuple(parameter_ranges.keys())
for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]):
call_with = {p: v for p, v in zip(parameter_names, vals)}
bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with)
decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with)
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)
np.testing.assert_equal(
bloq_phase,
decomposed_phase,
err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}',
)
Loading