Skip to content

Commit 7983de5

Browse files
authored
[Costs] get_bloq_callee_counts and call graph organization (#900)
get_bloq_callee_counts and call graph organization
1 parent 4ae45be commit 7983de5

13 files changed

Lines changed: 165 additions & 47 deletions

qualtran/_infra/bloq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def call_graph(
322322
according to `keep` and `max_depth` (if provided) or if a bloq cannot be
323323
decomposed.
324324
"""
325-
from qualtran.resource_counting.bloq_counts import get_bloq_call_graph
325+
from qualtran.resource_counting import get_bloq_call_graph
326326

327327
return get_bloq_call_graph(self, generalizer=generalizer, keep=keep, max_depth=max_depth)
328328

qualtran/cirq_interop/t_complexity_protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,16 @@ def _from_iterable(it: Any) -> Optional[TComplexity]:
150150

151151
def _from_bloq_build_call_graph(stc: Any) -> Optional[TComplexity]:
152152
# Uses the depth 1 call graph of Bloq `stc` to recursively compute the complexity.
153+
from qualtran.resource_counting import get_bloq_callee_counts
153154
from qualtran.resource_counting.generalizers import cirq_to_bloqs
154155

155156
if not isinstance(stc, Bloq):
156157
return None
157-
_, sigma = stc.call_graph(max_depth=1, generalizer=cirq_to_bloqs)
158-
if sigma == {stc: 1}:
159-
# No decomposition found.
158+
callee_counts = get_bloq_callee_counts(bloq=stc, generalizer=cirq_to_bloqs)
159+
if len(callee_counts) == 0:
160160
return None
161161
ret = TComplexity()
162-
for bloq, n in sigma.items():
162+
for bloq, n in callee_counts:
163163
r = t_complexity(bloq)
164164
if r is None:
165165
return None

qualtran/cirq_interop/t_complexity_protocol_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import cirq
1717
import pytest
18+
from attrs import frozen
1819

1920
from qualtran import Bloq, GateWithRegisters, Signature
2021
from qualtran._infra.gate_with_registers import get_named_qubits
@@ -36,6 +37,7 @@ class DoesNotSupportTComplexity:
3637
...
3738

3839

40+
@frozen
3941
class SupportsTComplexityGateWithRegisters(GateWithRegisters):
4042
@property
4143
def signature(self) -> Signature:
@@ -64,6 +66,7 @@ def signature(self) -> 'Signature':
6466
return Signature.build(q=1)
6567

6668

69+
@frozen
6770
class SupportsTComplexityBloqViaBuildCallGraph(Bloq):
6871
@property
6972
def signature(self) -> 'Signature':
@@ -75,6 +78,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
7578

7679
def test_t_complexity_for_bloq_via_build_call_graph():
7780
bloq = SupportsTComplexityBloqViaBuildCallGraph()
81+
_, sigma = bloq.call_graph(max_depth=1)
82+
assert sigma != {}
7883
assert t_complexity(bloq) == TComplexity(t=5, clifford=10)
7984

8085

qualtran/drawing/bloq_counts_graph_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def test_format_counts_graph_markdown():
4444
ret = format_counts_graph_markdown(graph)
4545
assert (
4646
ret
47-
== r""" - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))`
48-
- `And(cv1=1, cv2=1, uncompute=False)`: $\displaystyle 5$
47+
== """\
48+
- `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))`
49+
- `And(cv1=1, cv2=1, uncompute=False)`: $\\displaystyle 5$
4950
- `And(cv1=1, cv2=1, uncompute=False)`
50-
- `ArbitraryClifford(n=2)`: $\displaystyle 9$
51-
- `TGate()`: $\displaystyle 4$
51+
- `ArbitraryClifford(n=2)`: $\\displaystyle 9$
52+
- `TGate()`: $\\displaystyle 4$
5253
"""
5354
)
5455

qualtran/drawing/flame_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sympy
2525

2626
from qualtran import Bloq
27-
from qualtran.resource_counting.bloq_counts import _compute_sigma
27+
from qualtran.resource_counting._call_graph import _compute_sigma
2828
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
2929

3030

qualtran/resource_counting/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
isort:skip_file
1818
"""
1919

20-
from .bloq_counts import (
20+
from ._generalization import GeneralizerT
21+
22+
from ._call_graph import (
2123
BloqCountT,
22-
GeneralizerT,
2324
big_O,
2425
SympySymbolAllocator,
26+
get_bloq_callee_counts,
2527
get_bloq_call_graph,
2628
print_counts_graph,
2729
build_cbloq_call_graph,
Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from qualtran import Bloq, CompositeBloq, DecomposeNotImplementedError, DecomposeTypeError
2424

2525
BloqCountT = Tuple[Bloq, Union[int, sympy.Expr]]
26-
GeneralizerT = Callable[[Bloq], Optional[Bloq]]
26+
from ._generalization import _make_composite_generalizer, GeneralizerT
2727

2828

2929
def big_O(expr) -> sympy.Order:
@@ -85,6 +85,38 @@ def _generalize_callees(
8585
return callee_counts
8686

8787

88+
def get_bloq_callee_counts(
89+
bloq: 'Bloq', generalizer: 'GeneralizerT' = None, ssa: SympySymbolAllocator = None
90+
) -> List[BloqCountT]:
91+
"""Get the direct callees of a bloq and the number of times they are called.
92+
93+
This calls `bloq.build_call_graph()` with the correct configuration options.
94+
95+
Args:
96+
bloq: The bloq.
97+
generalizer: If provided, run this function on each callee to consolidate attributes
98+
that do not affect resource estimates. If the callable
99+
returns `None`, the bloq is omitted from the counts graph. If a sequence of
100+
generalizers is provided, each generalizer will be run in order.
101+
ssa: A sympy symbol allocator that can be provided if one already exists in your
102+
computation.
103+
104+
Returns:
105+
A list of (bloq, n) bloq counts.
106+
"""
107+
if generalizer is None:
108+
generalizer = lambda b: b
109+
if isinstance(generalizer, (list, tuple)):
110+
generalizer = _make_composite_generalizer(*generalizer)
111+
if ssa is None:
112+
ssa = SympySymbolAllocator()
113+
114+
try:
115+
return _generalize_callees(bloq.build_call_graph(ssa), generalizer)
116+
except (DecomposeNotImplementedError, DecomposeTypeError):
117+
return []
118+
119+
88120
def _build_call_graph(
89121
bloq: Bloq,
90122
generalizer: GeneralizerT,
@@ -103,8 +135,7 @@ def _build_call_graph(
103135
# We already visited this node.
104136
return
105137

106-
# Make sure this node is present in the graph. You could annotate
107-
# additional node properties here, too.
138+
# Make sure this node is present in the graph.
108139
g.add_node(bloq)
109140

110141
# Base case 1: This node is requested by the user to be a leaf node via the `keep` parameter.
@@ -116,12 +147,7 @@ def _build_call_graph(
116147
return
117148

118149
# Prep for recursion: get the callees and modify them according to `generalizer`.
119-
try:
120-
callee_counts = _generalize_callees(bloq.build_call_graph(ssa), generalizer)
121-
except (DecomposeNotImplementedError, DecomposeTypeError):
122-
# Base case 3: Decomposition (or `bloq_counts`) is not implemented. This is left as a
123-
# leaf node.
124-
return
150+
callee_counts = get_bloq_callee_counts(bloq, generalizer)
125151

126152
# Base case 3: Empty list of callees
127153
if not callee_counts:
@@ -165,19 +191,6 @@ def _compute_sigma(root_bloq: Bloq, g: nx.DiGraph) -> Dict[Bloq, Union[int, symp
165191
return dict(bloq_sigmas[root_bloq])
166192

167193

168-
def _make_composite_generalizer(*funcs: GeneralizerT) -> GeneralizerT:
169-
"""Return a generalizer that calls each `*funcs` generalizers in order."""
170-
171-
def _composite_generalize(b: Bloq) -> Optional[Bloq]:
172-
for func in funcs:
173-
b = func(b)
174-
if b is None:
175-
return
176-
return b
177-
178-
return _composite_generalize
179-
180-
181194
def get_bloq_call_graph(
182195
bloq: Bloq,
183196
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,

qualtran/resource_counting/bloq_counts_test.py renamed to qualtran/resource_counting/_call_graph_test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
from qualtran import Bloq, BloqBuilder, Signature, SoquetT
2626
from qualtran.bloqs.basic_gates import TGate
2727
from qualtran.bloqs.util_bloqs import ArbitraryClifford, Join, Split
28-
from qualtran.resource_counting import BloqCountT, get_bloq_call_graph, SympySymbolAllocator
28+
from qualtran.resource_counting import (
29+
BloqCountT,
30+
get_bloq_call_graph,
31+
get_bloq_callee_counts,
32+
SympySymbolAllocator,
33+
)
2934

3035

3136
@frozen
@@ -88,6 +93,23 @@ def test_bloq_counts_method():
8893
assert str(expr) == '3*log(100)'
8994

9095

96+
def test_get_bloq_callee_counts():
97+
bloq = BigBloq(100)
98+
callee_counts = get_bloq_callee_counts(bloq)
99+
assert callee_counts == [(SubBloq(unrelated_param=0.5), sympy.log(100))]
100+
101+
bloq = DecompBloq(10)
102+
callee_counts = get_bloq_callee_counts(bloq)
103+
assert len(callee_counts) == 10 + 2 # 2 for split/join
104+
105+
bloq = SubBloq(unrelated_param=0.5)
106+
callee_counts = get_bloq_callee_counts(bloq)
107+
assert callee_counts == [(TGate(), 3)]
108+
109+
callee_counts = get_bloq_callee_counts(TGate())
110+
assert callee_counts == []
111+
112+
91113
def test_bloq_counts_decomp():
92114
graph, sigma = get_bloq_call_graph(DecompBloq(10))
93115
assert len(sigma) == 3 # includes split and join
@@ -107,7 +129,7 @@ def generalize(bloq):
107129

108130
@pytest.mark.notebook
109131
def test_notebook():
110-
qlt_testing.execute_notebook('bloq_counts')
132+
qlt_testing.execute_notebook('call_graph')
111133

112134

113135
def _to_tuple(x: Iterable[BloqCountT]) -> Sequence[BloqCountT]:
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Callable, Optional, TYPE_CHECKING
15+
16+
if TYPE_CHECKING:
17+
from qualtran import Bloq
18+
19+
GeneralizerT = Callable[['Bloq'], Optional['Bloq']]
20+
21+
22+
def _make_composite_generalizer(*funcs: 'GeneralizerT') -> 'GeneralizerT':
23+
"""Return a generalizer that calls each `*funcs` generalizers in order."""
24+
25+
def _composite_generalize(b: 'Bloq') -> Optional['Bloq']:
26+
for func in funcs:
27+
b = func(b)
28+
if b is None:
29+
return
30+
return b
31+
32+
return _composite_generalize
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional
15+
16+
from qualtran import Bloq
17+
from qualtran.bloqs.for_testing import TestAtom
18+
from qualtran.resource_counting._generalization import _make_composite_generalizer
19+
20+
21+
def test_make_composite_generalizer():
22+
def func1(b: Bloq) -> Optional[Bloq]:
23+
if isinstance(b, TestAtom):
24+
return TestAtom()
25+
return b
26+
27+
def func2(b: Bloq) -> Optional[Bloq]:
28+
if isinstance(b, TestAtom):
29+
return
30+
return b
31+
32+
b = TestAtom(tag='test')
33+
assert func1(b) == TestAtom()
34+
assert func2(b) is None
35+
36+
g00 = _make_composite_generalizer()
37+
g10 = _make_composite_generalizer(func1)
38+
g01 = _make_composite_generalizer(func2)
39+
g11 = _make_composite_generalizer(func1, func2)
40+
g11_r = _make_composite_generalizer(func2, func1)
41+
42+
assert g00(b) == b
43+
assert g10(b) == TestAtom()
44+
assert g01(b) is None
45+
assert g11(b) is None
46+
assert g11_r(b) is None

0 commit comments

Comments
 (0)