Skip to content

Commit 7b602c1

Browse files
anurudhpmpharrigan
andauthored
Fix t_complexity for symbolic ham. sim. by gqsp example (#944)
* investigate `t_complexity` failure for symbolic HamSimbyGQSP * `UtilBloq` -> `_BookkeepingBloq`, improve docstring --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 524209d commit 7b602c1

5 files changed

Lines changed: 69 additions & 54 deletions

File tree

qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import cast, Dict, Tuple, TYPE_CHECKING, Union
15+
from typing import cast, Dict, Set, Tuple, TYPE_CHECKING, Union
1616

1717
import numpy as np
1818
from attrs import field, frozen
@@ -29,6 +29,7 @@
2929

3030
if TYPE_CHECKING:
3131
from qualtran import BloqBuilder, SoquetT
32+
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
3233

3334

3435
@frozen
@@ -176,6 +177,21 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
176177

177178
return soqs
178179

180+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
181+
if self.is_symbolic():
182+
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate
183+
184+
d = self.degree
185+
return {
186+
(self.walk_operator.prepare, 1),
187+
(self.walk_operator.prepare.adjoint(), 1),
188+
(self.walk_operator.controlled(control_values=[0]), d),
189+
(self.walk_operator.adjoint().controlled(), d),
190+
(SU2RotationGate.arbitrary(ssa), 2 * d + 1),
191+
}
192+
193+
return super().build_call_graph(ssa)
194+
179195

180196
@bloq_example
181197
def _hubbard_time_evolution_by_gqsp() -> HamiltonianSimulationByGQSP:
@@ -192,9 +208,8 @@ def _symbolic_hamsim_by_gqsp() -> HamiltonianSimulationByGQSP:
192208

193209
from qualtran.bloqs.hubbard_model import get_walk_operator_for_hubbard_model
194210

195-
walk_op = get_walk_operator_for_hubbard_model(2, 2, 1, 1)
196-
197-
t, inv_eps = sympy.symbols("t N")
211+
tau, t, inv_eps = sympy.symbols(r"\tau t \epsilon^{-1}", positive=True)
212+
walk_op = get_walk_operator_for_hubbard_model(2, 2, tau, 4 * tau)
198213
symbolic_hamsim_by_gqsp = HamiltonianSimulationByGQSP(walk_op, t=t, precision=1 / inv_eps)
199214
return symbolic_hamsim_by_gqsp
200215

qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pytest
1818
import scipy
19+
import sympy
1920
from numpy.typing import NDArray
2021

2122
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
@@ -26,6 +27,8 @@
2627
verify_generalized_qsp,
2728
)
2829
from qualtran.bloqs.qubitization_walk_operator import QubitizationWalkOperator
30+
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
31+
from qualtran.resource_counting import big_O
2932
from qualtran.symbolics import Shaped
3033

