Skip to content

Commit b35999e

Browse files
Add custom operation conversion hook to cirq_optree_to_cbloq (#1834)
Adds a backward-compatible hook that lets callers provide a custom Operation -> Bloq conversion function. This is needed by the FLASQ library to perform custom analysis as we convert cirq optrees to bloqs. For example, calculating the Manhattan distance between the two qubits involved in a two-qubit gate. The new parameter defaults to None, so all existing callers are unaffected. When provided, it replaces the default extraction logic for each operation in the circuit. All existing cirq_interop tests pass unchanged and a new test was added for this new functionality. Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 5ba42c1 commit b35999e

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

qualtran/cirq_interop/_cirq_to_bloq.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,18 @@
1818
import itertools
1919
import warnings
2020
from functools import cached_property
21-
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union
21+
from typing import (
22+
Any,
23+
Callable,
24+
Dict,
25+
List,
26+
Optional,
27+
Sequence,
28+
Tuple,
29+
TYPE_CHECKING,
30+
TypeVar,
31+
Union,
32+
)
2233

2334
import cirq
2435
import numpy as np
@@ -319,7 +330,7 @@ def _ensure_in_reg_exists(bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QR
319330
soqs_to_join[qreg.qubits[0]] = soq
320331
elif len(in_reg_qubits) == 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits:
321332
# Cast single QBit registers to the appropriate single-bit register dtype.
322-
err_msg = "Found non-QBit type register which shouldn't happen: " f"{soq}"
333+
err_msg = f"Found non-QBit type register which shouldn't happen: {soq}"
323334
assert isinstance(soq.dtype, QBit), err_msg
324335
if not isinstance(in_reg.dtype, QBit):
325336
qreg_to_qvar[in_reg] = bb.add(Cast(QBit(), in_reg.dtype), reg=soq)
@@ -465,6 +476,7 @@ def cirq_optree_to_cbloq(
465476
signature: Optional[Signature] = None,
466477
in_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
467478
out_quregs: Optional[Dict[str, 'CirqQuregT']] = None,
479+
op_conversion_method: Optional[Callable[[cirq.Operation], Bloq]] = None,
468480
) -> CompositeBloq:
469481
"""Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`.
470482
@@ -495,6 +507,17 @@ def cirq_optree_to_cbloq(
495507
496508
Any qubit in `optree` which is not part of `in_quregs` and `out_quregs` is considered to be
497509
allocated & deallocated inside the CompositeBloq and does not show up in it's signature.
510+
511+
Args:
512+
optree: A Cirq OP_TREE (e.g. a circuit or list of operations).
513+
signature: The signature of the resulting CompositeBloq. If not provided, a default
514+
signature with one thru-register named "qubits" is used.
515+
in_quregs: Mapping from register names to arrays of cirq qubits for LEFT registers.
516+
out_quregs: Mapping from register names to arrays of cirq qubits for RIGHT registers.
517+
op_conversion_method: An optional callable that takes a ``cirq.Operation`` and returns
518+
a ``Bloq``. If provided, this is used instead of the default ``_extract_bloq_from_op``
519+
to convert each operation. This allows callers to attach custom metadata (e.g.
520+
routing costs) to bloqs during conversion.
498521
"""
499522
circuit = cirq.Circuit(optree)
500523
if signature is None:
@@ -533,7 +556,10 @@ def cirq_optree_to_cbloq(
533556

534557
# 2. Add each operation to the composite Bloq.
535558
for op in circuit.all_operations():
536-
bloq = _extract_bloq_from_op(op)
559+
if op_conversion_method is not None:
560+
bloq = op_conversion_method(op)
561+
else:
562+
bloq = _extract_bloq_from_op(op)
537563
if bloq.signature == Signature([]):
538564
bb.add(bloq)
539565
continue

qualtran/cirq_interop/_cirq_to_bloq_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,33 @@ def test_cirq_gate_cost_via_decomp():
286286

287287
gc_swappow = get_cost_value(swappow_bloq, QECGatesCost())
288288
assert gc_swappow == GateCounts(clifford=5, rotation=1, and_bloq=1, measurement=1)
289+
290+
291+
def test_cirq_optree_to_cbloq_op_conversion_method():
292+
"""Test the op_conversion_method parameter of cirq_optree_to_cbloq.
293+
294+
When provided, op_conversion_method should be called for each operation
295+
instead of the default _extract_bloq_from_op.
296+
"""
297+
qubits = cirq.LineQubit.range(3)
298+
circuit = cirq.Circuit(cirq.H(qubits[0]), cirq.CNOT(qubits[0], qubits[1]), cirq.T(qubits[2]))
299+
300+
# Track which operations were converted
301+
converted_ops: list[cirq.Operation] = []
302+
303+
def custom_converter(op: cirq.Operation) -> Bloq:
304+
converted_ops.append(op)
305+
# Fall back to the default behavior.
306+
from qualtran.cirq_interop._cirq_to_bloq import _extract_bloq_from_op
307+
308+
return _extract_bloq_from_op(op)
309+
310+
cbloq = cirq_optree_to_cbloq(circuit, op_conversion_method=custom_converter)
311+
312+
# Verify the custom converter was called for each operation.
313+
assert len(converted_ops) == 3
314+
315+
# The resulting CompositeBloq should still produce the correct unitary.
316+
bloq_unitary = cbloq.tensor_contract()
317+
cirq_unitary = circuit.unitary(qubits)
318+
np.testing.assert_allclose(cirq_unitary, bloq_unitary, atol=1e-8)

0 commit comments

Comments
 (0)