88 Dict ,
99 Iterable ,
1010 List ,
11+ Optional ,
1112 Tuple ,
1213 Union ,
1314)
@@ -64,19 +65,25 @@ def _proper_transform_circuit_qubits(circuit: cirq.AbstractCircuit, remap: Dict[
6465
6566
6667class CircuitTranslationTracker :
67- def __init__ (self , flatten : bool ):
68+ def __init__ (self , flatten : bool , single_measure_key : Optional [ str ] = None ):
6869 self .qubit_coords : Dict [int , cirq .Qid ] = {}
6970 self .origin : DefaultDict [float ] = collections .defaultdict (float )
7071 self .num_measurements_seen = 0
7172 self .full_circuit = cirq .Circuit ()
7273 self .tick_circuit = cirq .Circuit ()
7374 self .flatten = flatten
7475 self .have_seen_loop = False
76+ self .single_measure_key = single_measure_key
7577
7678 def get_next_measure_id (self ) -> int :
7779 self .num_measurements_seen += 1
7880 return self .num_measurements_seen - 1
7981
82+ def get_next_measure_key (self ) -> str :
83+ if self .single_measure_key is None :
84+ return str (self .get_next_measure_id ())
85+ return self .single_measure_key
86+
8087 def append_operation (self , op : cirq .Operation ) -> None :
8188 self .tick_circuit .append (op , strategy = cirq .InsertStrategy .INLINE )
8289
@@ -186,7 +193,7 @@ def process_measurement_instruction(
186193 for t in targets :
187194 if not t .is_qubit_target :
188195 raise NotImplementedError (f"instruction={ instruction !r} " )
189- key = str ( self .get_next_measure_id () )
196+ key = self .get_next_measure_key ( )
190197 self .append_operation (
191198 MeasureAndOrResetGate (
192199 measure = measure ,
@@ -248,7 +255,7 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None:
248255
249256 obs = _stim_targets_to_dense_pauli_string (group )
250257 qubits = [cirq .LineQubit (t .value ) for t in group ]
251- key = str ( self .get_next_measure_id () )
258+ key = self .get_next_measure_key ( )
252259 self .append_operation (cirq .PauliMeasurementGate (obs , key = key ).on (* qubits ).with_tags (* tags ))
253260
254261 def process_spp_dag (self , instruction : stim .CircuitInstruction ) -> None :
@@ -290,7 +297,7 @@ def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> No
290297 if targets [0 ].is_inverted_result_target ^ targets [1 ].is_inverted_result_target :
291298 obs *= - 1
292299 qubits = [cirq .LineQubit (targets [0 ].value ), cirq .LineQubit (targets [1 ].value )]
293- key = str ( self .get_next_measure_id () )
300+ key = self .get_next_measure_key ( )
294301 self .append_operation (cirq .PauliMeasurementGate (obs , key = key ).on (* qubits ).with_tags (* tags ))
295302
296303 def process_mxx (self , instruction : stim .CircuitInstruction ) -> None :
@@ -309,7 +316,7 @@ def process_mpad(self, instruction: stim.CircuitInstruction) -> None:
309316 if t .value == 1 :
310317 obs *= - 1
311318 qubits = []
312- key = str ( self .get_next_measure_id () )
319+ key = self .get_next_measure_key ( )
313320 self .append_operation (cirq .PauliMeasurementGate (obs , key = key ).on (* qubits ))
314321
315322 def process_correlated_error (self , instruction : stim .CircuitInstruction ) -> None :
@@ -632,12 +639,17 @@ def handler(
632639 }
633640
634641
635- def stim_circuit_to_cirq_circuit (circuit : stim .Circuit , * , flatten : bool = False ) -> cirq .Circuit :
642+ def stim_circuit_to_cirq_circuit (
643+ circuit : stim .Circuit ,
644+ * ,
645+ flatten : bool = False ,
646+ single_measure_key : Optional [str ] = None ,
647+ ) -> cirq .Circuit :
636648 """Converts a stim circuit into an equivalent cirq circuit.
637649
638650 Qubit indices are turned into cirq.LineQubit instances. Measurements are
639651 keyed by their ordering (e.g. the first measurement is keyed "0", the second
640- is keyed "1", etc).
652+ is keyed "1", etc) unless a fixed measure_key is provided .
641653
642654 Not all circuits can be converted:
643655 - ELSE_CORRELATED_ERROR instructions are not supported.
@@ -652,6 +664,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
652664 explicitly repeating their instructions multiple times. Also,
653665 SHIFT_COORDS instructions are removed by appropriately adjusting the
654666 coordinate metadata of later instructions.
667+ single_measure_key: Defaults to None. If provided, all measurements are
668+ keyed with this string instead of sequentially generated numbers.
655669
656670 Returns:
657671 The converted circuit.
@@ -671,6 +685,8 @@ def stim_circuit_to_cirq_circuit(circuit: stim.Circuit, *, flatten: bool = False
671685 │
672686 1: ───────X──────────────────!M('0')───
673687 """
674- tracker = CircuitTranslationTracker (flatten = flatten )
688+ tracker = CircuitTranslationTracker (
689+ flatten = flatten , single_measure_key = single_measure_key
690+ )
675691 tracker .process_circuit (repetitions = 1 , circuit = circuit )
676692 return tracker .output ()
0 commit comments