Skip to content

Commit 6c64835

Browse files
committed
bloq counts
1 parent 15418d6 commit 6c64835

4 files changed

Lines changed: 240 additions & 11 deletions

File tree

qualtran/resource_counting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@
3131

3232
from ._costing import GeneralizerT, get_cost_value, get_cost_cache, query_costs, CostKey, CostValT
3333

34+
from ._bloq_counts import BloqCount, QECGatesCost
35+
3436
from . import generalizers
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
from collections import defaultdict
16+
from typing import Callable, Dict, Optional, Tuple, TYPE_CHECKING
17+
18+
import attrs
19+
import networkx as nx
20+
from attrs import frozen
21+
22+
from ._call_graph import get_bloq_callee_counts
23+
from ._costing import CostKey
24+
from .classify_bloqs import bloq_is_clifford
25+
26+
if TYPE_CHECKING:
27+
from qualtran import Bloq
28+
29+
logger = logging.getLogger(__name__)
30+
31+
BloqCountDict = Dict['Bloq', int]
32+
33+
34+
@frozen
35+
class BloqCount(CostKey[BloqCountDict]):
36+
gateset_bloqs: Tuple['Bloq', ...]
37+
gateset_name: str
38+
39+
@classmethod
40+
def for_gateset(cls, gateset_name: str):
41+
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap
42+
43+
if gateset_name == 't':
44+
bloqs = (TGate(), TGate(is_adjoint=True))
45+
elif gateset_name == 't+tof':
46+
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli())
47+
elif gateset_name == 't+tof+cswap':
48+
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli(), TwoBitCSwap())
49+
else:
50+
raise ValueError(f"Unknown gateset name {gateset_name}")
51+
52+
return cls(bloqs, gateset_name=gateset_name)
53+
54+
@classmethod
55+
def for_call_graph_leaf_bloqs(cls, g: nx.DiGraph):
56+
leaf_bloqs = {node for node in g.nodes if not g.succ[node]}
57+
return cls(tuple(leaf_bloqs), gateset_name='leaf')
58+
59+
def compute(
60+
self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], BloqCountDict]
61+
) -> BloqCountDict:
62+
if bloq in self.gateset_bloqs:
63+
logger.info("Computing %s: %s is in the target gateset.", self, bloq)
64+
return {bloq: 1}
65+
66+
totals: BloqCountDict = defaultdict(lambda: 0)
67+
callees = get_bloq_callee_counts(bloq)
68+
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees))
69+
for callee, n_times_called in callees:
70+
callee_cost = get_callee_cost(callee)
71+
for gateset_bloq, count in callee_cost.items():
72+
totals[gateset_bloq] += n_times_called * count
73+
74+
return dict(totals)
75+
76+
def zero(self) -> BloqCountDict:
77+
return {}
78+
79+
def __str__(self):
80+
return f'{self.gateset_name} counts'
81+
82+
83+
@frozen(kw_only=True)
84+
class GateCounts:
85+
t: int = 0
86+
toffoli: int = 0
87+
cswap: int = 0
88+
and_bloq: int = 0
89+
clifford: int = 0
90+
91+
def __add__(self, other):
92+
if not isinstance(other, GateCounts):
93+
raise TypeError(f"Can only add other `GateCounts` objects, not {self}")
94+
95+
return GateCounts(
96+
t=self.t + other.t,
97+
toffoli=self.toffoli + other.toffoli,
98+
cswap=self.cswap + other.cswap,
99+
and_bloq=self.and_bloq + other.and_bloq,
100+
clifford=self.clifford + other.clifford,
101+
)
102+
103+
def __mul__(self, other):
104+
if not isinstance(other, int):
105+
raise TypeError(f"Can only multiply `GateCounts` objects by integers, not {self}")
106+
107+
return GateCounts(
108+
t=other * self.t,
109+
toffoli=other * self.toffoli,
110+
cswap=other * self.cswap,
111+
and_bloq=other * self.and_bloq,
112+
clifford=other * self.clifford,
113+
)
114+
115+
def __rmul__(self, other):
116+
return self.__mul__(other)
117+
118+
def __str__(self):
119+
strs = []
120+
for f in attrs.fields(self.__class__):
121+
val = getattr(self, f.name)
122+
if val != 0:
123+
strs.append(f'{f.name}: {val}')
124+
125+
if strs:
126+
return ', '.join(strs)
127+
return '-'
128+
129+
@property
130+
def total_n_magic(self):
131+
"""The total number of magic states.
132+
133+
This can be used as a rough proxy for total cost. It is the sum of all the attributes
134+
other than `clifford`.
135+
"""
136+
return self.t + self.toffoli + self.cswap + self.and_bloq
137+
138+
@property
139+
def total_n_magic(self):
140+
"""The total number of magic states.
141+
142+
This can be used as a rough proxy for total cost. It is the sum of all the attributes
143+
other than `clifford`.
144+
"""
145+
return self.t + self.toffoli + self.cswap + self.and_bloq
146+
147+
148+
@frozen
149+
class QECGatesCost(CostKey[GateCounts]):
150+
"""Counts specifically for 'expensive' gates in a surface code error correction scheme."""
151+
152+
ts_per_toffoli: Optional[int] = None
153+
toffolis_per_and: Optional[int] = None
154+
ts_per_and: Optional[int] = None
155+
toffolis_per_cswap: Optional[int] = None
156+
ts_per_cswap: Optional[int] = None
157+
158+
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts:
159+
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap
160+
from qualtran.bloqs.mcmt.and_bloq import And
161+
162+
# T gates
163+
if isinstance(bloq, TGate):
164+
return GateCounts(t=1)
165+
166+
# Toffolis
167+
if isinstance(bloq, Toffoli):
168+
if self.ts_per_toffoli is not None:
169+
return GateCounts(t=self.ts_per_toffoli)
170+
else:
171+
return GateCounts(toffoli=1)
172+
173+
# 'And' bloqs
174+
if isinstance(bloq, And) and not bloq.uncompute:
175+
if self.toffolis_per_and is not None:
176+
return GateCounts(toffoli=self.toffolis_per_and * self.ts_per_toffoli)
177+
elif self.ts_per_and is not None:
178+
return GateCounts(t=self.ts_per_and)
179+
else:
180+
return GateCounts(and_bloq=1)
181+
182+
# CSwaps aka Fredkin
183+
if isinstance(bloq, TwoBitCSwap):
184+
if self.toffolis_per_cswap is not None:
185+
return GateCounts(toffoli=self.toffolis_per_cswap)
186+
elif self.ts_per_cswap is not None:
187+
return GateCounts(t=self.ts_per_cswap)
188+
else:
189+
return GateCounts(cswap=1)
190+
191+
# Cliffords
192+
if bloq_is_clifford(bloq):
193+
return GateCounts(clifford=1)
194+
195+
# Recursive case
196+
totals = GateCounts()
197+
callees = get_bloq_callee_counts(bloq)
198+
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees))
199+
for callee, n_times_called in callees:
200+
callee_cost = get_callee_cost(callee)
201+
totals += n_times_called * callee_cost
202+
return totals
203+
204+
def zero(self) -> GateCounts:
205+
return GateCounts()
206+
207+
def validate_val(self, val: GateCounts):
208+
if not isinstance(val, GateCounts):
209+
raise TypeError(f"{self} values should be `GateCounts`, got {val}")
210+
211+
def __str__(self):
212+
gates = ['t']
213+
if self.ts_per_toffoli is None:
214+
gates.append('tof')
215+
if self.toffolis_per_and is None and self.ts_per_and is None:
216+
gates.append('and')
217+
if self.toffolis_per_cswap is None and self.ts_per_cswap is None:
218+
gates.append('cswap')
219+
return ','.join(gates) + ' counts'

