Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions glue/cirq/stimcirq/_obs_annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple, Union

import cirq
import stim
Expand All @@ -16,7 +16,7 @@ def __init__(
*,
parity_keys: Iterable[str] = (),
relative_keys: Iterable[int] = (),
pauli_keys: Iterable[str] = (),
pauli_keys: Union[Iterable[Tuple[cirq.Qid, str]], Iterable[str]] = (),
observable_index: int,
):
"""
Expand All @@ -29,32 +29,66 @@ def __init__(
"""
self.parity_keys = frozenset(parity_keys)
self.relative_keys = frozenset(relative_keys)
self.pauli_keys = frozenset(pauli_keys)
_pauli_keys = []
_qubits_to_pauli_keys = []
for k in pauli_keys:
if isinstance(k, str):
# For backward compatibility
_pauli_keys.append(k)
_qubits_to_pauli_keys.append((cirq.LineQubit(int(k[1:])), k))
else:
qubit, basis_and_id = k
assert isinstance(basis_and_id, str)
assert isinstance(qubit, cirq.Qid)
_pauli_keys.append(basis_and_id)
_qubits_to_pauli_keys.append((qubit, basis_and_id))
self._qubits_to_pauli_keys = tuple(_qubits_to_pauli_keys)
self.pauli_keys = frozenset(_pauli_keys)
Comment thread
Strilanc marked this conversation as resolved.
self.observable_index = observable_index

@property
def qubits(self) -> Tuple[cirq.Qid, ...]:
return ()
return tuple(sorted(q for q, _ in self._qubits_to_pauli_keys))

def with_qubits(self, *new_qubits) -> 'CumulativeObservableAnnotation':
return self
if len(self.qubits) == len(new_qubits):
qubits_to_pauli_keys = dict(self._qubits_to_pauli_keys)
return CumulativeObservableAnnotation(
parity_keys=self.parity_keys,
relative_keys=self.relative_keys,
pauli_keys=tuple(
(newq, qubits_to_pauli_keys[q]) for newq, q in zip(new_qubits, self.qubits)
),
observable_index=self.observable_index,
)

raise ValueError("Number of qubits does not match")

def _value_equality_values_(self) -> Any:
return self.parity_keys, self.relative_keys, self.pauli_keys, self.observable_index
return self.parity_keys, self.relative_keys, self._qubits_to_pauli_keys, self.observable_index

def _circuit_diagram_info_(self, args: Any) -> str:
def _circuit_diagram_info_(self, args: Any) -> Union[str, Tuple[str]]:
items: List[str] = [repr(e) for e in sorted(self.parity_keys)]
items += [f'rec[{e}]' for e in sorted(self.relative_keys)]
items += sorted(self.pauli_keys)
k = ",".join(str(e) for e in items)
return f"Obs{self.observable_index}({k})"

if len(self._qubits_to_pauli_keys):
pauli_map = dict(self._qubits_to_pauli_keys)
out = []
for q in self.qubits:
k = ",".join([str(e) for e in items] + [f'{str(q)}{pauli_map[q][0]}'])
out.append(f"Obs{self.observable_index}({k})")
return tuple(out)
else:
k = ",".join(str(e) for e in items)
return f"Obs{self.observable_index}({k})"


def __repr__(self) -> str:
return (
f'stimcirq.CumulativeObservableAnnotation('
f'parity_keys={sorted(self.parity_keys)}, '
f'relative_keys={sorted(self.relative_keys)}, '
f'pauli_keys={sorted(self.pauli_keys)}, '
f'pauli_keys={sorted(self._qubits_to_pauli_keys)}, '
f'observable_index={self.observable_index!r})'
)

