1313# limitations under the License.
1414
1515from functools import cached_property
16- from typing import Dict , Set , Tuple , TYPE_CHECKING
16+ from typing import Dict , Iterable , Set , Tuple , TYPE_CHECKING , Union
1717
18+ import attrs
1819import cirq
1920import numpy as np
20- from attrs import frozen
2121from numpy .typing import NDArray
2222
2323from qualtran import (
3333 SoquetT ,
3434)
3535from qualtran .bloqs .swap_network .cswap_approx import CSwapApprox
36+ from qualtran .drawing import TextBox , WireSymbol
3637from qualtran .resource_counting .generalizers import ignore_split_join
38+ from qualtran .symbolics import is_symbolic , prod , SymbolicInt
3739
3840if TYPE_CHECKING :
3941 from qualtran .resource_counting import BloqCountT , SympySymbolAllocator
4042
4143
42- @frozen
44+ def _to_tuple (x : Union [SymbolicInt , Iterable [SymbolicInt ]]) -> Tuple [SymbolicInt , ...]:
45+ if isinstance (x , np .ndarray ):
46+ return _to_tuple (x .tolist ())
47+ if isinstance (x , Iterable ):
48+ return tuple (x )
49+ return (x ,)
50+
51+
52+ @attrs .frozen
4353class SwapWithZero (GateWithRegisters ):
44- """Swaps | Psi_0> with | Psi_x> if selection register stores index `x`.
54+ r """Swaps $|\ Psi_0\rangle$ with $|\ Psi_x\rangle$ if selection register stores index `x`.
4555
46- Implements the unitary U |x> |Psi_0> |Psi_1> ... |Psi_{n-1}> --> |x> |Psi_x> |Rest of Psi>.
47- Note that the state of `|Rest of Psi>` is allowed to be anything and should not be depended
48- upon.
56+ Implements the unitary
57+ $$
58+ U |x\rangle |\Psi_0\rangle |\Psi_1\rangle \dots \Psi_{M-1}\rangle \rightarrow
59+ |x\rangle |\Psi_x\rangle |\text{Rest of } \Psi\rangle$
60+ $$
61+ Note that the state of $|\text{Rest of } \Psi\rangle$ is allowed to be anything and
62+ should not be depended upon.
63+
64+ Also supports the multidimensional version where $|x\rangle$ can be an n-dimensional index
65+ of the form $|x_1\rangle|x_2\rangle \dots |x_n\rangle$
4966
5067 References:
5168 [Trading T-gates for dirty qubits in state preparation and unitary synthesis](https://arxiv.org/abs/1812.00954).
5269 Low, Kliuchnikov, Schaeffer. 2018.
5370 """
5471
55- selection_bitsize : int
56- target_bitsize : int
57- n_target_registers : int
72+ selection_bitsizes : Tuple [ SymbolicInt , ...] = attrs . field ( converter = _to_tuple )
73+ target_bitsize : SymbolicInt
74+ n_target_registers : Tuple [ SymbolicInt , ...] = attrs . field ( converter = _to_tuple )
5875
5976 def __attrs_post_init__ (self ):
60- assert self .n_target_registers <= 2 ** self .selection_bitsize
77+ assert len ( self .n_target_registers ) == len ( self .selection_bitsizes )
6178
6279 @cached_property
6380 def selection_registers (self ) -> Tuple [Register , ...]:
64- return (
65- Register (
66- 'selection' ,
67- BoundedQUInt (
68- bitsize = self . selection_bitsize , iteration_length = self . n_target_registers
69- ),
70- ),
71- )
81+ types = [
82+ BoundedQUInt ( sb , l )
83+ for sb , l in zip ( self . selection_bitsizes , self . n_target_registers )
84+ if is_symbolic ( sb ) or sb > 0
85+ ]
86+ if len ( types ) == 1 :
87+ return ( Register ( 'selection' , types [ 0 ]),)
88+ return tuple ( Register ( f'selection { i } _' , qdtype ) for i , qdtype in enumerate ( types ) )
7289
7390 @cached_property
7491 def target_registers (self ) -> Tuple [Register , ...]:
@@ -80,61 +97,104 @@ def target_registers(self) -> Tuple[Register, ...]:
8097 def signature (self ) -> Signature :
8198 return Signature ([* self .selection_registers , * self .target_registers ])
8299
83- def build_composite_bloq (
84- self , bb : 'BloqBuilder' , selection : Soquet , targets : NDArray [Soquet ] # type: ignore[type-var]
85- ) -> Dict [str , 'SoquetT' ]:
86- cswap_n = CSwapApprox (self .target_bitsize )
87- # Imagine a complete binary tree of depth `logN` with `N` leaves, each denoting a target
88- # register. If the selection register stores index `r`, we want to bring the value stored
89- # in leaf indexed `r` to the leaf indexed `0`. At each node of the binary tree, the left
90- # subtree contains node with current bit 0 and right subtree contains nodes with current
91- # bit 1. Thus, leaf indexed `0` is the leftmost node in the tree.
92- # Start iterating from the root of the tree. If the j'th bit is set in the selection
93- # register (i.e. the control would be activated); we know that the value we are searching
94- # for is in the right subtree. In order to (eventually) bring the desired value to node
95- # 0; we swap all values in the right subtree with all values in the left subtree. This
96- # takes (N / (2 ** (j + 1)) swaps at level `j`.
97- # Therefore, in total, we need $\sum_{j=0}^{logN-1} \frac{N}{2 ^ {j + 1}}$ controlled swaps.
98- selection_dtype = selection .reg .dtype
99- selection = bb .split (selection )
100- for j in range (self .selection_bitsize ):
101- for i in range (0 , self .n_target_registers - 2 ** j , 2 ** (j + 1 )):
102- # The inner loop is executed at-most `N - 1` times, where `N:= len(target_regs)`.
103- sel_i = self .selection_bitsize - j - 1
104- selection [sel_i ], targets [i ], targets [i + 2 ** j ] = bb .add (
105- cswap_n , ctrl = selection [sel_i ], x = targets [i ], y = targets [i + 2 ** j ]
100+ @cached_property
101+ def cswap_n (self ) -> 'CSwapApprox' :
102+ return CSwapApprox (self .target_bitsize )
103+
104+ def build_via_tree (
105+ self ,
106+ bb : 'BloqBuilder' ,
107+ sel : Dict [str , 'Soquet' ],
108+ targets : NDArray ['Soquet' ], # type: ignore[type-var]
109+ idx : Tuple [int , ...],
110+ ) -> None :
111+ sel_idx = len (idx )
112+ if sel_idx == len (self .selection_bitsizes ):
113+ return
114+
115+ n_target_registers = self .n_target_registers [sel_idx ]
116+ assert isinstance (n_target_registers , int )
117+ for i in range (n_target_registers ):
118+ # First make sure that value to be searched is present at the LEFT most position
119+ # of the composite index by recursively swapping the subtrees attached on leaf nodes of
120+ # the current segment tree.
121+ self .build_via_tree (bb , sel , targets , idx + (i ,))
122+
123+ sel_reg = self .selection_registers [sel_idx ] # type: ignore[type-var]
124+ sel_soqs = bb .split (sel [sel_reg .name ])
125+ sel_bitsize = self .selection_registers [sel_idx ].bitsize
126+ for j in range (sel_bitsize ):
127+ # Imagine a complete binary tree of depth `logN` with `N` leaves, each denoting a target
128+ # register. If the selection register stores index `r`, we want to bring the value stored
129+ # in leaf indexed `r` to the leaf indexed `0`. At each node of the binary tree, the left
130+ # subtree contains node with current bit 0 and right subtree contains nodes with current
131+ # bit 1. Thus, leaf indexed `0` is the leftmost node in the tree.
132+ # Start iterating from the root of the tree. If the j'th bit is set in the selection
133+ # register (i.e. the control would be activated); we know that the value we are searching
134+ # for is in the right subtree. In order to (eventually) bring the desired value to node
135+ # 0; we swap all values in the right subtree with all values in the left subtree. This
136+ # takes (N / (2 ** (j + 1)) swaps at level `j`.
137+ # Therefore, in total, we need $\sum_{j=0}^{logN-1} \frac{N}{2 ^ {j + 1}}$ controlled swaps.
138+ sel_i = sel_bitsize - j - 1
139+ for i in range (0 , self .n_target_registers [sel_idx ] - 2 ** j , 2 ** (j + 1 )):
140+ zero_pad = (0 ,) * (len (self .n_target_registers ) - len (idx ) - 1 )
141+ left_idx = idx + (i ,) + zero_pad
142+ right_idx = idx + (i + 2 ** j ,) + zero_pad
143+ sel_soqs [sel_i ], targets [left_idx ], targets [right_idx ] = bb .add (
144+ self .cswap_n , ctrl = sel_soqs [sel_i ], x = targets [left_idx ], y = targets [right_idx ]
106145 )
146+ sel [sel_reg .name ] = bb .join (sel_soqs , dtype = sel_reg .dtype )
107147
108- return {'selection' : bb .join (selection , dtype = selection_dtype ), 'targets' : targets }
148+ def build_composite_bloq (
149+ self , bb : 'BloqBuilder' , targets : NDArray ['Soquet' ], ** sel : 'Soquet' # type: ignore[type-var]
150+ ) -> Dict [str , 'SoquetT' ]:
151+ self .build_via_tree (bb , sel , targets , ())
152+ return sel | {'targets' : targets }
109153
110154 def build_call_graph (self , ssa : 'SympySymbolAllocator' ) -> Set ['BloqCountT' ]:
111- num_swaps = np .floor (
112- sum ([self .n_target_registers / (2 ** (j + 1 )) for j in range (self .selection_bitsize )])
113- )
114- return {(CSwapApprox (self .target_bitsize ), int (num_swaps ))}
155+ num_swaps = prod (* [x for x in self .n_target_registers ]) - 1
156+ return {(CSwapApprox (self .target_bitsize ), num_swaps )}
157+
158+ def _circuit_diagram_info_ (self , args ) -> cirq .CircuitDiagramInfo :
159+ from qualtran .cirq_interop ._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
115160
116- def _circuit_diagram_info_ (self , _ ) -> cirq .CircuitDiagramInfo :
117- wire_symbols = ["@(r⇋0)" ] * self .selection_bitsize
118- for i in range (self .n_target_registers ):
119- wire_symbols += [f"swap_{ i } " ] * self .target_bitsize
120- return cirq .CircuitDiagramInfo (wire_symbols = wire_symbols )
161+ return _wire_symbol_to_cirq_diagram_info (self , args )
162+
163+ def wire_symbol (self , reg : Register , idx : Tuple [int , ...] = tuple ()) -> 'WireSymbol' :
164+ if reg is None :
165+ return super ().wire_symbol (reg , idx )
166+ name = reg .name
167+ if 'selection' in name :
168+ return TextBox ('@(r⇋0)' )
169+ elif name == 'targets' :
170+ subscript = "" .join (f"_{ i } " for i in idx )
171+ return TextBox (f'swap{ subscript } ' )
172+ raise ValueError (f'Unrecognized register name { name } ' )
121173
122174
123175@bloq_example (generalizer = ignore_split_join )
124176def _swz () -> SwapWithZero :
125- swz = SwapWithZero (selection_bitsize = 8 , target_bitsize = 32 , n_target_registers = 4 )
177+ swz = SwapWithZero (selection_bitsizes = 8 , target_bitsize = 32 , n_target_registers = 4 )
126178 return swz
127179
128180
129181@bloq_example (generalizer = ignore_split_join )
130182def _swz_small () -> SwapWithZero :
131183 # A small version on four bits.
132- swz_small = SwapWithZero (selection_bitsize = 3 , target_bitsize = 2 , n_target_registers = 2 )
184+ swz_small = SwapWithZero (selection_bitsizes = 3 , target_bitsize = 2 , n_target_registers = 2 )
133185 return swz_small
134186
135187
188+ @bloq_example (generalizer = ignore_split_join )
189+ def _swz_multi_dimensional () -> SwapWithZero :
190+ swz_multi_dimensional = SwapWithZero (
191+ selection_bitsizes = (2 , 2 ), target_bitsize = 2 , n_target_registers = (3 , 4 )
192+ )
193+ return swz_multi_dimensional
194+
195+
136196_SWZ_DOC = BloqDocSpec (
137197 bloq_cls = SwapWithZero ,
138198 import_line = 'from qualtran.bloqs.swap_network import SwapWithZero' ,
139- examples = (_swz , _swz_small ),
199+ examples = (_swz , _swz_small , _swz_multi_dimensional ),
140200)
0 commit comments