Skip to content

Commit 0f5c35f

Browse files
pqvrmpharrigan
andauthored
feat: multi-level decomposition for (non-symbolic) bloqs (#1643)
## Description Each `Bloq` object has a `decompose_bloq` method. Consider the following snippet ```py from qualtran.bloqs.qsp.generalized_qsp import GeneralizedQSP from qualtran.bloqs.basic_gates import XPowGate bloq = GeneralizedQSP.from_qsp_polynomial(XPowGate(), (0.5, 0.5)) decomp = bloq.decompose_bloq() ``` What I want to do is to continually decompose all of the subsequent bloqs. However, the type of the decomposed bloq is ``` CompositeBloq([3 subbloqs...]) ``` and this does *not* have a `decompose_bloq` method. This feature allows one to completely decompose it into its constituent parts. ## Example Consider the following example of defining a `Bloq` object and decomposing it. ```py from qualtran.bloqs.basic_gates import XPowGate from qualtran.bloqs.qsp.generalized_qsp import GeneralizedQSP from qualtran.qref_interop._bloq_to_qref import bloq_to_qref from qref.experimental.rendering import to_graphviz bloq_qsp = GeneralizedQSP.from_qsp_polynomial(XPowGate(), (0.5, 0.5)) ``` The following generates the decomposition (as was done before) without decomposing recursively: ```py # Non-preserve bloqs on decomposition (default). schema_qsp = bloq_to_qref(bloq_qsp) plot_qsp = to_graphviz(schema_qsp) plot_qsp.render(directory="doctest-output/qsp", format="png", cleanup=True) ``` ![Digraph gv](https://github.com/user-attachments/assets/fd408f6b-2017-4271-896b-57d0b928f0d2) Using the new feature, we can select to "blow out" the subsequent `Bloq` objects for subsequent viewing: ```py # Preserve bloqs on decomposition. schema_qsp_preserve = bloq_to_qref(bloq_qsp, decomposition_rules=True) plot_qsp_preserve = to_graphviz(schema_qsp_preserve) plot_qsp_preserve.render(directory="doctest-output/qsp_preserve", format="png", cleanup=True) ``` ![Digraph gv](https://github.com/user-attachments/assets/18739bc1-f7e3-4ad7-91bc-f06303547c23) ## Limitations This works well when the inputs are numeric, but it does not work well for symbolic inputs. As this is outside of scope of this issue, the description of that will be described in a subsequent issue. Tagging @mstechly --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 2b281e0 commit 0f5c35f

10 files changed

Lines changed: 137 additions & 31 deletions

File tree

dev_tools/requirements/deps/runtime.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ qsharp
3131
qsharp-widgets
3232

3333
# qref bartiq interop
34-
qref==0.9.0
35-
bartiq==0.9.0
34+
qref==0.11.0
35+
bartiq==0.12.1
3636

3737
# pyzx interop
3838
pyzx

dev_tools/requirements/envs/dev.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ babel==2.17.0
6363
# sphinx
6464
backports-tarfile==1.2.0
6565
# via jaraco-context
66-
bartiq==0.9.0
66+
bartiq==0.12.1
6767
# via -r deps/runtime.txt
6868
beautifulsoup4==4.13.4
6969
# via
@@ -548,7 +548,6 @@ pylint==3.3.7
548548
# via -r deps/pylint.txt
549549
pyparsing==3.1.4
550550
# via
551-
# bartiq
552551
# matplotlib
553552
# pydot
554553
pyperclip==1.9.0
@@ -596,7 +595,7 @@ pyzmq==26.4.0
596595
# jupyter-server
597596
pyzx==0.9.0
598597
# via -r deps/runtime.txt
599-
qref==0.9.0
598+
qref==0.11.0
600599
# via
601600
# -r deps/runtime.txt
602601
# bartiq

dev_tools/requirements/envs/docs.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ babel==2.17.0
8787
# jupyterlab-server
8888
# pydata-sphinx-theme
8989
# sphinx
90-
bartiq==0.9.0
90+
bartiq==0.12.1
9191
# via
9292
# -c envs/dev.env.txt
9393
# -r deps/runtime.txt
@@ -624,7 +624,6 @@ pygments==2.19.1
624624
pyparsing==3.1.4
625625
# via
626626
# -c envs/dev.env.txt
627-
# bartiq
628627
# matplotlib
629628
# pydot
630629
pyperclip==1.9.0
@@ -664,7 +663,7 @@ pyzx==0.9.0
664663
# via
665664
# -c envs/dev.env.txt
666665
# -r deps/runtime.txt
667-
qref==0.9.0
666+
qref==0.11.0
668667
# via
669668
# -c envs/dev.env.txt
670669
# -r deps/runtime.txt

dev_tools/requirements/envs/format.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ babel==2.17.0
7070
# via
7171
# -c envs/dev.env.txt
7272
# jupyterlab-server
73-
bartiq==0.9.0
73+
bartiq==0.12.1
7474
# via
7575
# -c envs/dev.env.txt
7676
# -r deps/runtime.txt
@@ -559,7 +559,6 @@ pygments==2.19.1
559559
pyparsing==3.1.4
560560
# via
561561
# -c envs/dev.env.txt
562-
# bartiq
563562
# matplotlib
564563
# pydot
565564
pyperclip==1.9.0
@@ -595,7 +594,7 @@ pyzx==0.9.0
595594
# via
596595
# -c envs/dev.env.txt
597596
# -r deps/runtime.txt
598-
qref==0.9.0
597+
qref==0.11.0
599598
# via
600599
# -c envs/dev.env.txt
601600
# -r deps/runtime.txt

dev_tools/requirements/envs/mypy.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ babel==2.17.0
6666
# via
6767
# -c envs/dev.env.txt
6868
# jupyterlab-server
69-
bartiq==0.9.0
69+
bartiq==0.12.1
7070
# via
7171
# -c envs/dev.env.txt
7272
# -r deps/runtime.txt
@@ -545,7 +545,6 @@ pygments==2.19.1
545545
pyparsing==3.1.4
546546
# via
547547
# -c envs/dev.env.txt
548-
# bartiq
549548
# matplotlib
550549
# pydot
551550
pyperclip==1.9.0
@@ -581,7 +580,7 @@ pyzx==0.9.0
581580
# via
582581
# -c envs/dev.env.txt
583582
# -r deps/runtime.txt
584-
qref==0.9.0
583+
qref==0.11.0
585584
# via
586585
# -c envs/dev.env.txt
587586
# -r deps/runtime.txt

dev_tools/requirements/envs/pylint.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ babel==2.17.0
8787
# -c envs/dev.env.txt
8888
# jupyterlab-server
8989
# sphinx
90-
bartiq==0.9.0
90+
bartiq==0.12.1
9191
# via
9292
# -c envs/dev.env.txt
9393
# -r deps/runtime.txt
@@ -644,7 +644,6 @@ pylint==3.3.7
644644
pyparsing==3.1.4
645645
# via
646646
# -c envs/dev.env.txt
647-
# bartiq
648647
# matplotlib
649648
# pydot
650649
pyperclip==1.9.0
@@ -689,7 +688,7 @@ pyzx==0.9.0
689688
# via
690689
# -c envs/dev.env.txt
691690
# -r deps/runtime.txt
692-
qref==0.9.0
691+
qref==0.11.0
693692
# via
694693
# -c envs/dev.env.txt
695694
# -r deps/runtime.txt

dev_tools/requirements/envs/pytest.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ babel==2.17.0
7070
# via
7171
# -c envs/dev.env.txt
7272
# jupyterlab-server
73-
bartiq==0.9.0
73+
bartiq==0.12.1
7474
# via
7575
# -c envs/dev.env.txt
7676
# -r deps/runtime.txt
@@ -605,7 +605,6 @@ pygments==2.19.1
605605
pyparsing==3.1.4
606606
# via
607607
# -c envs/dev.env.txt
608-
# bartiq
609608
# matplotlib
610609
# pydot
611610
pyperclip==1.9.0
@@ -664,7 +663,7 @@ pyzx==0.9.0
664663
# via
665664
# -c envs/dev.env.txt
666665
# -r deps/runtime.txt
667-
qref==0.9.0
666+
qref==0.11.0
668667
# via
669668
# -c envs/dev.env.txt
670669
# -r deps/runtime.txt

dev_tools/requirements/envs/runtime.env.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ babel==2.17.0
6666
# via
6767
# -c envs/dev.env.txt
6868
# jupyterlab-server
69-
bartiq==0.9.0
69+
bartiq==0.12.1
7070
# via
7171
# -c envs/dev.env.txt
7272
# -r deps/runtime.txt
@@ -532,7 +532,6 @@ pygments==2.19.1
532532
pyparsing==3.1.4
533533
# via
534534
# -c envs/dev.env.txt
535-
# bartiq
536535
# matplotlib
537536
# pydot
538537
pyperclip==1.9.0
@@ -568,7 +567,7 @@ pyzx==0.9.0
568567
# via
569568
# -c envs/dev.env.txt
570569
# -r deps/runtime.txt
571-
qref==0.9.0
570+
qref==0.11.0
572571
# via
573572
# -c envs/dev.env.txt
574573
# -r deps/runtime.txt

qualtran/qref_interop/_bloq_to_qref.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from qualtran import Bloq, BloqInstance, CompositeBloq
3535
from qualtran import Connection as QualtranConnection
36-
from qualtran import Register, Side, Soquet
36+
from qualtran import DecomposeNotImplementedError, DecomposeTypeError, Register, Side, Soquet
3737
from qualtran.cirq_interop import CirqGateAsBloq
3838
from qualtran.symbolics import is_symbolic
3939

@@ -128,23 +128,92 @@ def _bloq_instance_name(instance: BloqInstance) -> str:
128128
return f"{_bloq_type(instance.bloq)}_{instance.i}"
129129

130130

131+
@singledispatch
131132
def bloq_to_qref(
132-
obj: Union[Bloq, CompositeBloq, BloqInstance], from_callgraph: bool = False
133+
obj: Union[Bloq, CompositeBloq, BloqInstance],
134+
*,
135+
from_callgraph: bool = False,
136+
decomposition_rules: Union[bool, Iterable[type[Bloq]]] = False,
133137
) -> SchemaV1:
134-
"""Converts Bloq to QREF SchemaV1 object.
138+
"""Converts a Qualtran Bloq into a QREF SchemaV1.
135139
136140
Args:
137-
obj: bloq to be converted
138-
from_callgraph: a flag indicating whether conversion should be performed using only the information
139-
from callgraph. It's useful when a bloq doesn't have a full decomposition, but has a callgraph.
141+
obj: Bloq, CompositeBloq, or BloqInstance to be converted.
142+
from_callgraph: if True, import purely from the bloq’s call-graph.
143+
decomposition_rules:
144+
• False (default): no extra unfolding.
145+
• True: attempt to decompose *every* Bloq via decompose_bloq().
146+
• Iterable of Bloq classes: decompose only those types.
140147
148+
Returns:
149+
A SchemaV1 whose `.program` is the top-level RoutineV1.
141150
"""
142151
if from_callgraph:
143152
if isinstance(obj, BloqInstance):
144-
raise ValueError("BloqInstance object can't be used to generate a callgraph.")
153+
raise ValueError("BloqInstance cannot drive a callgraph import.")
145154
return SchemaV1(version="v1", program=bloq_to_routine_from_callgraph(obj))
146155
else:
147-
return SchemaV1(version="v1", program=bloq_to_routine(obj))
156+
if decomposition_rules is False:
157+
program = bloq_to_routine(obj)
158+
else:
159+
keep = None if decomposition_rules is True else set(decomposition_rules)
160+
program = _routine_with_decomposition(obj, keep)
161+
return SchemaV1(version="v1", program=program)
162+
163+
164+
def _routine_with_decomposition(
165+
obj: Union[Bloq, CompositeBloq, BloqInstance],
166+
preserve: Optional[set[type[Bloq]]],
167+
*,
168+
name: Optional[str] = None,
169+
) -> RoutineV1:
170+
"""Convert a Qualtran bloq into a QREF RoutineV1, selectively decomposing.
171+
172+
This helper will import `obj` (which may be a Bloq, CompositeBloq or BloqInstance)
173+
into a nested RoutineV1, but only expand (“decompose”) those Bloq types
174+
requested in `preserve`.
175+
176+
Behavior:
177+
- If `preserve` is None, every bloq that implements `decompose_bloq()`
178+
will be peeled apart into its CompositeBloq and imported recursively.
179+
- If `preserve` is a set of Bloq classes, then only instances of those
180+
classes will be decomposed; all other Bloqs are treated as atomic leaves
181+
(using the original, default `bloq_to_routine` handler).
182+
183+
Args:
184+
obj: The root Bloq, CompositeBloq, or BloqInstance to convert.
185+
preserve:
186+
• None ⇒ decompose **all** bloqs where possible.
187+
• set of Bloq types ⇒ only those classes will be inlined; others remain atomic.
188+
189+
Returns:
190+
A RoutineV1 representing the imported bloq hierarchy, with clusters
191+
for exactly the bloq types you asked to decompose.
192+
"""
193+
# CompositeBloq → recurse
194+
if isinstance(obj, CompositeBloq):
195+
children = [_routine_with_decomposition(i, preserve) for i in obj.bloq_instances]
196+
connections = [_import_connection(c) for c in obj.connections]
197+
return RoutineV1(
198+
**_extract_common_bloq_attributes(obj, name), children=children, connections=connections
199+
)
200+
201+
# BloqInstance → unwrap
202+
if isinstance(obj, BloqInstance):
203+
return _routine_with_decomposition(obj.bloq, preserve, name=_bloq_instance_name(obj))
204+
205+
# Plain Bloq → decide
206+
bloq = obj
207+
if preserve is None or type(bloq) in preserve:
208+
try:
209+
cb = bloq.decompose_bloq()
210+
except (DecomposeTypeError, DecomposeNotImplementedError):
211+
return _default_leaf(bloq, name=name)
212+
else:
213+
return _routine_with_decomposition(cb, preserve, name=name or type(bloq).__name__)
214+
215+
# leave leaf
216+
return _default_leaf(bloq, name=name)
148217

149218

150219
@singledispatch
@@ -411,3 +480,6 @@ def bloq_to_routine_from_callgraph(bloq: Union[Bloq, CompositeBloq]) -> RoutineV
411480
nodes_to_routine_map[edge[0]] = parent
412481

413482
return nodes_to_routine_map[list(call_graph.nodes)[0]]
483+
484+
485+
_default_leaf = bloq_to_routine.registry[Bloq]

qualtran/qref_interop/_bloq_to_qref_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from qualtran.bloqs.arithmetic.addition import _add_oop_large
3434
from qualtran.bloqs.arithmetic.comparison import LessThanEqual
3535
from qualtran.bloqs.basic_gates import CNOT
36+
from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate
3637
from qualtran.bloqs.block_encoding.lcu_block_encoding import _black_box_lcu_block, _lcu_block
3738
from qualtran.bloqs.chemistry.df.double_factorization import _df_block_encoding, _df_one_body
3839
from qualtran.bloqs.cryptography.rsa.rsa_phase_estimate import _rsa_pe
@@ -329,3 +330,43 @@ def _undecomposed_alias_sampling() -> tuple[Bloq, RoutineV1, str]:
329330
)
330331
def test_importing_qualtran_object_gives_expected_routine_object(qualtran_object, expected_routine):
331332
assert bloq_to_qref(qualtran_object).program == expected_routine
333+
334+
335+
@pytest.mark.parametrize("decomposition_rules", [False, True])
336+
def test_default_vs_true_decomposition_on_less_than_equal(decomposition_rules):
337+
bloq = LessThanEqual(5, 7)
338+
schema = bloq_to_qref(bloq, decomposition_rules=decomposition_rules)
339+
routine = schema.program
340+
341+
if decomposition_rules is False:
342+
# stays atomic
343+
assert routine.children == []
344+
else:
345+
# fully decomposed
346+
assert len(routine.children) == len(bloq.decompose_bloq().bloq_instances)
347+
348+
349+
def test_decomposition_rules_iterable_decomposes_only_requested_types():
350+
bloq = LessThanEqual(3, 4)
351+
# only SU2RotationGate requested → no change
352+
schema = bloq_to_qref(bloq, decomposition_rules=[SU2RotationGate])
353+
assert schema.program.children == []
354+
355+
# requesting LessThanEqual itself → inlined
356+
schema = bloq_to_qref(bloq, decomposition_rules=[LessThanEqual])
357+
assert len(schema.program.children) == len(bloq.decompose_bloq().bloq_instances)
358+
359+
360+
@pytest.mark.parametrize("a,b", [(0.1, 0.2), (1.3, -0.7)])
361+
def test_su2_rotation_can_be_auto_decomposed_with_decomposition_rules(a, b):
362+
bloq = SU2RotationGate(a, b, 0.0)
363+
# default: atomic
364+
assert bloq_to_qref(bloq).program.children == []
365+
366+
# with True → decomposed into Rz + Rx/Ry + GlobalPhase
367+
schema = bloq_to_qref(bloq, decomposition_rules=True)
368+
child_types = {c.type for c in schema.program.children}
369+
assert "GlobalPhase" in child_types
370+
assert "Rz" in child_types
371+
assert any(x in child_types for x in ("Ry", "Rx"))
372+
assert len(schema.program.children) >= 3

0 commit comments

Comments
 (0)