diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index a5acf9d135..8ef231b82b 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -322,7 +322,7 @@ def call_graph( according to `keep` and `max_depth` (if provided) or if a bloq cannot be decomposed. """ - from qualtran.resource_counting.bloq_counts import get_bloq_call_graph + from qualtran.resource_counting import get_bloq_call_graph return get_bloq_call_graph(self, generalizer=generalizer, keep=keep, max_depth=max_depth) diff --git a/qualtran/cirq_interop/t_complexity_protocol.py b/qualtran/cirq_interop/t_complexity_protocol.py index ee97c65f4d..2855350771 100644 --- a/qualtran/cirq_interop/t_complexity_protocol.py +++ b/qualtran/cirq_interop/t_complexity_protocol.py @@ -150,16 +150,16 @@ def _from_iterable(it: Any) -> Optional[TComplexity]: def _from_bloq_build_call_graph(stc: Any) -> Optional[TComplexity]: # Uses the depth 1 call graph of Bloq `stc` to recursively compute the complexity. + from qualtran.resource_counting import get_bloq_callee_counts from qualtran.resource_counting.generalizers import cirq_to_bloqs if not isinstance(stc, Bloq): return None - _, sigma = stc.call_graph(max_depth=1, generalizer=cirq_to_bloqs) - if sigma == {stc: 1}: - # No decomposition found. + callee_counts = get_bloq_callee_counts(bloq=stc, generalizer=cirq_to_bloqs) + if len(callee_counts) == 0: return None ret = TComplexity() - for bloq, n in sigma.items(): + for bloq, n in callee_counts: r = t_complexity(bloq) if r is None: return None diff --git a/qualtran/cirq_interop/t_complexity_protocol_test.py b/qualtran/cirq_interop/t_complexity_protocol_test.py index bf5c80b61c..6ed67ac402 100644 --- a/qualtran/cirq_interop/t_complexity_protocol_test.py +++ b/qualtran/cirq_interop/t_complexity_protocol_test.py @@ -15,6 +15,7 @@ import cirq import pytest +from attrs import frozen from qualtran import Bloq, GateWithRegisters, Signature from qualtran._infra.gate_with_registers import get_named_qubits @@ -36,6 +37,7 @@ class DoesNotSupportTComplexity: ... +@frozen class SupportsTComplexityGateWithRegisters(GateWithRegisters): @property def signature(self) -> Signature: @@ -64,6 +66,7 @@ def signature(self) -> 'Signature': return Signature.build(q=1) +@frozen class SupportsTComplexityBloqViaBuildCallGraph(Bloq): @property def signature(self) -> 'Signature': @@ -75,6 +78,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: def test_t_complexity_for_bloq_via_build_call_graph(): bloq = SupportsTComplexityBloqViaBuildCallGraph() + _, sigma = bloq.call_graph(max_depth=1) + assert sigma != {} assert t_complexity(bloq) == TComplexity(t=5, clifford=10) diff --git a/qualtran/drawing/bloq_counts_graph_test.py b/qualtran/drawing/bloq_counts_graph_test.py index e9b64a6f0c..c46d725ed4 100644 --- a/qualtran/drawing/bloq_counts_graph_test.py +++ b/qualtran/drawing/bloq_counts_graph_test.py @@ -44,11 +44,12 @@ def test_format_counts_graph_markdown(): ret = format_counts_graph_markdown(graph) assert ( ret - == r""" - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))` - - `And(cv1=1, cv2=1, uncompute=False)`: $\displaystyle 5$ + == """\ + - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))` + - `And(cv1=1, cv2=1, uncompute=False)`: $\\displaystyle 5$ - `And(cv1=1, cv2=1, uncompute=False)` - - `ArbitraryClifford(n=2)`: $\displaystyle 9$ - - `TGate()`: $\displaystyle 4$ + - `ArbitraryClifford(n=2)`: $\\displaystyle 9$ + - `TGate()`: $\\displaystyle 4$ """ ) diff --git a/qualtran/drawing/flame_graph.py b/qualtran/drawing/flame_graph.py index 75b522afc6..0366fd8e88 100644 --- a/qualtran/drawing/flame_graph.py +++ b/qualtran/drawing/flame_graph.py @@ -24,7 +24,7 @@ import sympy from qualtran import Bloq -from qualtran.resource_counting.bloq_counts import _compute_sigma +from qualtran.resource_counting._call_graph import _compute_sigma from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma diff --git a/qualtran/resource_counting/__init__.py b/qualtran/resource_counting/__init__.py index 4cebe08c27..dc77fd8762 100644 --- a/qualtran/resource_counting/__init__.py +++ b/qualtran/resource_counting/__init__.py @@ -17,11 +17,13 @@ isort:skip_file """ -from .bloq_counts import ( +from ._generalization import GeneralizerT + +from ._call_graph import ( BloqCountT, - GeneralizerT, big_O, SympySymbolAllocator, + get_bloq_callee_counts, get_bloq_call_graph, print_counts_graph, build_cbloq_call_graph, diff --git a/qualtran/resource_counting/bloq_counts.py b/qualtran/resource_counting/_call_graph.py similarity index 85% rename from qualtran/resource_counting/bloq_counts.py rename to qualtran/resource_counting/_call_graph.py index 72dbaaafa5..79c180ca1e 100644 --- a/qualtran/resource_counting/bloq_counts.py +++ b/qualtran/resource_counting/_call_graph.py @@ -23,7 +23,7 @@ from qualtran import Bloq, CompositeBloq, DecomposeNotImplementedError, DecomposeTypeError BloqCountT = Tuple[Bloq, Union[int, sympy.Expr]] -GeneralizerT = Callable[[Bloq], Optional[Bloq]] +from ._generalization import _make_composite_generalizer, GeneralizerT def big_O(expr) -> sympy.Order: @@ -85,6 +85,38 @@ def _generalize_callees( return callee_counts +def get_bloq_callee_counts( + bloq: 'Bloq', generalizer: 'GeneralizerT' = None, ssa: SympySymbolAllocator = None +) -> List[BloqCountT]: + """Get the direct callees of a bloq and the number of times they are called. + + This calls `bloq.build_call_graph()` with the correct configuration options. + + Args: + bloq: The bloq. + generalizer: If provided, run this function on each callee to consolidate attributes + that do not affect resource estimates. If the callable + returns `None`, the bloq is omitted from the counts graph. If a sequence of + generalizers is provided, each generalizer will be run in order. + ssa: A sympy symbol allocator that can be provided if one already exists in your + computation. + + Returns: + A list of (bloq, n) bloq counts. + """ + if generalizer is None: + generalizer = lambda b: b + if isinstance(generalizer, (list, tuple)): + generalizer = _make_composite_generalizer(*generalizer) + if ssa is None: + ssa = SympySymbolAllocator() + + try: + return _generalize_callees(bloq.build_call_graph(ssa), generalizer) + except (DecomposeNotImplementedError, DecomposeTypeError): + return [] + + def _build_call_graph( bloq: Bloq, generalizer: GeneralizerT, @@ -103,8 +135,7 @@ def _build_call_graph( # We already visited this node. return - # Make sure this node is present in the graph. You could annotate - # additional node properties here, too. + # Make sure this node is present in the graph. g.add_node(bloq) # Base case 1: This node is requested by the user to be a leaf node via the `keep` parameter. @@ -116,12 +147,7 @@ def _build_call_graph( return # Prep for recursion: get the callees and modify them according to `generalizer`. - try: - callee_counts = _generalize_callees(bloq.build_call_graph(ssa), generalizer) - except (DecomposeNotImplementedError, DecomposeTypeError): - # Base case 3: Decomposition (or `bloq_counts`) is not implemented. This is left as a - # leaf node. - return + callee_counts = get_bloq_callee_counts(bloq, generalizer) # Base case 3: Empty list of callees if not callee_counts: @@ -165,19 +191,6 @@ def _compute_sigma(root_bloq: Bloq, g: nx.DiGraph) -> Dict[Bloq, Union[int, symp return dict(bloq_sigmas[root_bloq]) -def _make_composite_generalizer(*funcs: GeneralizerT) -> GeneralizerT: - """Return a generalizer that calls each `*funcs` generalizers in order.""" - - def _composite_generalize(b: Bloq) -> Optional[Bloq]: - for func in funcs: - b = func(b) - if b is None: - return - return b - - return _composite_generalize - - def get_bloq_call_graph( bloq: Bloq, generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None, diff --git a/qualtran/resource_counting/bloq_counts_test.py b/qualtran/resource_counting/_call_graph_test.py similarity index 89% rename from qualtran/resource_counting/bloq_counts_test.py rename to qualtran/resource_counting/_call_graph_test.py index 9f9f3bdba3..0ffa0a8fa6 100644 --- a/qualtran/resource_counting/bloq_counts_test.py +++ b/qualtran/resource_counting/_call_graph_test.py @@ -25,7 +25,12 @@ from qualtran import Bloq, BloqBuilder, Signature, SoquetT from qualtran.bloqs.basic_gates import TGate from qualtran.bloqs.util_bloqs import ArbitraryClifford, Join, Split -from qualtran.resource_counting import BloqCountT, get_bloq_call_graph, SympySymbolAllocator +from qualtran.resource_counting import ( + BloqCountT, + get_bloq_call_graph, + get_bloq_callee_counts, + SympySymbolAllocator, +) @frozen @@ -88,6 +93,23 @@ def test_bloq_counts_method(): assert str(expr) == '3*log(100)' +def test_get_bloq_callee_counts(): + bloq = BigBloq(100) + callee_counts = get_bloq_callee_counts(bloq) + assert callee_counts == [(SubBloq(unrelated_param=0.5), sympy.log(100))] + + bloq = DecompBloq(10) + callee_counts = get_bloq_callee_counts(bloq) + assert len(callee_counts) == 10 + 2 # 2 for split/join + + bloq = SubBloq(unrelated_param=0.5) + callee_counts = get_bloq_callee_counts(bloq) + assert callee_counts == [(TGate(), 3)] + + callee_counts = get_bloq_callee_counts(TGate()) + assert callee_counts == [] + + def test_bloq_counts_decomp(): graph, sigma = get_bloq_call_graph(DecompBloq(10)) assert len(sigma) == 3 # includes split and join @@ -107,7 +129,7 @@ def generalize(bloq): @pytest.mark.notebook def test_notebook(): - qlt_testing.execute_notebook('bloq_counts') + qlt_testing.execute_notebook('call_graph') def _to_tuple(x: Iterable[BloqCountT]) -> Sequence[BloqCountT]: diff --git a/qualtran/resource_counting/_generalization.py b/qualtran/resource_counting/_generalization.py new file mode 100644 index 0000000000..85b9b2e7e3 --- /dev/null +++ b/qualtran/resource_counting/_generalization.py @@ -0,0 +1,32 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from qualtran import Bloq + +GeneralizerT = Callable[['Bloq'], Optional['Bloq']] + + +def _make_composite_generalizer(*funcs: 'GeneralizerT') -> 'GeneralizerT': + """Return a generalizer that calls each `*funcs` generalizers in order.""" + + def _composite_generalize(b: 'Bloq') -> Optional['Bloq']: + for func in funcs: + b = func(b) + if b is None: + return + return b + + return _composite_generalize diff --git a/qualtran/resource_counting/_generalization_test.py b/qualtran/resource_counting/_generalization_test.py new file mode 100644 index 0000000000..c2cdc8c529 --- /dev/null +++ b/qualtran/resource_counting/_generalization_test.py @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from qualtran import Bloq +from qualtran.bloqs.for_testing import TestAtom +from qualtran.resource_counting._generalization import _make_composite_generalizer + + +def test_make_composite_generalizer(): + def func1(b: Bloq) -> Optional[Bloq]: + if isinstance(b, TestAtom): + return TestAtom() + return b + + def func2(b: Bloq) -> Optional[Bloq]: + if isinstance(b, TestAtom): + return + return b + + b = TestAtom(tag='test') + assert func1(b) == TestAtom() + assert func2(b) is None + + g00 = _make_composite_generalizer() + g10 = _make_composite_generalizer(func1) + g01 = _make_composite_generalizer(func2) + g11 = _make_composite_generalizer(func1, func2) + g11_r = _make_composite_generalizer(func2, func1) + + assert g00(b) == b + assert g10(b) == TestAtom() + assert g01(b) is None + assert g11(b) is None + assert g11_r(b) is None diff --git a/qualtran/resource_counting/bloq_counts.ipynb b/qualtran/resource_counting/call_graph.ipynb similarity index 92% rename from qualtran/resource_counting/bloq_counts.ipynb rename to qualtran/resource_counting/call_graph.ipynb index 8804c7ef57..25ec008efd 100644 --- a/qualtran/resource_counting/bloq_counts.ipynb +++ b/qualtran/resource_counting/call_graph.ipynb @@ -7,7 +7,7 @@ "source": [ "# The Call Graph Protocol\n", "\n", - "The call graph protocol lets you query which subbloq are called in a bloq's decomposition. Proper accounting of the quantity of subroutine calls is a crucial tool in estimating resource requirements for an algorithm. For example, you can expand the call graph until you reach 'expensive' gates like `TGate` or `Toffoli`. The total number of these gates set the runtime of the algorithm." + "The call graph protocol lets you query which subbloq are called in a bloq's decomposition. Proper accounting of the quantity of subroutine calls is a crucial tool in estimating resource requirements for an algorithm. For example, the number of 'expensive' gates like `TGate` or `Toffoli` required by a bloq is the sum of the number of those gates used by the bloq's callees." ] }, { @@ -20,10 +20,8 @@ "from qualtran.drawing import show_call_graph, show_counts_sigma\n", "from qualtran.bloqs.mcmt import MultiAnd, And\n", "\n", - "graph, sigma = MultiAnd(cvs=(1,)*6).call_graph()\n", - "\n", - "show_call_graph(graph)\n", - "show_counts_sigma(sigma)" + "graph, _ = MultiAnd(cvs=(1,)*6).call_graph()\n", + "show_call_graph(graph)" ] }, { @@ -33,7 +31,7 @@ "source": [ "## Interface\n", "\n", - "The primary method for accessing the call graph of a bloq is `Bloq.call_graph()`. It returns a networkx graph as well as a dictionary of totals for \"leaf\" bloqs. \n", + "The primary method for accessing the call graph of a bloq is `Bloq.call_graph()`. It returns a networkx graph as well as an accounting of total bloq counts for \"leaf\" bloqs. \n", "\n", "Another method is `Bloq.bloq_counts`, which will return a dictionary of immediate children." ] @@ -190,8 +188,7 @@ "source": [ "myfunc = MyFunc(n=sympy.sympify('n'))\n", "graph, sigma = myfunc.call_graph()\n", - "show_call_graph(graph)\n", - "show_counts_sigma(sigma)" + "show_call_graph(graph)" ] }, { @@ -203,7 +200,7 @@ "\n", "If a bloq does not override `build_call_graph(...)`, the default fallback will be used by Qualtran to support the call graph protocol.\n", "\n", - "By default, Qualtran will use the decomposition to count subbloqs called by the bloq. For example, below we author a `SWAP` bloq. We define a decomposition but do not explicitly provide the call graph counts." + "By default, Qualtran will extract the call graph from the full decomposition. For example, below we author a `SWAP` bloq. We define a decomposition but do not explicitly override `build_call_graph`." ] }, { diff --git a/qualtran/resource_counting/generalizers_test.py b/qualtran/resource_counting/generalizers_test.py index f48e1f7a9b..4db1977068 100644 --- a/qualtran/resource_counting/generalizers_test.py +++ b/qualtran/resource_counting/generalizers_test.py @@ -18,7 +18,7 @@ from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Partition, Split from qualtran.cirq_interop import CirqGateAsBloq -from qualtran.resource_counting.bloq_counts import _make_composite_generalizer +from qualtran.resource_counting._generalization import _make_composite_generalizer from qualtran.resource_counting.generalizers import ( cirq_to_bloqs, CV, diff --git a/qualtran/resource_counting/t_counts_from_sigma.py b/qualtran/resource_counting/t_counts_from_sigma.py index 7337c0be33..99cab14770 100644 --- a/qualtran/resource_counting/t_counts_from_sigma.py +++ b/qualtran/resource_counting/t_counts_from_sigma.py @@ -17,7 +17,6 @@ import cirq -from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt if TYPE_CHECKING: @@ -46,6 +45,7 @@ def t_counts_from_sigma( ) -> SymbolicInt: """Aggregates T-counts from a sigma dictionary by summing T-costs for all rotation bloqs.""" from qualtran.bloqs.basic_gates import TGate + from qualtran.cirq_interop.t_complexity_protocol import TComplexity if rotation_types is None: rotation_types = _get_all_rotation_types()