Skip to content

Commit d42d9be

Browse files
authored
[Costs] Framework (#913)
* costing framework * docstrings and linting * lint * python 3.10 type issues * lint * part b * copyright * mypy * lint
1 parent 8fcba74 commit d42d9be

File tree

8 files changed

+481
-9
lines changed

8 files changed

+481
-9
lines changed

qualtran/_infra/bloq.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from qualtran.cirq_interop import CirqQuregT
4141
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
4242
from qualtran.drawing import WireSymbol
43-
from qualtran.resource_counting import BloqCountT, GeneralizerT, SympySymbolAllocator
43+
from qualtran.resource_counting import BloqCountT, CostKey, GeneralizerT, SympySymbolAllocator
4444
from qualtran.simulation.classical_sim import ClassicalValT
4545

4646

@@ -296,6 +296,20 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
296296
"""
297297
return self.decompose_bloq().build_call_graph(ssa)
298298

299+
def my_static_costs(self, cost_key: 'CostKey'):
300+
"""Override this method to provide static costs.
301+
302+
The system will query a particular cost by asking for a `cost_key`. This method
303+
can optionally provide a value, which will be preferred over a computed cost.
304+
305+
Static costs can be provided if the particular cost cannot be easily computed or
306+
as a performance optimization.
307+
308+
This method must return `NotImplemented` if a value cannot be provided for the specified
309+
CostKey.
310+
"""
311+
return NotImplemented
312+
299313
def call_graph(
300314
self,
301315
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Sequence, Set, Tuple
15+
16+
from attrs import field, frozen
17+
18+
from qualtran import Bloq, Signature
19+
from qualtran.resource_counting import BloqCountT, CostKey, SympySymbolAllocator
20+
21+
22+
def _convert_callees(callees: Sequence[BloqCountT]) -> Tuple[BloqCountT, ...]:
23+
# Convert to tuples in a type-checked way.
24+
return tuple(callees)
25+
26+
27+
@frozen
28+
class CostingBloq(Bloq):
29+
"""A bloq that lets you set the costs via attributes."""
30+
31+
name: str
32+
num_qubits: int
33+
callees: Sequence[BloqCountT] = field(converter=_convert_callees, factory=tuple)
34+
static_costs: Sequence[Tuple[CostKey, Any]] = field(converter=tuple, factory=tuple)
35+
36+
@property
37+
def signature(self) -> 'Signature':
38+
return Signature.build(register=self.num_qubits)
39+
40+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
41+
return set(self.callees)
42+
43+
def my_static_costs(self, cost_key: 'CostKey'):
44+
return dict(self.static_costs).get(cost_key, NotImplemented)
45+
46+
def pretty_name(self):
47+
return self.name
48+
49+
def __str__(self):
50+
return self.name
51+
52+
53+
def make_example_costing_bloqs():
54+
from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli
55+
56+
func1 = CostingBloq(
57+
'Func1', num_qubits=10, callees=[(TGate(), 10), (TGate().adjoint(), 10), (Hadamard(), 10)]
58+
)
59+
func2 = CostingBloq('Func2', num_qubits=3, callees=[(Toffoli(), 100)])
60+
algo = CostingBloq('Algo', num_qubits=100, callees=[(func1, 1), (func2, 1)])
61+
return algo
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs
16+
from qualtran.resource_counting import format_call_graph_debug_text
17+
18+
19+
def test_costing_bloqs():
20+
algo = make_example_costing_bloqs()
21+
g, _ = algo.call_graph()
22+
assert (
23+
format_call_graph_debug_text(g)
24+
== """\
25+
Algo -- 1 -> Func1
26+
Algo -- 1 -> Func2
27+
Func1 -- 10 -> Hadamard()
28+
Func1 -- 10 -> TGate()
29+
Func1 -- 10 -> TGate(is_adjoint=True)
30+
Func2 -- 100 -> Toffoli()
31+
Toffoli() -- 4 -> TGate()"""
32+
)

qualtran/bloqs/phase_estimation/lp_resource_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import cirq
2121
import numpy as np
2222
import sympy
23-
from numpy._typing import NDArray
23+
from numpy.typing import NDArray
2424

2525
from qualtran import (
2626
Bloq,

qualtran/resource_counting/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
SympySymbolAllocator,
2626
get_bloq_callee_counts,
2727
get_bloq_call_graph,
28-
print_counts_graph,
2928
build_cbloq_call_graph,
29+
format_call_graph_debug_text,
3030
)
3131

32+
from ._costing import GeneralizerT, get_cost_value, get_cost_cache, query_costs, CostKey, CostValT
33+
3234
from . import generalizers

qualtran/resource_counting/_call_graph.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Functionality for the `Bloq.call_graph()` protocol."""
1616

17-
import collections.abc as abc
17+
import collections.abc
1818
from collections import defaultdict
1919
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
2020

@@ -231,7 +231,7 @@ def get_bloq_call_graph(
231231
keep = lambda b: False
232232
if generalizer is None:
233233
generalizer = lambda b: b
234-
if isinstance(generalizer, abc.Sequence):
234+
if isinstance(generalizer, collections.abc.Sequence):
235235
generalizer = _make_composite_generalizer(*generalizer)
236236

237237
g = nx.DiGraph()
@@ -243,8 +243,11 @@ def get_bloq_call_graph(
243243
return g, sigma
244244

245245

246-
def print_counts_graph(g: nx.DiGraph):
246+
def format_call_graph_debug_text(g: nx.DiGraph) -> str:
247247
"""Print the graph returned from `get_bloq_counts_graph`."""
248-
for b in nx.topological_sort(g):
249-
for succ in g.succ[b]:
250-
print(b, '--', g.edges[b, succ]['n'], '->', succ)
248+
lines = []
249+
for gen in nx.topological_generations(g):
250+
for b in sorted(gen, key=str):
251+
for succ in sorted(g.succ[b], key=str):
252+
lines.append(f"{b} -- {g.edges[b, succ]['n']} -> {succ}")
253+
return '\n'.join(lines)

0 commit comments

Comments
 (0)