qualtran/resource_counting/classify_bloqs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,19 @@ def classify_t_count_by_bloq_type(
105105
classification = classify_bloq(k, bloq_classification)
106106
classified_bloqs[classification] += v * t_counts_from_sigma(k.call_graph()[1])
107107
return classified_bloqs
108+
109+
110+
def bloq_is_clifford(b: Bloq):
111+
from qualtran.bloqs.basic_gates import CNOT, Hadamard, SGate, TwoBitSwap, XGate, ZGate
112+
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiTargetCNOT
113+
from qualtran.bloqs.util_bloqs import ArbitraryClifford
114+
115+
if isinstance(b, Adjoint):
116+
b = b.subbloq
117+
118+
if isinstance(
119+
b, (TwoBitSwap, Hadamard, XGate, ZGate, ArbitraryClifford, CNOT, MultiTargetCNOT, SGate)
120+
):
121+
return True
122+
123+
return False

qualtran/resource_counting/generalizers.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import attrs
2525
import sympy
2626

27-
from qualtran import Adjoint, Bloq
27+
from qualtran import Bloq
2828

2929
PHI = sympy.Symbol(r'\phi')
3030
CV = sympy.Symbol("cv")
@@ -79,18 +79,10 @@ def generalize_cvs(b: Bloq) -> Optional[Bloq]:
7979

8080
def ignore_cliffords(b: Bloq) -> Optional[Bloq]:
8181
"""A generalizer that ignores known clifford bloqs."""
82-
from qualtran.bloqs.basic_gates import CNOT, Hadamard, SGate, TwoBitSwap, XGate, ZGate
83-
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiTargetCNOT
84-
from qualtran.bloqs.util_bloqs import ArbitraryClifford
82+
from qualtran.resource_counting.classify_bloqs import bloq_is_clifford
8583

86-
if isinstance(b, Adjoint):
87-
b = b.subbloq
88-
89-
if isinstance(
90-
b, (TwoBitSwap, Hadamard, XGate, ZGate, ArbitraryClifford, CNOT, MultiTargetCNOT, SGate)
91-
):
84+
if bloq_is_clifford(b):
9285
return None
93-
9486
return b
9587

9688

0 commit comments

Comments
 (0)