Skip to content

Commit cc0926a

Browse files
authored
Add basic support for feedback operations to stimcirq (#1055)
- Currently fails to round trip while preserving simulability - cirq->stim can handle cirq.ClassicallyControlledOperation (wrapping Paulis with a single control) and turns it in "CX rec[-1] 0" or similar - stim->cirq produces a stimcirq.FeedbackPauli which is not simulable (due to not knowing how to lookup the measurement)
1 parent 8d2770e commit cc0926a

5 files changed

Lines changed: 368 additions & 15 deletions

File tree

glue/cirq/stimcirq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._cx_swap_gate import CXSwapGate
44
from ._cz_swap_gate import CZSwapGate
55
from ._det_annotation import DetAnnotation
6+
from ._feedback_pauli import FeedbackPauli
67
from ._obs_annotation import CumulativeObservableAnnotation
78
from ._shift_coords_annotation import ShiftCoordsAnnotation
89
from ._stim_sampler import StimSampler
@@ -19,6 +20,7 @@
1920
JSON_RESOLVERS_DICT = {
2021
"CumulativeObservableAnnotation": CumulativeObservableAnnotation,
2122
"DetAnnotation": DetAnnotation,
23+
"FeedbackPauli": FeedbackPauli,
2224
"MeasureAndOrResetGate": MeasureAndOrResetGate,
2325
"ShiftCoordsAnnotation": ShiftCoordsAnnotation,
2426
"SweepPauli": SweepPauli,

glue/cirq/stimcirq/_cirq_to_stim.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import cirq
77
import stim
88

9-
from ._i_error_gate import IErrorGate
10-
from ._ii_error_gate import IIErrorGate
119
from ._ii_gate import IIGate
1210

1311

@@ -142,6 +140,52 @@ def cirq_circuit_to_stim_data(
142140

143141

144142
StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int], str], None]
143+
StimOpTypeHandler = Callable[[stim.Circuit, cirq.Operation, List[int], str, List[Tuple[str, int]]], None]
144+
145+
146+
def _stim_append_classically_controlled_gate(
147+
circuit: stim.Circuit,
148+
op: cirq.ClassicallyControlledOperation,
149+
targets: List[int],
150+
tag: str,
151+
measurement_key_lengths: List[Tuple[str, int]]):
152+
153+
if len(op.classical_controls) != 1:
154+
raise NotImplementedError(f'Stim only supports single-control Pauli feedback, but got {op=}')
155+
control, = op.classical_controls
156+
if not isinstance(control, cirq.KeyCondition):
157+
raise NotImplementedError(f'Stim only supports single-control Pauli feedback (i.e. a `cirq.KeyCondition` control), but got {control=}')
158+
control: cirq.KeyCondition
159+
gate = op.without_classical_controls().gate
160+
161+
if gate == cirq.X:
162+
stim_gate = 'X'
163+
elif gate == cirq.Y:
164+
stim_gate = 'Y'
165+
elif gate == cirq.Z:
166+
stim_gate = 'Z'
167+
else:
168+
raise NotImplementedError(f'Stim only supports Pauli feedback, but got {op=}')
169+
assert len(targets) == 1
170+
171+
skips_left = control.index
172+
for offset in range(len(measurement_key_lengths)):
173+
m_key, m_len = measurement_key_lengths[-1 - offset]
174+
if m_len != 1:
175+
raise NotImplementedError(f"multi-qubit measurement {m_key!r}")
176+
if m_key == control.key:
177+
if skips_left > 0:
178+
skips_left -= 1
179+
else:
180+
rec_target = stim.target_rec(-1 - offset)
181+
break
182+
else:
183+
raise ValueError(
184+
f"{control!r} was processed before the measurement it referenced."
185+
f" Make sure the referenced measurements keys are actually in the circuit, and come"
186+
f" in an earlier moment (or earlier in the same moment's operation order)."
187+
)
188+
circuit.append(f"C{stim_gate}", [rec_target, targets[0]], tag=tag)
145189

146190

147191
@functools.lru_cache(maxsize=1)
@@ -278,6 +322,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:
278322
}
279323

280324

325+
@functools.lru_cache()
326+
def op_type_to_stim_append_func() -> Dict[Type[cirq.Operation], StimOpTypeHandler]:
327+
"""A dictionary mapping specific gate types to stim circuit appending functions."""
328+
return {
329+
cirq.ClassicallyControlledOperation: _stim_append_classically_controlled_gate,
330+
}
331+
332+
281333
def _stim_append_measurement_gate(
282334
circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str
283335
):
@@ -454,7 +506,8 @@ def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation,
454506

455507
def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
456508
g2f = gate_to_stim_append_func()
457-
t2f = gate_type_to_stim_append_func()
509+
tg2f = gate_type_to_stim_append_func()
510+
to2f = op_type_to_stim_append_func()
458511
for op in operations:
459512
assert isinstance(op, cirq.Operation)
460513
tag = self.tag_func(op)
@@ -500,11 +553,16 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
500553
continue
501554