3134
from .hamiltonian_simulation_by_gqsp import (
@@ -100,3 +103,8 @@ def test_hamiltonian_simulation_by_gqsp(
100103
def test_hamiltonian_simulation_by_gqsp_t_complexity():
101104
hubbard_time_evolution_by_gqsp = _hubbard_time_evolution_by_gqsp.make()
102105
_ = hubbard_time_evolution_by_gqsp.t_complexity()
106+
107+
symbolic_hamsim_by_gqsp = _symbolic_hamsim_by_gqsp()
108+
tau, t, inv_eps = sympy.symbols(r"\tau t \epsilon^{-1}", positive=True)
109+
T = big_O(tau * t + sympy.log(inv_eps) / sympy.log(sympy.log(inv_eps)))
110+
assert symbolic_hamsim_by_gqsp.t_complexity() == TComplexity(t=T, clifford=T, rotations=T) # type: ignore[arg-type]

qualtran/bloqs/hubbard_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from qualtran.bloqs.state_preparation.prepare_uniform_superposition import (
6767
PrepareUniformSuperposition,
6868
)
69+
from qualtran.symbolics.math_funcs import acos, ssqrt
6970

7071
if TYPE_CHECKING:
7172
from qualtran.symbolics import SymbolicFloat
@@ -308,7 +309,7 @@ def decompose_from_registers(
308309
temp = quregs['temp']
309310

310311
N = self.x_dim * self.y_dim * 2
311-
yield cirq.Ry(rads=2 * np.arccos(np.sqrt(self.t * N / self.l1_norm_of_coeffs))).on(*V)
312+
yield cirq.Ry(rads=2 * acos(ssqrt(self.t * N / self.l1_norm_of_coeffs))).on(*V)
312313
yield cirq.Ry(rads=2 * np.arccos(np.sqrt(1 / 5))).on(*U).controlled_by(*V)
313314
yield PrepareUniformSuperposition(self.x_dim).on_registers(controls=[], target=p_x)
314315
yield PrepareUniformSuperposition(self.y_dim).on_registers(controls=[], target=p_y)

qualtran/bloqs/util_bloqs.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Bloqs for virtual operations and register reshaping."""
16-
16+
import abc
1717
from functools import cached_property
1818
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
1919

@@ -50,8 +50,32 @@
5050
from qualtran.simulation.classical_sim import ClassicalValT
5151

5252

53+
class _BookkeepingBloq(Bloq, metaclass=abc.ABCMeta):
54+
"""Base class for utility bloqs used for bookkeeping.
55+
56+
This bloq:
57+
- has trivial controlled versions, which pass through the control register.
58+
- does not affect T complexity.
59+
"""
60+
61+
def get_ctrl_system(
62+
self, ctrl_spec: Optional['CtrlSpec'] = None
63+
) -> Tuple['Bloq', 'AddControlledT']:
64+
def add_controlled(
65+
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
66+
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
67+
# ignore `ctrl_soq` and pass it through for bookkeeping operation.
68+
out_soqs = bb.add_t(self, **in_soqs)
69+
return ctrl_soqs, out_soqs
70+
71+
return self, add_controlled
72+
73+
def _t_complexity_(self) -> 'TComplexity':
74+
return TComplexity()
75+
76+
5377
@frozen
54-
class Split(Bloq):
78+
class Split(_BookkeepingBloq):
5579
"""Split a bitsize `n` register into a length-`n` array-register.
5680
5781
Attributes:
@@ -75,9 +99,6 @@ def adjoint(self) -> 'Bloq':
7599
def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]:
76100
return None, {'reg': reg.reshape((self.dtype.num_qubits, 1))}
77101

78-
def _t_complexity_(self) -> 'TComplexity':
79-
return TComplexity()
80-
81102
def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
82103
return {'reg': ints_to_bits(np.array([reg]), self.dtype.num_qubits)[0]}
83104

@@ -103,16 +124,6 @@ def add_my_tensors(
103124
)
104125
)
105126

106-
def get_ctrl_system(self, ctrl_spec=None) -> Tuple['Bloq', 'AddControlledT']:
107-
def add_controlled(
108-
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
109-
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
110-
# ignore `ctrl_soq` and pass it through for bookkeeping operation.
111-
out_soqs = bb.add_t(self, **in_soqs)
112-
return ctrl_soqs, out_soqs
113-
114-
return self, add_controlled
115-
116127
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
117128
if reg is None:
118129
return Text(self.pretty_name())
@@ -123,7 +134,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
123134

124135

125136
@frozen
126-
class Join(Bloq):
137+
class Join(_BookkeepingBloq):
127138
"""Join a length-`n` array-register into one register of bitsize `n`.
128139
129140
Attributes:
@@ -147,9 +158,6 @@ def adjoint(self) -> 'Bloq':
147158
def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]:
148159
return None, {'reg': reg.reshape(self.dtype.num_qubits)}
149160

150-
def _t_complexity_(self) -> 'TComplexity':
151-
return TComplexity()
152-
153161
def add_my_tensors(
154162
self,
155163
tn: 'qtn.TensorNetwork',
@@ -175,18 +183,6 @@ def add_my_tensors(
175183
def on_classical_vals(self, reg: 'NDArray[np.uint]') -> Dict[str, int]:
176184
return {'reg': bits_to_ints(reg)[0]}
177185

178-
def get_ctrl_system(
179-
self, ctrl_spec: Optional['CtrlSpec'] = None
180-
) -> Tuple['Bloq', 'AddControlledT']:
181-
def add_controlled(
182-
bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: Dict[str, 'SoquetT']
183-
) -> Tuple[Iterable['SoquetT'], Iterable['SoquetT']]:
184-
# ignore `ctrl_soq` and pass it through for bookkeeping operation.
185-
out_soqs = bb.add_t(self, **in_soqs)
186-
return ctrl_soqs, out_soqs
187-
188-
return self, add_controlled
189-
190186
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
191187
if reg is None:
192188
return Text('')
@@ -197,7 +193,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
197193

198194

199195
@frozen
200-
class Partition(Bloq):
196+
class Partition(_BookkeepingBloq):
201197
"""Partition a generic index into multiple registers.
202198
203199
Args:
@@ -240,9 +236,6 @@ def as_cirq_op(self, qubit_manager, **cirq_quregs) -> Tuple[None, Dict[str, 'Cir
240236
else:
241237
return None, {'x': np.concatenate([v.ravel() for _, v in cirq_quregs.items()])}
242238

243-
def _t_complexity_(self) -> 'TComplexity':
244-
return TComplexity()
245-
246239
def add_my_tensors(
247240
self,
248241
tn: 'qtn.TensorNetwork',
@@ -323,7 +316,7 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
323316

324317

325318
@frozen
326-
class Allocate(Bloq):
319+
class Allocate(_BookkeepingBloq):
327320
"""Allocate an `n` bit register.
328321
329322
Attributes:
@@ -342,9 +335,6 @@ def adjoint(self) -> 'Bloq':
342335
def on_classical_vals(self) -> Dict[str, int]:
343336
return {'reg': 0}
344337

345-
def _t_complexity_(self) -> 'TComplexity':
346-
return TComplexity()
347-
348338
def add_my_tensors(
349339
self,
350340
tn: 'qtn.TensorNetwork',
@@ -367,7 +357,7 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
367357

368358

369359
@frozen
370-
class Free(Bloq):
360+
class Free(_BookkeepingBloq):
371361
"""Free (i.e. de-allocate) an `n` bit register.
372362
373363
The tensor decomposition assumes the `n` bit register is uncomputed and is in the $|0^{n}>$
@@ -392,9 +382,6 @@ def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
392382
raise ValueError(f"Tried to free a non-zero register: {reg}.")
393383
return {}
394384

395-
def _t_complexity_(self) -> 'TComplexity':
396-
return TComplexity()
397-
398385
def add_my_tensors(
399386
self,
400387
tn: 'qtn.TensorNetwork',
@@ -440,13 +427,14 @@ def _t_complexity_(self) -> 'TComplexity':
440427

441428

442429
@frozen
443-
class Cast(Bloq):
430+
class Cast(_BookkeepingBloq):
444431
"""Cast a register from one n-bit QDType to another QDType.
445432
446433
447434
Args:
448-
in_qdtype: Input QDType to cast from.
449-
out_qdtype: Output QDType to cast to.
435+
inp_dtype: Input QDType to cast from.
436+
out_dtype: Output QDType to cast to.
437+
shape: shape of the register to cast.
450438
451439
Registers:
452440
in: input register to cast from.
@@ -501,9 +489,6 @@ def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
501489
def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]:
502490
return None, {'reg': reg}
503491

504-
def _t_complexity_(self) -> 'TComplexity':
505-
return TComplexity()
506-
507492

508493
@frozen
509494
class Power(GateWithRegisters):

qualtran/symbolics/math_funcs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def sabs(x: SymbolicFloat) -> SymbolicFloat:
4141
return cast(SymbolicFloat, abs(x))
4242

4343

44+
def ssqrt(x: SymbolicFloat) -> SymbolicFloat:
45+
if isinstance(x, sympy.Basic):
46+
return sympy.sqrt(x)
47+
return np.sqrt(x)
48+
49+
4450
def ceil(x: SymbolicFloat) -> SymbolicInt:
4551
if not isinstance(x, sympy.Basic):
4652
return int(np.ceil(x))

0 commit comments

Comments
 (0)