Skip to content

Commit 12a10a8

Browse files
Add single-measure key option for Cirq converison (#1043)
This adds an optional kwarg to `stimcirq.stim_circuit_to_cirq_circuit` for using a fixed measure key for a all cirq.MeasurementGate ops. This will result in cirq.Result containing a single measure record array with all measurements ordered along the "instance" axis in the order they appear in the stim circuit.
1 parent a6d5c7f commit 12a10a8

2 files changed

Lines changed: 71 additions & 9 deletions

File tree

glue/cirq/stimcirq/_stim_to_cirq.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Dict,
99
Iterable,
1010
List,
11+
Optional,
1112
Tuple,
1213
Union,
1314
)
@@ -64,19 +65,25 @@ def _proper_transform_circuit_qubits(circuit: cirq.AbstractCircuit, remap: Dict[
6465

6566

6667
class CircuitTranslationTracker:
67-
def __init__(self, flatten: bool):
68+
def __init__(self, flatten: bool, single_measure_key: Optional[str] = None):
6869
self.qubit_coords: Dict[int, cirq.Qid] = {}
6970
self.origin: DefaultDict[float] = collections.defaultdict(float)
7071
self.num_measurements_seen = 0
7172
self.full_circuit = cirq.Circuit()
7273
self.tick_circuit = cirq.Circuit()
7374
self.flatten = flatten
7475
self.have_seen_loop = False
76+
self.single_measure_key = single_measure_key
7577

7678
def get_next_measure_id(self) -> int:
7779
self.num_measurements_seen += 1
7880
return self.num_measurements_seen - 1
7981

82+
def get_next_measure_key(self) -> str:
83+
if self.single_measure_key is None:
84+
return str(self.get_next_measure_id())
85+
return self.single_measure_key
86+
8087
def append_operation(self, op: cirq.Operation) -> None:
8188
self.tick_circuit.append(op, strategy=cirq.InsertStrategy.INLINE)
8289

@@ -186,7 +193,7 @@ def process_measurement_instruction(
186193
for t in targets:
187194
if not t.is_qubit_target:
188195
raise NotImplementedError(f"instruction={instruction!r}")
189-
key = str(self.get_next_measure_id())
196+
key = self.get_next_measure_key()
190197
self.append_operation(
191198
MeasureAndOrResetGate(
192199
measure=measure,
@@ -248,7 +255,7 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None:
248255

249256
obs = _stim_targets_to_dense_pauli_string(group)
250257
qubits = [cirq.LineQubit(t.value) for t in group]
251-
key = str(self.get_next_measure_id())
258+
key = self.get_next_measure_key()
252259
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags))
253260

254261
def process_spp_dag(self, instruction: stim.CircuitInstruction) -> None:
@@ -290,7 +297,7 @@ def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> No
290297
if targets[0].is_inverted_result_target ^ targets[1].is_inverted_result_target:
291298
obs *= -1
292299
qubits = [cirq.LineQubit(targets[0].value), cirq.LineQubit(targets[1].value)]
293-
key = str(self.get_next_measure_id())
300+
key = self.get_next_measure_key()
294301
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags))
295302

296303
def process_mxx(self, instruction: stim.CircuitInstruction) -> None:
@@ -309,7 +316,7 @@ def process_mpad(self, instruction: stim.CircuitInstruction) -> None:
309316
if t.value == 1:
310317
obs *= -1
311318
qubits = []
312-
key = str(self.get_next_measure_id())
319+
key = self.get_next_measure_key()
313320
self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits))
314321

315322
def process_correlated_error(self, instruction: stim.CircuitInstruction) -> None:
@@ -632,12 +639,17 @@ def handler(
632639
}
633640

634641

635-
def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False) -> cirq.Circuit:
642+
def stim_circuit_to_cirq_circuit(
643+
circuit: stim.Circuit,
644+
*,
645+
flatten: bool = False,
646+
single_measure_key: Optional[str] = None,
647+
) -> cirq.Circuit:
636648
"""Converts a stim circuit into an equivalent cirq circuit.
637649
638650
Qubit indices are turned into cirq.LineQubit instances. Measurements are
639651
keyed by their ordering (e.g. the first measurement is keyed "0", the second
640-
is keyed "1", etc).
652+
is keyed "1", etc) unless a fixed measure_key is provided.
641653
642654
Not all circuits can be converted:
643655
- ELSE_CORRELATED_ERROR instructions are not supported.
@@ -652,6 +664,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
652664
explicitly repeating their instructions multiple times. Also,
653665
SHIFT_COORDS instructions are removed by appropriately adjusting the
654666
coordinate metadata of later instructions.
667+
single_measure_key: Defaults to None. If provided, all measurements are
668+
keyed with this string instead of sequentially generated numbers.
655669
656670
Returns:
657671
The converted circuit.
@@ -671,6 +685,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
671685
672686
1: ───────X──────────────────!M('0')───
673687
"""
674-
tracker = CircuitTranslationTracker(flatten=flatten)
688+
tracker = CircuitTranslationTracker(
689+
flatten=flatten, single_measure_key=single_measure_key
690+
)
675691
tracker.process_circuit(repetitions=1, circuit=circuit)
676692
return tracker.output()

glue/cirq/stimcirq/_stim_to_cirq_test.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,4 +778,50 @@ def test_round_trip_with_pauli_obs():
778778
""")
779779
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit)
780780
restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
781-
assert restored_circuit == stim_circuit
781+
assert restored_circuit == stim_circuit
782+
783+
784+
def test_single_measure_key_order():
785+
stim_circuits = [
786+
stim.Circuit(
787+
"""
788+
X 1
789+
X 1 3
790+
X 1 3
791+
X 1 3 2
792+
M 1
793+
M 3
794+
M 2
795+
M 0
796+
"""
797+
),
798+
stim.Circuit(
799+
"""
800+
X 1
801+
X 1
802+
X 1
803+
X 1
804+
M 1 3
805+
X 2
806+
M 2 0
807+
"""
808+
)
809+
]
810+
measure_key = "m"
811+
for stim_circuit in stim_circuits:
812+
cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(
813+
stim_circuit, single_measure_key=measure_key
814+
)
815+
qubits = cirq.LineQubit.range(4)
816+
expected_order = [
817+
qubits[targ.qubit_value]
818+
for inst in stim_circuit if inst.name == "M"
819+
for targ in inst.targets_copy()
820+
]
821+
actual_order = []
822+
for op in cirq_circuit.all_operations():
823+
if isinstance(op.gate, cirq.MeasurementGate):
824+
assert op.gate.key == measure_key
825+
assert len(op.qubits) == 1
826+
actual_order.append(op.qubits[0])
827+
assert expected_order == actual_order

0 commit comments

Comments
 (0)