502555
# Look for recognized gate types like cirq.DepolarizingChannel.
503-
type_append_func = t2f.get(type(gate))
556+
type_append_func = tg2f.get(type(gate))
504557
if type_append_func is not None:
505558
type_append_func(self.out, gate, targets, tag=tag)
506559
continue
507560

561+
op_type_append_func = to2f.get(type(op))
562+
if op_type_append_func is not None:
563+
op_type_append_func(self.out, op, targets, tag, self.key_out)
564+
continue
565+
508566
# Ask unrecognized operations to decompose themselves into simpler operations.
509567
try:
510568
self.process_operations(cirq.decompose_once(op))
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Any, Dict, List, Tuple, Optional
2+
3+
import cirq
4+
import stim
5+
6+
7+
@cirq.value_equality
8+
class FeedbackPauli(cirq.Gate):
9+
"""A Pauli gate conditioned on a prior measurement."""
10+
11+
def __init__(
12+
self,
13+
*,
14+
relative_measurement_index: Optional[int] = None,
15+
pauli: cirq.Pauli,
16+
):
17+
r"""
18+
19+
Args:
20+
relative_measurement_index: A negative integer identifying how many measurements ago is the measurement that
21+
controls the Pauli operation.
22+
pauli: The cirq Pauli operation to apply when the bit is True.
23+
"""
24+
if relative_measurement_index is not None and (relative_measurement_index >= 0 or not isinstance(relative_measurement_index, int)):
25+
raise ValueError(f"{relative_measurement_index=} isn't a negative int (note {type(relative_measurement_index)=})")
26+
self.relative_measurement_index = relative_measurement_index
27+
self.pauli = pauli
28+
29+
def _is_parameterized_(self) -> bool:
30+
return False
31+
32+
def _num_qubits_(self) -> int:
33+
return 1
34+
35+
def _value_equality_values_(self) -> Any:
36+
return self.pauli, self.relative_measurement_index
37+
38+
def _circuit_diagram_info_(self, args: Any) -> str:
39+
return f"{self.pauli}^rec[{self.relative_measurement_index}]"
40+
41+
@staticmethod
42+
def _json_namespace_() -> str:
43+
return ''
44+
45+
def _json_dict_(self) -> Dict[str, Any]:
46+
return {
47+
'pauli': self.pauli,
48+
'relative_measurement_index': self.relative_measurement_index,
49+
}
50+
51+
def __repr__(self) -> str:
52+
return (
53+
f'stimcirq.FeedbackPauli('
54+
f'relative_measurement_index={self.relative_measurement_index!r}, '
55+
f'pauli={self.pauli!r})'
56+
)
57+
58+
def _stim_conversion_(
59+
self,
60+
*,
61+
edit_circuit: stim.Circuit,
62+
tag: str,
63+
targets: List[int],
64+
**kwargs,
65+
):
66+
rec_target = stim.target_rec(self.relative_measurement_index)
67+
edit_circuit.append(f"C{self.pauli}", [rec_target, targets[0]], tag=tag)
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import cirq
2+
import pytest
3+
import stim
4+
import stimcirq
5+
6+
7+
def test_cirq_to_stim_to_cirq_classical_control():
8+
q = cirq.LineQubit(0)
9+
cirq_circuit = cirq.Circuit(
10+
cirq.measure(q, key="test"),
11+
cirq.X(q).with_classical_controls("test").with_tags("test2")
12+
)
13+
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
14+
assert stim_circuit == stim.Circuit("""
15+
M 0
16+
TICK
17+
CX[test2] rec[-1] 0
18+
TICK
19+
""")
20+
assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit(
21+
cirq.measure(q, key="0"),
22+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags("test2")
23+
)
24+
25+
26+
def test_cirq_to_stim_to_cirq_feedback_pauli():
27+
q = cirq.LineQubit(0)
28+
cirq_circuit = cirq.Circuit(
29+
cirq.measure(q, key="test"),
30+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3')
31+
)
32+
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit)
33+
assert stim_circuit == stim.Circuit("""
34+
M 0
35+
TICK
36+
CX[test3] rec[-1] 0
37+
TICK
38+
""")
39+
assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq.Circuit(
40+
cirq.measure(q, key="0"),
41+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(q).with_tags('test3')
42+
)
43+
44+
45+
def test_stim_to_cirq_conversion():
46+
with pytest.raises(NotImplementedError, match="wrong target"):
47+
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
48+
M 0
49+
TICK
50+
XCZ rec[-1] 3
51+
"""))
52+
with pytest.raises(NotImplementedError, match="wrong target"):
53+
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
54+
M 0
55+
TICK
56+
YCZ rec[-1] 3
57+
"""))
58+
with pytest.raises(NotImplementedError, match="wrong target"):
59+
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
60+
M 0
61+
TICK
62+
CY 3 rec[-1]
63+
"""))
64+
with pytest.raises(NotImplementedError, match="wrong target"):
65+
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
66+
M 0
67+
TICK
68+
CX 3 rec[-1]
69+
"""))
70+
with pytest.raises(NotImplementedError, match="Two classical"):
71+
stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
72+
M 0 1
73+
TICK
74+
CZ rec[-1] rec[-2]
75+
"""))
76+
77+
assert stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit("""
78+
M 0
79+
TICK
80+
ZCX rec[-1] 0
81+
ZCY rec[-1] 1
82+
ZCZ rec[-1] 2
83+
XCZ 3 rec[-1]
84+
YCZ 4 rec[-1]
85+
ZCZ 5 rec[-1]
86+
""")) == cirq.Circuit(
87+
cirq.Moment(
88+
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='0')),
89+
),
90+
cirq.Moment(
91+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(0)),
92+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(1)),
93+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(2)),
94+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X).on(cirq.LineQubit(3)),
95+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y).on(cirq.LineQubit(4)),
96+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Z).on(cirq.LineQubit(5)),
97+
),
98+
)
99+
100+
101+
def test_stim_conversion():
102+
a, b, c = cirq.LineQubit.range(3)
103+
104+
with pytest.raises(ValueError, match="earlier"):
105+
stimcirq.cirq_circuit_to_stim_circuit(
106+
cirq.Circuit(cirq.Moment(cirq.X(a).with_classical_controls("unknown")))
107+
)
108+
with pytest.raises(ValueError, match="earlier"):
109+
stimcirq.cirq_circuit_to_stim_circuit(
110+
cirq.Circuit(
111+
cirq.Moment(
112+
cirq.X(a).with_classical_controls("unknown"), cirq.measure(b, key="later")
113+
)
114+
)
115+
)
116+
with pytest.raises(ValueError, match="earlier"):
117+
stimcirq.cirq_circuit_to_stim_circuit(
118+
cirq.Circuit(
119+
cirq.Moment(cirq.X(a).with_classical_controls("unknown")),
120+
cirq.Moment(cirq.measure(b, key="later")),
121+
)
122+
)
123+
assert stimcirq.cirq_circuit_to_stim_circuit(
124+
cirq.Circuit(
125+
cirq.Moment(cirq.measure(b, key="earlier")),
126+
cirq.Moment(cirq.X(b).with_classical_controls("earlier")),
127+
)
128+
) == stim.Circuit(
129+
"""
130+
QUBIT_COORDS(1) 0
131+
M 0
132+
TICK
133+
CX rec[-1] 0
134+
TICK
135+
"""
136+
)
137+
138+
assert stimcirq.cirq_circuit_to_stim_circuit(
139+
cirq.Circuit(
140+
cirq.Moment(cirq.measure(a, key="a"), cirq.measure(b, key="b")),
141+
cirq.Moment(
142+
cirq.X(b).with_classical_controls("a"),
143+
),
144+
cirq.Moment(
145+
cirq.Z(b).with_classical_controls("b"),
146+
),
147+
)
148+
) == stim.Circuit(
149+
"""
150+
M 0 1
151+
TICK
152+
CX rec[-2] 1
153+
TICK
154+
CZ rec[-1] 1
155+
TICK
156+
"""
157+
)
158+
159+
160+
def test_diagram():
161+
a, b = cirq.LineQubit.range(2)
162+
cirq.testing.assert_has_diagram(
163+
cirq.Circuit(
164+
cirq.measure(a, key="a"),
165+
cirq.measure(b, key="b"),
166+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli='Y').on(a),
167+
),
168+
"""
169+
0: ---M('a')---Y^rec[-1]---
170+
171+
1: ---M('b')---------------
172+
""",
173+
use_unicode_characters=False,
174+
)
175+
176+
177+
def test_repr():
178+
val = stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y)
179+
assert eval(repr(val), {"cirq": cirq, "stimcirq": stimcirq}) == val
180+
181+
182+
def test_equality():
183+
eq = cirq.testing.EqualsTester()
184+
eq.add_equality_group(
185+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X),
186+
stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.X))
187+
eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-1, pauli=cirq.Y))
188+
eq.add_equality_group(
189+
stimcirq.FeedbackPauli(relative_measurement_index=-4, pauli=cirq.X),
190+
)
191+
eq.add_equality_group(stimcirq.FeedbackPauli(relative_measurement_index=-10, pauli=cirq.Z))
192+
193+
194+
def test_json_serialization():
195+
c = cirq.Circuit(
196+
stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X).on(cirq.LineQubit(0)),
197+
stimcirq.FeedbackPauli(relative_measurement_index=-5, pauli=cirq.Y).on(cirq.LineQubit(1)),
198+
stimcirq.FeedbackPauli(relative_measurement_index=-7, pauli=cirq.Z).on(cirq.LineQubit(2)),
199+
)
200+
json = cirq.to_json(c)
201+
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
202+
assert c == c2
203+
204+
205+
def test_json_backwards_compat_exact():
206+
raw = stimcirq.FeedbackPauli(relative_measurement_index=-3, pauli=cirq.X)
207+
packed = '{\n "cirq_type": "FeedbackPauli",\n "pauli": {\n "cirq_type": "_PauliX",\n "exponent": 1.0,\n "global_shift": 0.0\n },\n "relative_measurement_index": -3\n}'
208+
assert cirq.to_json(raw) == packed
209+
assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw

0 commit comments

Comments
 (0)