Skip to content

Commit da80676

Browse files
committed
costing framework
1 parent 7983de5 commit da80676

7 files changed

Lines changed: 422 additions & 6 deletions

File tree

qualtran/_infra/bloq.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from qualtran.cirq_interop import CirqQuregT
3838
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
3939
from qualtran.drawing import WireSymbol
40-
from qualtran.resource_counting import BloqCountT, GeneralizerT, SympySymbolAllocator
40+
from qualtran.resource_counting import BloqCountT, CostKey, GeneralizerT, SympySymbolAllocator
4141
from qualtran.simulation.classical_sim import ClassicalValT
4242

4343

@@ -294,6 +294,20 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
294294
"""
295295
return self.decompose_bloq().build_call_graph(ssa)
296296

297+
def my_static_costs(self, cost_key: 'CostKey') -> Union[Any, NotImplemented]:
298+
"""Override this method to provide static costs.
299+
300+
The system will query a particular cost by asking for a `cost_key`. This method
301+
can optionally provide a value, which will be preferred over a computed cost.
302+
303+
Static costs can be provided if the particular cost cannot be easily computed or
304+
as a performance optimization.
305+
306+
This method must return `NotImplemented` if a value cannot be provided for the specified
307+
CostKey.
308+
"""
309+
return NotImplemented
310+
297311
def call_graph(
298312
self,
299313
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Sequence, Set, Tuple, Union
15+
16+
from attrs import field, frozen
17+
18+
from qualtran import Bloq, Signature
19+
from qualtran.resource_counting import BloqCountT, CostKey, SympySymbolAllocator
20+
21+
22+
@frozen
23+
class CostingBloq(Bloq):
24+
"""A bloq that lets you set the costs via attributes."""
25+
26+
name: str
27+
num_qubits: int
28+
callees: Sequence[BloqCountT] = field(converter=tuple, factory=tuple)
29+
static_costs: Sequence[Tuple[CostKey, Any]] = field(converter=tuple, factory=tuple)
30+
31+
@property
32+
def signature(self) -> 'Signature':
33+
return Signature.build(register=self.num_qubits)
34+
35+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
36+
return set(self.callees)
37+
38+
def my_static_costs(self, k: 'CostKey') -> Union[Any, NotImplemented]:
39+
return dict(self.static_costs).get(k, NotImplemented)
40+
41+
def pretty_name(self):
42+
return self.name
43+
44+
def __str__(self):
45+
return self.name
46+
47+
48+
def make_example_costing_bloqs():
49+
from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli
50+
51+
func1 = CostingBloq(
52+
'Func1', num_qubits=10, callees=[(TGate(), 10), (TGate().adjoint(), 10), (Hadamard(), 10)]
53+
)
54+
func2 = CostingBloq('Func2', num_qubits=3, callees=[(Toffoli(), 100)])
55+
algo = CostingBloq('Algo', num_qubits=100, callees=[(func1, 1), (func2, 1)])
56+
return algo
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs
16+
from qualtran.resource_counting import format_call_graph_debug_text
17+
18+
19+
def test_costing_bloqs():
20+
algo = make_example_costing_bloqs()
21+
g, _ = algo.call_graph()
22+
assert (
23+
format_call_graph_debug_text(g)
24+
== """\
25+
Algo -- 1 -> Func1
26+
Algo -- 1 -> Func2
27+
Func1 -- 10 -> Hadamard()
28+
Func1 -- 10 -> TGate()
29+
Func1 -- 10 -> TGate(is_adjoint=True)
30+
Func2 -- 100 -> Toffoli()
31+
Toffoli() -- 4 -> TGate()"""
32+
)

qualtran/resource_counting/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
SympySymbolAllocator,
2626
get_bloq_callee_counts,
2727
get_bloq_call_graph,
28-
print_counts_graph,
2928
build_cbloq_call_graph,
29+
format_call_graph_debug_text,
3030
)
3131

32+
from ._costing import GeneralizerT, get_cost_value, get_cost_cache, query_costs, CostKey, CostValT
33+
3234
from . import generalizers

qualtran/resource_counting/_call_graph.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,11 @@ def get_bloq_call_graph(
240240
return g, sigma
241241

242242

243-
def print_counts_graph(g: nx.DiGraph):
243+
def format_call_graph_debug_text(g: nx.DiGraph) -> str:
244244
"""Print the graph returned from `get_bloq_counts_graph`."""
245-
for b in nx.topological_sort(g):
246-
for succ in g.succ[b]:
247-
print(b, '--', g.edges[b, succ]['n'], '->', succ)
245+
lines = []
246+
for gen in nx.topological_generations(g):
247+
for b in sorted(gen, key=str):
248+
for succ in sorted(g.succ[b], key=str):
249+
lines.append(f"{b} -- {g.edges[b, succ]['n']} -> {succ}")
250+
return '\n'.join(lines)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
import logging
17+
import time
18+
from collections import defaultdict
19+
from typing import (
20+
Callable,
21+
Dict,
22+
Generic,
23+
Iterable,
24+
Optional,
25+
Sequence,
26+
Tuple,
27+
TYPE_CHECKING,
28+
TypeVar,
29+
Union,
30+
)
31+
32+
from ._generalization import _make_composite_generalizer, GeneralizerT
33+
34+
if TYPE_CHECKING:
35+
from qualtran import Bloq
36+
37+
logger = logging.getLogger(__name__)
38+
39+
CostValT = TypeVar('CostValT')
40+
41+
42+
class CostKey(Generic[CostValT], metaclass=abc.ABCMeta):
43+
"""Abstract base class for different types of costs.
44+
45+
One important aspect of a bloq is the resources required to execute it on an error
46+
corrected quantum computer. Since we're usually trying to minimize these resource requirements
47+
we will generally use the catch-all term "costs".
48+
49+
There are a variety of different types or flavors of costs. Each is represented by an
50+
instance of a sublcass of `CostKey`. For example, gate counts (including T-gate counts),
51+
qubit requirements, and circuit depth are all cost metrics that may be of interest.
52+
53+
Each `CostKey` primarily encodes the behavior required to compute a cost value from a
54+
bloq. Often, these costs are defined recursively: a bloq's costs is some combination
55+
of the costs of the bloqs in its decomposition (i.e. the bloq 'callees'). Implementors
56+
must override the `compute` method to define the cost computation.
57+
58+
Each cost key has an associated CostValT. For example, the CostValT of a "t count"
59+
CostKey could be an integer. For a more complicated gateset, the value could be a mapping
60+
from gate to count. This abstract base class is generic w.r.t. `CostValT`. Subclasses
61+
should have a concrete value type. The `validate_val` method can optionally be overridden
62+
to raise an exception if a bad value type is encountered. The `zero` method must return
63+
the zero (additive identity) cost value of the correct type.
64+
"""
65+
66+
@abc.abstractmethod
67+
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], CostValT]) -> CostValT:
68+
"""Compute this type of cost.
69+
70+
When implementing a new CostKey, this method must be overridden.
71+
Users should not call this method directly. Instead: use the `qualtran.resource_counting`
72+
functions like `get_cost_value`, `get_cost_cache`, or `query_costs`. These provide
73+
caching, logging, generalizers, and support for static costs.
74+
75+
For recursive computations, use the provided callable to recurse.
76+
77+
Args:
78+
bloq: The bloq to compute the cost of.
79+
get_callee_cost: A qualtran-provided function for computing costs for "callees"
80+
of the bloq; i.e. bloqs in the decomposition. Use this function to accurately
81+
cache intermediate cost values and respect bloqs' static costs.
82+
83+
Returns:
84+
A value of the generic type `CostValT`. Subclasses should define their value type.
85+
"""
86+
87+
@abc.abstractmethod
88+
def zero(self) -> CostValT:
89+
"""The value corresponding to zero cost."""
90+
91+
def validate_val(self, val: CostValT):
92+
"""Assert that `val` is a valid `CostValT`.
93+
94+
This method can be optionally overridden to raise an error if an invalid value
95+
is encountered. By default, no validation is performed.
96+
"""
97+
98+
99+
def _get_cost_value(
100+
bloq: 'Bloq',
101+
cost_key: CostKey[CostValT],
102+
*,
103+
costs_cache: Dict['Bloq', CostValT],
104+
generalizer: 'GeneralizerT',
105+
) -> CostValT:
106+
"""Helper function for `query_costs`.
107+
108+
This function tries the following strategies
109+
1. Use the value found in `costs_cache`, if it exists.
110+
2. Use the value returned by `Bloq.my_static_costs` if one is returned.
111+
3. Use `cost_key.compute()` and cache the result in `costs_cache`.
112+
113+
Args:
114+
bloq: The bloq.
115+
cost_key: The cost key to get the value for.
116+
costs_cache: A dictionary to use as a cache for computed bloq costs. This cache
117+
will be mutated by this function.
118+
generalizer: The generalizer to operate on each bloq before computing its cost.
119+
"""
120+
bloq = generalizer(bloq)
121+
if bloq is None:
122+
return cost_key.zero()
123+
124+
# Strategy 1: Use cached value
125+
if bloq in costs_cache:
126+
logger.debug("Using cached %s for %s", cost_key, bloq)
127+
return costs_cache[bloq]
128+
129+
# Strategy 2: Static costs
130+
static_cost = bloq.my_static_costs(cost_key)
131+
if static_cost is not NotImplemented:
132+
cost_key.validate_val(static_cost)
133+
logger.info("Using static %s for %s", cost_key, bloq)
134+
costs_cache[bloq] = static_cost
135+
return static_cost
136+
137+
# Strategy 3: Compute
138+
# part a. set up caching of computed costs. Using the callable will use the cache if possible
139+
# and only recurse if the bloq has not been seen before. The result of a computation
140+
# will be cached.
141+
def _get_cost_val_internal(callee: 'Bloq'):
142+
return _get_cost_value(callee, cost_key, costs_cache=costs_cache, generalizer=generalizer)
143+
144+
tstart = time.perf_counter()
145+
computed_cost = cost_key.compute(bloq, _get_cost_val_internal)
146+
tdur = time.perf_counter() - tstart
147+
logger.info("Computed %s for %s in %g s", cost_key, bloq, tdur)
148+
costs_cache[bloq] = computed_cost
149+
return computed_cost
150+
151+
152+
def get_cost_value(
153+
bloq: 'Bloq',
154+
cost_key: CostKey[CostValT],
155+
costs_cache: Optional[Dict['Bloq', CostValT]] = None,
156+
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
157+
) -> CostValT:
158+
if costs_cache is None:
159+
costs_cache = {}
160+
if generalizer is None:
161+
generalizer = lambda b: b
162+
if isinstance(generalizer, (list, tuple)):
163+
generalizer = _make_composite_generalizer(*generalizer)
164+
165+
cost_val = _get_cost_value(bloq, cost_key, costs_cache=costs_cache, generalizer=generalizer)
166+
return cost_val
167+
168+
169+
def get_cost_cache(
170+
bloq: 'Bloq',
171+
cost_key: CostKey[CostValT],
172+
costs_cache: Optional[Dict['Bloq', CostValT]] = None,
173+
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
174+
) -> Dict['Bloq', CostValT]:
175+
if costs_cache is None:
176+
costs_cache = {}
177+
if generalizer is None:
178+
generalizer = lambda b: b
179+
if isinstance(generalizer, (list, tuple)):
180+
generalizer = _make_composite_generalizer(*generalizer)
181+
182+
_get_cost_value(bloq, cost_key, costs_cache=costs_cache, generalizer=generalizer)
183+
return costs_cache
184+
185+
186+
def query_costs(
187+
bloq: 'Bloq',
188+
cost_keys: Iterable[CostKey],
189+
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
190+
) -> Dict['Bloq', Dict[CostKey, CostValT]]:
191+
192+
costs = defaultdict(dict)
193+
for cost_key in cost_keys:
194+
cost_for_bloqs = get_cost_cache(bloq, cost_key, generalizer=generalizer)
195+
for bloq, val in cost_for_bloqs.items():
196+
costs[bloq][cost_key] = val
197+
return dict(costs)

0 commit comments

Comments
 (0)