Skip to content

Commit b6c7b29

Browse files
authored
SwapWithZero learns how to swap multidimensional selection index. (#954)
* SwapWithZero learns how to swap multidimensional selection index. * Clear output * Fix mypy and use wire_symbol instead of cirq diagram. Fix a bug in _wire_symbol_to_cirq_diagram_info * Fix failing test
1 parent 3097c0e commit b6c7b29

6 files changed

Lines changed: 245 additions & 70 deletions

File tree

qualtran/bloqs/swap_network/cswap_approx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ignore_cliffords,
3131
ignore_split_join,
3232
)
33+
from qualtran.symbolics import SymbolicInt
3334

3435
if TYPE_CHECKING:
3536
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
@@ -59,7 +60,7 @@ class CSwapApprox(GateWithRegisters):
5960
Low et. al. 2018. See Appendix B.2.c.
6061
"""
6162

62-
bitsize: int
63+
bitsize: SymbolicInt
6364

6465
@cached_property
6566
def signature(self) -> Signature:

qualtran/bloqs/swap_network/swap_network.ipynb

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,18 @@
313313
},
314314
"source": [
315315
"## `SwapWithZero`\n",
316-
"Swaps |Psi_0> with |Psi_x> if selection register stores index `x`.\n",
316+
"Swaps $|\\Psi_0\\rangle$ with $|\\Psi_x\\rangle$ if selection register stores index `x`.\n",
317317
"\n",
318-
"Implements the unitary U |x> |Psi_0> |Psi_1> ... |Psi_{n-1}> --> |x> |Psi_x> |Rest of Psi>.\n",
319-
"Note that the state of `|Rest of Psi>` is allowed to be anything and should not be depended\n",
320-
"upon.\n",
318+
"Implements the unitary\n",
319+
"$$\n",
320+
"U |x\\rangle |\\Psi_0\\rangle |\\Psi_1\\rangle \\dots \\Psi_{M-1}\\rangle \\rightarrow\n",
321+
" |x\\rangle |\\Psi_x\\rangle |\\text{Rest of } \\Psi\\rangle$\n",
322+
"$$\n",
323+
"Note that the state of $|\\text{Rest of } \\Psi\\rangle$ is allowed to be anything and\n",
324+
"should not be depended upon.\n",
325+
"\n",
326+
"Also supports the multidimensional version where $|x\\rangle$ can be an n-dimensional index\n",
327+
"of the form $|x_1\\rangle|x_2\\rangle \\dots |x_n\\rangle$\n",
321328
"\n",
322329
"#### References\n",
323330
" - [Trading T-gates for dirty qubits in state preparation and unitary synthesis](https://arxiv.org/abs/1812.00954). Low, Kliuchnikov, Schaeffer. 2018.\n"
@@ -354,7 +361,7 @@
354361
},
355362
"outputs": [],
356363
"source": [
357-
"swz = SwapWithZero(selection_bitsize=8, target_bitsize=32, n_target_registers=4)"
364+
"swz = SwapWithZero(selection_bitsizes=8, target_bitsize=32, n_target_registers=4)"
358365
]
359366
},
360367
{
@@ -367,7 +374,19 @@
367374
"outputs": [],
368375
"source": [
369376
"# A small version on four bits.\n",
370-
"swz_small = SwapWithZero(selection_bitsize=3, target_bitsize=2, n_target_registers=2)"
377+
"swz_small = SwapWithZero(selection_bitsizes=3, target_bitsize=2, n_target_registers=2)"
378+
]
379+
},
380+
{
381+
"cell_type": "code",
382+
"execution_count": null,
383+
"id": "38dfda92-364a-4e65-8914-2ff0a947d421",
384+
"metadata": {},
385+
"outputs": [],
386+
"source": [
387+
"swz_multi_dimensional = SwapWithZero(\n",
388+
" selection_bitsizes=(4, 3), target_bitsize=2, n_target_registers=(15, 5)\n",
389+
")"
371390
]
372391
},
373392
{
@@ -390,8 +409,8 @@
390409
"outputs": [],
391410
"source": [
392411
"from qualtran.drawing import show_bloqs\n",
393-
"show_bloqs([swz, swz_small],\n",
394-
" ['`swz`', '`swz_small`'])"
412+
"show_bloqs([swz, swz_small, swz_multi_dimensional],\n",
413+
" ['`swz`', '`swz_small`', '`swz_multi_dimensional`'])"
395414
]
396415
},
397416
{
@@ -545,6 +564,40 @@
545564
"show_call_graph(multiplexed_cswap_g)\n",
546565
"show_counts_sigma(multiplexed_cswap_sigma)"
547566
]
567+
},
568+
{
569+
"cell_type": "code",
570+
"execution_count": null,
571+
"id": "33258192",
572+
"metadata": {
573+
"cq.autogen": "SwapWithZero.swz_multi_dimensional"
574+
},
575+
"outputs": [],
576+
"source": [
577+
"swz_multi_dimensional = SwapWithZero(\n",
578+
" selection_bitsizes=(2, 2), target_bitsize=2, n_target_registers=(3, 4)\n",
579+
")"
580+
]
581+
},
582+
{
583+
"cell_type": "code",
584+
"execution_count": null,
585+
"id": "80e73ef0",
586+
"metadata": {
587+
"cq.autogen": "SwapWithZero.swz_multi_symbolic"
588+
},
589+
"outputs": [],
590+
"source": [
591+
"# A small version on four bits.\n",
592+
"selection = sympy.symbols(\"p q r\")\n",
593+
"target_bitsize = sympy.Symbol(\"b\")\n",
594+
"n_target_registers = sympy.symbols(\"P Q R\")\n",
595+
"swz_multi_symbolic = SwapWithZero(\n",
596+
" selection_bitsizes=selection,\n",
597+
" target_bitsize=target_bitsize,\n",
598+
" n_target_registers=n_target_registers,\n",
599+
")"
600+
]
548601
}
549602
],
550603
"metadata": {
@@ -563,7 +616,7 @@
563616
"name": "python",
564617
"nbconvert_exporter": "python",
565618
"pygments_lexer": "ipython3",
566-
"version": "3.11.7"
619+
"version": "3.11.8"
567620
}
568621
},
569622
"nbformat": 4,

qualtran/bloqs/swap_network/swap_with_zero.py

Lines changed: 115 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, Set, Tuple, TYPE_CHECKING
16+
from typing import Dict, Iterable, Set, Tuple, TYPE_CHECKING, Union
1717

18+
import attrs
1819
import cirq
1920
import numpy as np
20-
from attrs import frozen
2121
from numpy.typing import NDArray
2222

2323
from qualtran import (
@@ -33,42 +33,59 @@
3333
SoquetT,
3434
)
3535
from qualtran.bloqs.swap_network.cswap_approx import CSwapApprox
36+
from qualtran.drawing import TextBox, WireSymbol
3637
from qualtran.resource_counting.generalizers import ignore_split_join
38+
from qualtran.symbolics import is_symbolic, prod, SymbolicInt
3739

3840
if TYPE_CHECKING:
3941
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
4042

4143

42-
@frozen
44+
def _to_tuple(x: Union[SymbolicInt, Iterable[SymbolicInt]]) -> Tuple[SymbolicInt, ...]:
45+
if isinstance(x, np.ndarray):
46+
return _to_tuple(x.tolist())
47+
if isinstance(x, Iterable):
48+
return tuple(x)
49+
return (x,)
50+
51+
52+
@attrs.frozen
4353
class SwapWithZero(GateWithRegisters):
44-
"""Swaps |Psi_0> with |Psi_x> if selection register stores index `x`.
54+
r"""Swaps $|\Psi_0\rangle$ with $|\Psi_x\rangle$ if selection register stores index `x`.
4555
46-
Implements the unitary U |x> |Psi_0> |Psi_1> ... |Psi_{n-1}> --> |x> |Psi_x> |Rest of Psi>.
47-
Note that the state of `|Rest of Psi>` is allowed to be anything and should not be depended
48-
upon.
56+
Implements the unitary
57+
$$
58+
U |x\rangle |\Psi_0\rangle |\Psi_1\rangle \dots \Psi_{M-1}\rangle \rightarrow
59+
|x\rangle |\Psi_x\rangle |\text{Rest of } \Psi\rangle$
60+
$$
61+
Note that the state of $|\text{Rest of } \Psi\rangle$ is allowed to be anything and
62+
should not be depended upon.
63+
64+
Also supports the multidimensional version where $|x\rangle$ can be an n-dimensional index
65+
of the form $|x_1\rangle|x_2\rangle \dots |x_n\rangle$
4966
5067
References:
5168
[Trading T-gates for dirty qubits in state preparation and unitary synthesis](https://arxiv.org/abs/1812.00954).
5269
Low, Kliuchnikov, Schaeffer. 2018.
5370
"""
5471

