Skip to content

Commit 0378502

Browse files
authored
Refactor formula for getting costs of QROAM (#943)
1 parent abfbfab commit 0378502

3 files changed

Lines changed: 47 additions & 22 deletions

File tree

qualtran/bloqs/chemistry/black_boxes.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,39 @@
3232
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
3333

3434

35-
def get_qroam_cost(
35+
def qroam_cost(x, data_size: int, bitsize: int, adjoint: bool = False):
36+
# See appendix B of https://arxiv.org/pdf/1902.02134
37+
if adjoint:
38+
return data_size / x + x
39+
else:
40+
return data_size / x + bitsize * (x - 1)
41+
42+
43+
def qroam_cost_dirty(x, data_size: int, bitsize: int, adjoint: bool = False):
44+
# See appendix A of https://arxiv.org/pdf/1902.02134
45+
if adjoint:
46+
return data_size / x + x
47+
else:
48+
return 2 * (data_size / x - 1) + 4 * bitsize * (x - 1)
49+
50+
51+
def get_optimal_log_block_size_clean_ancilla(
52+
data_size: int, bitsize: int, adjoint: bool = False, qroam_block_size: Optional[int] = None
53+
) -> int:
54+
if qroam_block_size is None:
55+
if adjoint:
56+
log_blk = 0.5 * np.log2(data_size)
57+
qroam_block_size = 2**log_blk
58+
else:
59+
log_blk = 0.5 * np.log2(data_size / bitsize)
60+
assert log_blk >= 0
61+
qroam_block_size = 2**log_blk
62+
k = np.log2(qroam_block_size)
63+
k_int = np.array([np.floor(k), np.ceil(k)])
64+
return int(k_int[np.argmin(qroam_cost(2**k_int, data_size, bitsize, adjoint))])
65+
66+
67+
def get_qroam_cost_clean_ancilla(
3668
data_size: int, bitsize: int, adjoint: bool = False, qroam_block_size: Optional[int] = None
3769
) -> int:
3870
"""This gives the optimal k and minimum cost for a QROM over L values of size M.
@@ -51,22 +83,8 @@ def get_qroam_cost(
5183
"""
5284
if qroam_block_size == 1:
5385
return data_size - 1
54-
if adjoint:
55-
if qroam_block_size is None:
56-
log_blk = 0.5 * np.log2(data_size)
57-
qroam_block_size = 2**log_blk
58-
value = lambda x: data_size / x + x
59-
else:
60-
if qroam_block_size is None:
61-
log_blk = 0.5 * np.log2(data_size / bitsize)
62-
assert log_blk >= 0
63-
qroam_block_size = 2**log_blk
64-
value = lambda x: data_size / x + bitsize * (x - 1)
65-
k = np.log2(qroam_block_size)
66-
k_int = np.array([np.floor(k), np.ceil(k)])
67-
k_opt = k_int[np.argmin(value(2**k_int))]
68-
val_opt = np.ceil(value(2**k_opt))
69-
return int(val_opt)
86+
k_opt = get_optimal_log_block_size_clean_ancilla(data_size, bitsize, adjoint, qroam_block_size)
87+
return int(np.ceil(qroam_cost(2**k_opt, data_size, bitsize, adjoint)))
7088

7189

7290
@frozen
@@ -96,7 +114,7 @@ def signature(self) -> Signature:
96114
return Signature.build(sel=(self.data_size - 1).bit_length(), trg=self.target_bitsize)
97115

98116
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
99-
cost = get_qroam_cost(
117+
cost = get_qroam_cost_clean_ancilla(
100118
self.data_size,
101119
self.target_bitsize,
102120
adjoint=self.is_adjoint,

qualtran/bloqs/chemistry/black_boxes_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import pytest
1515
from openfermion.resource_estimates.utils import QI, QR
1616

17-
from qualtran.bloqs.chemistry.black_boxes import get_qroam_cost
17+
from qualtran.bloqs.chemistry.black_boxes import get_qroam_cost_clean_ancilla
1818

1919

2020
@pytest.mark.parametrize("data_size, bitsize", ((100, 10), (100, 3), (1_000, 13), (1_000_000, 20)))
2121
def test_qroam_factors(data_size, bitsize):
22-
assert get_qroam_cost(data_size, bitsize) == QR(data_size, bitsize)[-1]
23-
assert get_qroam_cost(data_size, bitsize, adjoint=True) == QI(data_size)[-1]
22+
assert get_qroam_cost_clean_ancilla(data_size, bitsize) == QR(data_size, bitsize)[-1]
23+
assert get_qroam_cost_clean_ancilla(data_size, bitsize, adjoint=True) == QI(data_size)[-1]

qualtran/bloqs/chemistry/chem_tutorials.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def plot_linear_log_log(
103103
y_vals = np.exp(intr) * x_vals**slope
104104
if label is None:
105105
label = ''
106-
ax.loglog(xs, ys, marker='o', ls='None', label=rf'{label} $N^{{{slope:3.2f}}}$', color=color)
106+
ax.loglog(
107+
xs,
108+
ys,
109+
marker='o',
110+
ls='None',
111+
label=rf'{label} ${{{np.exp(intr):3.1f}}}N^{{{slope:3.2f}}}$',
112+
color=color,
113+
)
107114
ax.loglog(x_vals, y_vals, marker='None', linestyle='--', color=color)
108115
ax.legend()

0 commit comments

Comments
 (0)