Expand All @@ -66,7 +100,7 @@ def _json_dict_(self) -> Dict[str, Any]:
result = {
'parity_keys': sorted(self.parity_keys),
'observable_index': self.observable_index,
'pauli_keys': sorted(self.pauli_keys),
'pauli_keys': sorted(self._qubits_to_pauli_keys),
}
if self.relative_keys:
result['relative_keys'] = sorted(self.relative_keys)
Expand All @@ -85,6 +119,7 @@ def _stim_conversion_(
edit_measurement_key_lengths: List[Tuple[str, int]],
have_seen_loop: bool = False,
tag: str,
targets: List[int],
**kwargs,
):
# Ideally these references would all be resolved ahead of time, to avoid the redundant
Expand All @@ -109,10 +144,13 @@ def _stim_conversion_(
rec_targets.append(stim.target_rec(-1 - offset))
if not remaining:
break

qubit_to_basis = dict([(q,k[0]) for q, k in self._qubits_to_pauli_keys])

rec_targets.extend(
[
stim.target_pauli(qubit_index=int(k[1:]), pauli=k[0])
for k in sorted(self.pauli_keys)
stim.target_pauli(qubit_index=tid, pauli=qubit_to_basis[q])
for q, tid in zip(self.qubits, targets)
]
)
if remaining:
Expand Down
23 changes: 15 additions & 8 deletions glue/cirq/stimcirq/_obs_annotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ def test_json_serialization():
assert c == c2

def test_json_serialization_with_pauli_keys():
pauli_keys = [(cirq.LineQubit(0), "X"), (cirq.LineQubit(1), "Y"), (cirq.LineQubit(2), "Z")]
c = cirq.Circuit(
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["a", "b"], observable_index=5, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]
parity_keys=["a", "b"], relative_keys=[-1, -3], observable_index=5, pauli_keys=pauli_keys
),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=["X0", "Y1", "Z2"]),
stimcirq.CumulativeObservableAnnotation(observable_index=2, pauli_keys=pauli_keys),
stimcirq.CumulativeObservableAnnotation(parity_keys=["d", "c"], observable_index=5, pauli_keys=pauli_keys),
)
json = cirq.to_json(c)
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
Expand All @@ -208,13 +209,19 @@ def test_json_serialization_with_pauli_keys():
def test_json_backwards_compat_exact():
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5)
packed_v1 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "relative_keys": [\n -2\n ]\n}'
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v2 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 = '{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [],\n "relative_keys": [\n -2\n ]\n}'
assert cirq.read_json(json_text=packed_v1, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3

# With pauli_keys
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=["X0", "Y1", "Z2"])
pauli_keys = [(cirq.LineQubit(0), "X0"), (cirq.LineQubit(1), "Y1"), (cirq.LineQubit(2), "Z2")]
raw = stimcirq.CumulativeObservableAnnotation(parity_keys=['z'], relative_keys=[-2], observable_index=5, pauli_keys=pauli_keys)
packed_v2 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n "X0",\n "Y1",\n "Z2"\n ],\n "relative_keys": [\n -2\n ]\n}'
packed_v3 ='{\n "cirq_type": "CumulativeObservableAnnotation",\n "parity_keys": [\n "z"\n ],\n "observable_index": 5,\n "pauli_keys": [\n [\n {\n "cirq_type": "LineQubit",\n "x": 0\n },\n "X0"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 1\n },\n "Y1"\n ],\n [\n {\n "cirq_type": "LineQubit",\n "x": 2\n },\n "Z2"\n ]\n ],\n "relative_keys": [\n -2\n ]\n}'

assert cirq.read_json(json_text=packed_v2, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v2
assert cirq.read_json(json_text=packed_v3, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
assert cirq.to_json(raw) == packed_v3
4 changes: 2 additions & 2 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,13 @@ def coords_after_offset(

def resolve_measurement_record_keys(
self, targets: Iterable[stim.GateTarget]
) -> Tuple[List[str], List[int], List[str]]:
) -> Tuple[List[str], List[int], List[Tuple[cirq.Qid, str]]]:
pauli_targets, meas_targets = [], []
for t in targets:
if t.is_measurement_record_target:
meas_targets.append(t)
else:
pauli_targets.append(f'{t.pauli_type}{t.value}')
pauli_targets.append((cirq.LineQubit(t.value), f'{t.pauli_type}{t.value}'))

if self.have_seen_loop:
return [], [t.value for t in meas_targets], pauli_targets
Expand Down
1 change: 1 addition & 0 deletions glue/cirq/stimcirq/_stim_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def test_round_trip_with_pauli_obs():
stim_circuit = stim.Circuit("""
QUBIT_COORDS(5, 5) 0
R 0
TICK
OBSERVABLE_INCLUDE(0) X0
TICK
H 0
Expand Down
Loading