Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Sequence, Tuple, Union
from typing import Dict, Iterator, Sequence, Tuple, Union

import attrs
import cirq
Expand All @@ -25,6 +25,7 @@
from qualtran._infra.data_types import BQUInt
from qualtran._infra.gate_with_registers import total_bits
from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate
from qualtran.simulation.classical_sim import ClassicalValT


@attrs.frozen
Expand Down Expand Up @@ -137,5 +138,43 @@ def nth_operation( # type: ignore[override]
yield self.target_gate(target[target_idx]).controlled_by(control)
yield cirq.CZ(*accumulator, target[target_idx])

def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return NotImplemented
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return NotImplemented
Comment on lines +144 to +145
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this restriction necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure - it is hard for me to understand what this gate does in the general case. Is my understanding in #1699 (comment) correct?

Comment on lines +144 to +145
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current check len(self.control_registers) > 1 does not account for the case where there are zero control registers. If control_regs is empty, accessing self.control_registers[0] on line 146 will raise an IndexError. Given the implementation assumes exactly one control and one selection register, it's safer to check for != 1.

Suggested change
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return NotImplemented
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
return NotImplemented

control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']

# When target_gate == cirq.X, the action is (modulo phase) a single bitflip.
if control and self.target_gate == cirq.X:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
target = (2 ** (max_selection - selection)) ^ target
Comment on lines +154 to +155
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment describing how this logic works

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be even more explicit with how the bit twiddling operations correspond to the promised action of the subroutine

"flip the selection-th bit in target. the selection-th bit is addressed with the unsigned integer 2^(N - i) in our big endian convention"

or something like that

Comment on lines +153 to +155
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If selection exceeds max_selection, max_selection - selection becomes negative, causing 2 ** (max_selection - selection) to return a float. Bitwise XOR (^) with a float will then raise a TypeError. We should ensure selection <= max_selection before applying the bit flip. If the selection is out of range, the gate should act as identity.

Suggested change
if control and self.target_gate == cirq.X:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
target = (2 ** (max_selection - selection)) ^ target
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
if control and self.target_gate == cirq.X and selection <= max_selection:
target = (2 ** (max_selection - selection)) ^ target

# When target_gate == cirq.Z, the action is only in the phase.

return {control_name: control, selection_name: selection, 'target': target}

def basis_state_phase(self, **vals) -> Union[complex, None]:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return None
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return None
Comment on lines +163 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to on_classical_vals, this check should ensure exactly one control register is present to avoid an IndexError on line 165 when control_regs is empty.

Suggested change
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return None
if len(self.control_registers) != 1 or len(self.selection_registers) != 1:
return None

control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']
if control:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
if self.target_gate == cirq.X:
num_phases = (target >> (max_selection - selection + 1)).bit_count()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain how this works

else:
num_phases = (target >> (max_selection - selection)).bit_count()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain how this works and why it's different from the cirq.X case. Maybe be more defensive and put this behind an elif in case we add more in the future e.g. cirq.Y

return 1 if (num_phases % 2) == 0 else -1
Comment on lines +170 to +176
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If selection exceeds max_selection, the bit shift count max_selection - selection can become negative, leading to a ValueError. Additionally, if selection == max_selection + 1 and target_gate == cirq.X, the shift count is 0, which would incorrectly count all bits in target. The phase should only be calculated when selection <= max_selection.

Suggested change
if control:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
if self.target_gate == cirq.X:
num_phases = (target >> (max_selection - selection + 1)).bit_count()
else:
num_phases = (target >> (max_selection - selection)).bit_count()
return 1 if (num_phases % 2) == 0 else -1
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
if control and selection <= max_selection:
if self.target_gate == cirq.X:
num_phases = (target >> (max_selection - selection + 1)).bit_count()
else:
num_phases = (target >> (max_selection - selection)).bit_count()
return 1 if (num_phases % 2) == 0 else -1
return 1

return 1

def __str__(self):
return f'SelectedMajoranaFermion({self.target_gate})'
16 changes: 15 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion
from qualtran.cirq_interop.testing import GateHelper
from qualtran.testing import assert_valid_bloq_decomposition
from qualtran.testing import (
assert_consistent_phased_classical_action,
assert_valid_bloq_decomposition,
)


@pytest.mark.slow
Expand Down Expand Up @@ -148,3 +151,14 @@ def test_selected_majorana_fermion_gate_make_on():
op = gate.on_registers(**get_named_qubits(gate.signature))
op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature))
assert op == op2


@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)])
@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Z])
def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize, target_gate):
gate = SelectedMajoranaFermion(
Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=target_gate
)
assert_consistent_phased_classical_action(
gate, selection=range(target_bitsize), target=range(2**target_bitsize), control=range(2)
)
27 changes: 27 additions & 0 deletions qualtran/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Side,
)
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
from qualtran.simulation.classical_sim import do_phased_classical_simulation
from qualtran.symbolics import is_symbolic

if TYPE_CHECKING:
Expand Down Expand Up @@ -714,3 +715,29 @@ def assert_consistent_classical_action(
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)


def assert_consistent_phased_classical_action(
bloq: Bloq,
**parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]],
):
"""Check that the bloq has a phased classical action consistent with its decomposition.

Args:
bloq: bloq to test.
parameter_ranges: named arguments giving ranges for each of the registers of the bloq.
"""
cb = bloq.decompose_bloq()
parameter_names = tuple(parameter_ranges.keys())
for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]):
call_with = {p: v for p, v in zip(parameter_names, vals)}
bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with)
decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with)
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)
np.testing.assert_equal(
bloq_phase,
decomposed_phase,
err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}',
)
Loading