Skip to content

Commit bae2356

Browse files
authored
[Costs] Bloq & Gate counts (#958)
* bloq counts * [counts] test and docs * [counts] real imports in test files * support symbolics * merge fixes
1 parent 5aabc66 commit bae2356

8 files changed

Lines changed: 318 additions & 21 deletions

File tree

qualtran/bloqs/for_testing/costing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ def make_example_costing_bloqs():
6666
'Func1', num_qubits=10, callees=[(TGate(), 10), (TGate().adjoint(), 10), (Hadamard(), 10)]
6767
)
6868
func2 = CostingBloq('Func2', num_qubits=3, callees=[(Toffoli(), 100)])
69-
algo = CostingBloq('Algo', num_qubits=100, callees=[(func1, 1), (func2, 1)])
69+
algo = CostingBloq('Algo', num_qubits=100, callees=[(func1, 2), (func2, 1)])
7070
return algo

qualtran/bloqs/for_testing/costing_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_costing_bloqs():
2222
assert (
2323
format_call_graph_debug_text(g)
2424
== """\
25-
Algo -- 1 -> Func1
25+
Algo -- 2 -> Func1
2626
Algo -- 1 -> Func2
2727
Func1 -- 10 -> Hadamard()
2828
Func1 -- 10 -> TGate()

qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,24 @@
1919
import sympy
2020
from numpy.typing import NDArray
2121

22+
from qualtran.bloqs.basic_gates import TGate, TwoBitCSwap
2223
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
2324
from qualtran.bloqs.for_testing.random_select_and_prepare import random_qubitization_walk_operator
25+
from qualtran.bloqs.hamiltonian_simulation.hamiltonian_simulation_by_gqsp import (
26+
_hubbard_time_evolution_by_gqsp,
27+
_symbolic_hamsim_by_gqsp,
28+
HamiltonianSimulationByGQSP,
29+
)
2430
from qualtran.bloqs.qsp.generalized_qsp_test import (
2531
assert_matrices_almost_equal,
2632
check_polynomial_pair_on_random_points_on_unit_circle,
2733
verify_generalized_qsp,
2834
)
2935
from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator
3036
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
31-
from qualtran.resource_counting import big_O
37+
from qualtran.resource_counting import big_O, BloqCount, get_cost_value
3238
from qualtran.symbolics import Shaped
3339

34-
from .hamiltonian_simulation_by_gqsp import (
35-
_hubbard_time_evolution_by_gqsp,
36-
_symbolic_hamsim_by_gqsp,
37-
HamiltonianSimulationByGQSP,
38-
)
39-
4040

4141
def test_examples(bloq_autotester):
4242
bloq_autotester(_hubbard_time_evolution_by_gqsp)
@@ -102,7 +102,10 @@ def test_hamiltonian_simulation_by_gqsp(
102102

103103
def test_hamiltonian_simulation_by_gqsp_t_complexity():
104104
hubbard_time_evolution_by_gqsp = _hubbard_time_evolution_by_gqsp.make()
105-
_ = hubbard_time_evolution_by_gqsp.t_complexity()
105+
t_comp = hubbard_time_evolution_by_gqsp.t_complexity()
106+
107+
counts = get_cost_value(hubbard_time_evolution_by_gqsp, BloqCount.for_gateset('t+tof+cswap'))
108+
assert t_comp.t == counts[TwoBitCSwap()] * 7 + counts[TGate()]
106109

107110
symbolic_hamsim_by_gqsp = _symbolic_hamsim_by_gqsp()
108111
tau, t, inv_eps = sympy.symbols(r"\tau t \epsilon^{-1}", positive=True)

qualtran/resource_counting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@
3333

3434
from ._success_prob import SuccessProb
3535
from ._qubit_counts import QubitCount
36+
from ._bloq_counts import BloqCount, QECGatesCost, GateCounts
3637

3738
from . import generalizers
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
from collections import defaultdict
16+
from typing import Callable, Dict, Sequence, Tuple, TYPE_CHECKING
17+
18+
import attrs
19+
import networkx as nx
20+
from attrs import field, frozen
21+
22+
from ._call_graph import get_bloq_callee_counts
23+
from ._costing import CostKey
24+
from .classify_bloqs import bloq_is_clifford
25+
26+
if TYPE_CHECKING:
27+
from qualtran import Bloq
28+
29+
logger = logging.getLogger(__name__)
30+
31+
BloqCountDict = Dict['Bloq', int]
32+
33+
34+
def _gateset_bloqs_to_tuple(bloqs: Sequence['Bloq']) -> Tuple['Bloq', ...]:
35+
return tuple(bloqs)
36+
37+
38+
@frozen
39+
class BloqCount(CostKey[BloqCountDict]):
40+
"""A cost which is the count of a specific set of bloqs forming a gateset.
41+
42+
Often, we wish to know the number of specific gates in our algorithm. This is a generic
43+
CostKey that can count any gate (bloq) of interest.
44+
45+
The cost value type for this cost is a mapping from bloq to its count.
46+
47+
Args:
48+
gateset_bloqs: A sequence of bloqs which we will count. Bloqs are counted according
49+
to their equality operator.
50+
gateset_name: A string name of the gateset. Used for display and debugging purposes.
51+
"""
52+
53+
gateset_bloqs: Sequence['Bloq'] = field(converter=_gateset_bloqs_to_tuple)
54+
gateset_name: str
55+
56+
@classmethod
57+
def for_gateset(cls, gateset_name: str):
58+
"""Helper constructor to configure this cost for some common gatesets.
59+
60+
Args:
61+
gateset_name: One of 't', 't+tof', 't+tof+cswap'. This will construct a
62+
`BloqCount` cost with the indicated gates as the `gateset_bloqs`. In all
63+
cases, both TGate and its adjoint are included.
64+
"""
65+
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap
66+
67+
bloqs: Tuple['Bloq', ...]
68+
if gateset_name == 't':
69+
bloqs = (TGate(), TGate(is_adjoint=True))
70+
elif gateset_name == 't+tof':
71+
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli())
72+
elif gateset_name == 't+tof+cswap':
73+
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli(), TwoBitCSwap())
74+
else:
75+
raise ValueError(f"Unknown gateset name {gateset_name}")
76+
77+
return cls(bloqs, gateset_name=gateset_name)
78+
79+
@classmethod
80+
def for_call_graph_leaf_bloqs(cls, g: nx.DiGraph):
81+
"""Helper constructor to configure this cost for 'leaf' bloqs in a given call graph.
82+
83+
Args:
84+
g: The call graph. Its leaves will be used for `gateset_bloqs`. This call graph
85+
can be generated from `Bloq.call_graph()`
86+
"""
87+
leaf_bloqs = {node for node in g.nodes if not g.succ[node]}
88+
return cls(tuple(leaf_bloqs), gateset_name='leaf')
89+
90+
def compute(
91+
self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], BloqCountDict]
92+
) -> BloqCountDict:
93+
if bloq in self.gateset_bloqs:
94+
logger.info("Computing %s: %s is in the target gateset.", self, bloq)
95+
return {bloq: 1}
96+
97+
totals: BloqCountDict = defaultdict(lambda: 0)
98+
callees = get_bloq_callee_counts(bloq)
99+
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees))
100+
for callee, n_times_called in callees:
101+
callee_cost = get_callee_cost(callee)
102+
for gateset_bloq, count in callee_cost.items():
103+
totals[gateset_bloq] += n_times_called * count
104+
105+
return dict(totals)
106+
107+
def zero(self) -> BloqCountDict:
108+
# The additive identity of the bloq counts dictionary is an empty dictionary.
109+
return {}
110+
111+
def __str__(self):
112+
return f'{self.gateset_name} counts'
113+
114+
115+
@frozen(kw_only=True)
116+
class GateCounts:
117+
"""A data class of counts of the typical target gates in a compilation.
118+
119+
Specifically, this class holds counts for the number of `TGate` (and adjoint), `Toffoli`,
120+
`TwoBitCSwap`, `And`, and clifford bloqs.
121+
"""
122+
123+
t: int = 0
124+
toffoli: int = 0
125+
cswap: int = 0
126+
and_bloq: int = 0
127+
clifford: int = 0
128+
129+
def __add__(self, other):
130+
if not isinstance(other, GateCounts):
131+
raise TypeError(f"Can only add other `GateCounts` objects, not {self}")
132+
133+
return GateCounts(
134+
t=self.t + other.t,
135+
toffoli=self.toffoli + other.toffoli,
136+
cswap=self.cswap + other.cswap,
137+
and_bloq=self.and_bloq + other.and_bloq,
138+
clifford=self.clifford + other.clifford,
139+
)
140+
141+
def __mul__(self, other):
142+
return GateCounts(
143+
t=other * self.t,
144+
toffoli=other * self.toffoli,
145+
cswap=other * self.cswap,
146+
and_bloq=other * self.and_bloq,
147+
clifford=other * self.clifford,
148+
)
149+
150+
def __rmul__(self, other):
151+
return self.__mul__(other)
152+
153+
def __str__(self):
154+
strs = []
155+
for f in attrs.fields(self.__class__):
156+
val = getattr(self, f.name)
157+
if val != 0:
158+
strs.append(f'{f.name}: {val}')
159+
160+
if strs:
161+
return ', '.join(strs)
162+
return '-'
163+
164+
def total_t_count(
165+
self, ts_per_toffoli: int = 4, ts_per_cswap: int = 7, ts_per_and_bloq: int = 4
166+
) -> int:
167+
"""Get the total number of T Gates for the `GateCounts` object.
168+
169+
This simply multiplies each gate type by its cost in terms of T gates, which is configurable
170+
via the arguments to this method.
171+
"""
172+
return (
173+
self.t
174+
+ ts_per_toffoli * self.toffoli
175+
+ ts_per_cswap * self.cswap
176+
+ ts_per_and_bloq * self.and_bloq
177+
)
178+
179+
180+
@frozen
181+
class QECGatesCost(CostKey[GateCounts]):
182+
"""Counts specifically for 'expensive' gates in a surface code error correction scheme.
183+
184+
The cost value type for this CostKey is `GateCounts`.
185+
"""
186+
187+
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts:
188+
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap
189+
from qualtran.bloqs.mcmt.and_bloq import And
190+
191+
# T gates
192+
if isinstance(bloq, TGate):
193+
return GateCounts(t=1)
194+
195+
# Toffolis
196+
if isinstance(bloq, Toffoli):
197+
return GateCounts(toffoli=1)
198+
199+
# 'And' bloqs
200+
if isinstance(bloq, And) and not bloq.uncompute:
201+
return GateCounts(and_bloq=1)
202+
203+
# CSwaps aka Fredkin
204+
if isinstance(bloq, TwoBitCSwap):
205+
return GateCounts(cswap=1)
206+
207+
# Cliffords
208+
if bloq_is_clifford(bloq):
209+
return GateCounts(clifford=1)
210+
211+
# Recursive case
212+
totals = GateCounts()
213+
callees = get_bloq_callee_counts(bloq)
214+
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees))
215+
for callee, n_times_called in callees:
216+
callee_cost = get_callee_cost(callee)
217+
totals += n_times_called * callee_cost
218+
return totals
219+
220+
def zero(self) -> GateCounts:
221+
return GateCounts()
222+
223+
def validate_val(self, val: GateCounts):
224+
if not isinstance(val, GateCounts):
225+
raise TypeError(f"{self} values should be `GateCounts`, got {val}")
226+
227+
def __str__(self):
228+
return 'gate counts'
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli
15+
from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs
16+
from qualtran.resource_counting import BloqCount, GateCounts, get_cost_value, QECGatesCost
17+
18+
19+
def test_bloq_count():
20+
algo = make_example_costing_bloqs()
21+
22+
cost = BloqCount([Toffoli()], 'toffoli')
23+
tof_count = get_cost_value(algo, cost)
24+
25+
# `make_example_costing_bloqs` has `func` and `func2`. `func2` has 100 Tof
26+
assert tof_count == {Toffoli(): 100}
27+
28+
t_and_tof_count = get_cost_value(algo, BloqCount.for_gateset('t+tof'))
29+
assert t_and_tof_count == {Toffoli(): 100, TGate(): 2 * 10, TGate().adjoint(): 2 * 10}
30+
31+
g, _ = algo.call_graph()
32+
leaf = BloqCount.for_call_graph_leaf_bloqs(g)
33+
# Note: Toffoli has a decomposition in terms of T gates.
34+
assert set(leaf.gateset_bloqs) == {Hadamard(), TGate(), TGate().adjoint()}
35+
36+
t_count = get_cost_value(algo, leaf)
37+
assert t_count == {TGate(): 2 * 10 + 100 * 4, TGate().adjoint(): 2 * 10, Hadamard(): 2 * 10}
38+
39+
# count things other than leaf bloqs
40+
top_level = get_cost_value(algo, BloqCount([bloq for bloq, n in algo.callees], 'top'))
41+
assert sorted(f'{k}: {v}' for k, v in top_level.items()) == ['Func1: 2', 'Func2: 1']
42+
43+
44+
def test_gate_counts():
45+
gc = GateCounts(t=100, toffoli=13)
46+
assert str(gc) == 't: 100, toffoli: 13'
47+
48+
assert GateCounts(t=10) * 2 == GateCounts(t=20)
49+
assert 2 * GateCounts(t=10) == GateCounts(t=20)
50+
51+
assert GateCounts(toffoli=1, cswap=1, and_bloq=1).total_t_count() == 4 + 7 + 4
52+
53+
54+
def test_qec_gates_cost():
55+
algo = make_example_costing_bloqs()
56+
gc = get_cost_value(algo, QECGatesCost())
57+
assert gc == GateCounts(toffoli=100, t=2 * 2 * 10, clifford=2 * 10)

qualtran/resource_counting/classify_bloqs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,19 @@ def classify_t_count_by_bloq_type(
105105
classification = classify_bloq(k, bloq_classification)
106106
classified_bloqs[classification] += v * t_counts_from_sigma(k.call_graph()[1])
107107
return classified_bloqs
108+
109+
110+
def bloq_is_clifford(b: Bloq):
111+
from qualtran.bloqs.basic_gates import CNOT, Hadamard, SGate, TwoBitSwap, XGate, ZGate
112+
from qualtran.bloqs.bookkeeping import ArbitraryClifford
113+
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiTargetCNOT
114+
115+
if isinstance(b, Adjoint):
116+
b = b.subbloq
117+
118+
if isinstance(
119+
b, (TwoBitSwap, Hadamard, XGate, ZGate, ArbitraryClifford, CNOT, MultiTargetCNOT, SGate)
120+
):
121+
return True
122+
123+
return False

0 commit comments

Comments
 (0)