55-
selection_bitsize: int
56-
target_bitsize: int
57-
n_target_registers: int
72+
selection_bitsizes: Tuple[SymbolicInt, ...] = attrs.field(converter=_to_tuple)
73+
target_bitsize: SymbolicInt
74+
n_target_registers: Tuple[SymbolicInt, ...] = attrs.field(converter=_to_tuple)
5875

5976
def __attrs_post_init__(self):
60-
assert self.n_target_registers <= 2**self.selection_bitsize
77+
assert len(self.n_target_registers) == len(self.selection_bitsizes)
6178

6279
@cached_property
6380
def selection_registers(self) -> Tuple[Register, ...]:
64-
return (
65-
Register(
66-
'selection',
67-
BoundedQUInt(
68-
bitsize=self.selection_bitsize, iteration_length=self.n_target_registers
69-
),
70-
),
71-
)
81+
types = [
82+
BoundedQUInt(sb, l)
83+
for sb, l in zip(self.selection_bitsizes, self.n_target_registers)
84+
if is_symbolic(sb) or sb > 0
85+
]
86+
if len(types) == 1:
87+
return (Register('selection', types[0]),)
88+
return tuple(Register(f'selection{i}_', qdtype) for i, qdtype in enumerate(types))
7289

7390
@cached_property
7491
def target_registers(self) -> Tuple[Register, ...]:
@@ -80,61 +97,104 @@ def target_registers(self) -> Tuple[Register, ...]:
8097
def signature(self) -> Signature:
8198
return Signature([*self.selection_registers, *self.target_registers])
8299

