2323from qualtran import Bloq , CompositeBloq , DecomposeNotImplementedError , DecomposeTypeError
2424
2525BloqCountT = Tuple [Bloq , Union [int , sympy .Expr ]]
26- GeneralizerT = Callable [[ Bloq ], Optional [ Bloq ]]
26+ from . _generalization import _make_composite_generalizer , GeneralizerT
2727
2828
2929def big_O (expr ) -> sympy .Order :
@@ -85,6 +85,38 @@ def _generalize_callees(
8585 return callee_counts
8686
8787
88+ def get_bloq_callee_counts (
89+ bloq : 'Bloq' , generalizer : 'GeneralizerT' = None , ssa : SympySymbolAllocator = None
90+ ) -> List [BloqCountT ]:
91+ """Get the direct callees of a bloq and the number of times they are called.
92+
93+ This calls `bloq.build_call_graph()` with the correct configuration options.
94+
95+ Args:
96+ bloq: The bloq.
97+ generalizer: If provided, run this function on each callee to consolidate attributes
98+ that do not affect resource estimates. If the callable
99+ returns `None`, the bloq is omitted from the counts graph. If a sequence of
100+ generalizers is provided, each generalizer will be run in order.
101+ ssa: A sympy symbol allocator that can be provided if one already exists in your
102+ computation.
103+
104+ Returns:
105+ A list of (bloq, n) bloq counts.
106+ """
107+ if generalizer is None :
108+ generalizer = lambda b : b
109+ if isinstance (generalizer , (list , tuple )):
110+ generalizer = _make_composite_generalizer (* generalizer )
111+ if ssa is None :
112+ ssa = SympySymbolAllocator ()
113+
114+ try :
115+ return _generalize_callees (bloq .build_call_graph (ssa ), generalizer )
116+ except (DecomposeNotImplementedError , DecomposeTypeError ):
117+ return []
118+
119+
88120def _build_call_graph (
89121 bloq : Bloq ,
90122 generalizer : GeneralizerT ,
@@ -103,8 +135,7 @@ def _build_call_graph(
103135 # We already visited this node.
104136 return
105137
106- # Make sure this node is present in the graph. You could annotate
107- # additional node properties here, too.
138+ # Make sure this node is present in the graph.
108139 g .add_node (bloq )
109140
110141 # Base case 1: This node is requested by the user to be a leaf node via the `keep` parameter.
@@ -116,12 +147,7 @@ def _build_call_graph(
116147 return
117148
118149 # Prep for recursion: get the callees and modify them according to `generalizer`.
119- try :
120- callee_counts = _generalize_callees (bloq .build_call_graph (ssa ), generalizer )
121- except (DecomposeNotImplementedError , DecomposeTypeError ):
122- # Base case 3: Decomposition (or `bloq_counts`) is not implemented. This is left as a
123- # leaf node.
124- return
150+ callee_counts = get_bloq_callee_counts (bloq , generalizer )
125151
126152 # Base case 3: Empty list of callees
127153 if not callee_counts :
@@ -165,19 +191,6 @@ def _compute_sigma(root_bloq: Bloq, g: nx.DiGraph) -> Dict[Bloq, Union[int, symp
165191 return dict (bloq_sigmas [root_bloq ])
166192
167193
168- def _make_composite_generalizer (* funcs : GeneralizerT ) -> GeneralizerT :
169- """Return a generalizer that calls each `*funcs` generalizers in order."""
170-
171- def _composite_generalize (b : Bloq ) -> Optional [Bloq ]:
172- for func in funcs :
173- b = func (b )
174- if b is None :
175- return
176- return b
177-
178- return _composite_generalize
179-
180-
181194def get_bloq_call_graph (
182195 bloq : Bloq ,
183196 generalizer : Optional [Union ['GeneralizerT' , Sequence ['GeneralizerT' ]]] = None ,
0 commit comments