83-
def build_composite_bloq(
84-
self, bb: 'BloqBuilder', selection: Soquet, targets: NDArray[Soquet] # type: ignore[type-var]
85-
) -> Dict[str, 'SoquetT']:
86-
cswap_n = CSwapApprox(self.target_bitsize)
87-
# Imagine a complete binary tree of depth `logN` with `N` leaves, each denoting a target
88-
# register. If the selection register stores index `r`, we want to bring the value stored
89-
# in leaf indexed `r` to the leaf indexed `0`. At each node of the binary tree, the left
90-
# subtree contains node with current bit 0 and right subtree contains nodes with current
91-
# bit 1. Thus, leaf indexed `0` is the leftmost node in the tree.
92-
# Start iterating from the root of the tree. If the j'th bit is set in the selection
93-
# register (i.e. the control would be activated); we know that the value we are searching
94-
# for is in the right subtree. In order to (eventually) bring the desired value to node
95-
# 0; we swap all values in the right subtree with all values in the left subtree. This
96-
# takes (N / (2 ** (j + 1)) swaps at level `j`.
97-
# Therefore, in total, we need $\sum_{j=0}^{logN-1} \frac{N}{2 ^ {j + 1}}$ controlled swaps.
98-
selection_dtype = selection.reg.dtype
99-
selection = bb.split(selection)
100-
for j in range(self.selection_bitsize):
101-
for i in range(0, self.n_target_registers - 2**j, 2 ** (j + 1)):
102-
# The inner loop is executed at-most `N - 1` times, where `N:= len(target_regs)`.
103-
sel_i = self.selection_bitsize - j - 1
104-
selection[sel_i], targets[i], targets[i + 2**j] = bb.add(
105-
cswap_n, ctrl=selection[sel_i], x=targets[i], y=targets[i + 2**j]
100+
@cached_property
101+
def cswap_n(self) -> 'CSwapApprox':
102+
return CSwapApprox(self.target_bitsize)
103+
104+
def build_via_tree(
105+
self,
106+
bb: 'BloqBuilder',
107+
sel: Dict[str, 'Soquet'],
108+
targets: NDArray['Soquet'], # type: ignore[type-var]
109+
idx: Tuple[int, ...],
110+
) -> None:
111+
sel_idx = len(idx)
112+
if sel_idx == len(self.selection_bitsizes):
113+
return
114+
115+
n_target_registers = self.n_target_registers[sel_idx]
116+
assert isinstance(n_target_registers, int)
117+
for i in range(n_target_registers):
118+
# First make sure that value to be searched is present at the LEFT most position
119+
# of the composite index by recursively swapping the subtrees attached on leaf nodes of
120+
# the current segment tree.
121+
self.build_via_tree(bb, sel, targets, idx + (i,))
122+
123+
sel_reg = self.selection_registers[sel_idx] # type: ignore[type-var]
124+
sel_soqs = bb.split(sel[sel_reg.name])
125+
sel_bitsize = self.selection_registers[sel_idx].bitsize
126+
for j in range(sel_bitsize):
127+
# Imagine a complete binary tree of depth `logN` with `N` leaves, each denoting a target
128+
# register. If the selection register stores index `r`, we want to bring the value stored
129+
# in leaf indexed `r` to the leaf indexed `0`. At each node of the binary tree, the left
130+
# subtree contains node with current bit 0 and right subtree contains nodes with current
131+
# bit 1. Thus, leaf indexed `0` is the leftmost node in the tree.
132+
# Start iterating from the root of the tree. If the j'th bit is set in the selection
133+
# register (i.e. the control would be activated); we know that the value we are searching
134+
# for is in the right subtree. In order to (eventually) bring the desired value to node
135+
# 0; we swap all values in the right subtree with all values in the left subtree. This
136+
# takes (N / (2 ** (j + 1)) swaps at level `j`.
137+
# Therefore, in total, we need $\sum_{j=0}^{logN-1} \frac{N}{2 ^ {j + 1}}$ controlled swaps.
138+
sel_i = sel_bitsize - j - 1
139+
for i in range(0, self.n_target_registers[sel_idx] - 2**j, 2 ** (j + 1)):
140+
zero_pad = (0,) * (len(self.n_target_registers) - len(idx) - 1)
141+
left_idx = idx + (i,) + zero_pad
142+
right_idx = idx + (i + 2**j,) + zero_pad
143+
sel_soqs[sel_i], targets[left_idx], targets[right_idx] = bb.add(
144+
self.cswap_n, ctrl=sel_soqs[sel_i], x=targets[left_idx], y=targets[right_idx]
106145
)
146+
sel[sel_reg.name] = bb.join(sel_soqs, dtype=sel_reg.dtype)
107147

108-
return {'selection': bb.join(selection, dtype=selection_dtype), 'targets': targets}
148+
def build_composite_bloq(
149+
self, bb: 'BloqBuilder', targets: NDArray['Soquet'], **sel: 'Soquet' # type: ignore[type-var]
150+
) -> Dict[str, 'SoquetT']:
151+
self.build_via_tree(bb, sel, targets, ())
152+
return sel | {'targets': targets}
109153

110154
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
111-
num_swaps = np.floor(
112-
sum([self.n_target_registers / (2 ** (j + 1)) for j in range(self.selection_bitsize)])
113-
)
114-
return {(CSwapApprox(self.target_bitsize), int(num_swaps))}
155+
num_swaps = prod(*[x for x in self.n_target_registers]) - 1
156+
return {(CSwapApprox(self.target_bitsize), num_swaps)}
157+
158+
def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:
159+
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
115160

116-
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
117-
wire_symbols = ["@(r⇋0)"] * self.selection_bitsize
118-
for i in range(self.n_target_registers):
119-
wire_symbols += [f"swap_{i}"] * self.target_bitsize
120-
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
161+
return _wire_symbol_to_cirq_diagram_info(self, args)
162+
163+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
164+
if reg is None:
165+
return super().wire_symbol(reg, idx)
166+
name = reg.name
167+
if 'selection' in name:
168+
return TextBox('@(r⇋0)')
169+
elif name == 'targets':
170+
subscript = "".join(f"_{i}" for i in idx)
171+
return TextBox(f'swap{subscript}')
172+
raise ValueError(f'Unrecognized register name {name}')
121173

122174

123175
@bloq_example(generalizer=ignore_split_join)
124176
def _swz() -> SwapWithZero:
125-
swz = SwapWithZero(selection_bitsize=8, target_bitsize=32, n_target_registers=4)
177+
swz = SwapWithZero(selection_bitsizes=8, target_bitsize=32, n_target_registers=4)
126178
return swz
127179

128180

129181
@bloq_example(generalizer=ignore_split_join)
130182
def _swz_small() -> SwapWithZero:
131183
# A small version on four bits.
132-
swz_small = SwapWithZero(selection_bitsize=3, target_bitsize=2, n_target_registers=2)
184+
swz_small = SwapWithZero(selection_bitsizes=3, target_bitsize=2, n_target_registers=2)
133185
return swz_small
134186

135187

188+
@bloq_example(generalizer=ignore_split_join)
189+
def _swz_multi_dimensional() -> SwapWithZero:
190+
swz_multi_dimensional = SwapWithZero(
191+
selection_bitsizes=(2, 2), target_bitsize=2, n_target_registers=(3, 4)
192+
)
193+
return swz_multi_dimensional
194+
195+
136196
_SWZ_DOC = BloqDocSpec(
137197
bloq_cls=SwapWithZero,
138198
import_line='from qualtran.bloqs.swap_network import SwapWithZero',
139-
examples=(_swz, _swz_small),
199+
examples=(_swz, _swz_small, _swz_multi_dimensional),
140200
)

0 commit comments

Comments
 (0)