Skip to content

Commit 8fa6e2a

Browse files
authored
Merge branch 'main' into 2024-04/generic-costs-framework
2 parents 5b9ad8b + 54bc181 commit 8fa6e2a

2 files changed

Lines changed: 31 additions & 11 deletions

File tree

qualtran/bloqs/basic_gates/toffoli_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from qualtran import BloqBuilder
2020
from qualtran.bloqs.basic_gates import TGate, Toffoli, ZeroState
2121
from qualtran.bloqs.basic_gates.toffoli import _toffoli
22+
from qualtran.drawing.musical_score import Circle, ModPlus
23+
from qualtran.testing import assert_wire_symbols_match_expected
2224

2325

2426
def test_toffoli(bloq_autotester):
@@ -50,6 +52,9 @@ def test_toffoli_cirq():
5052
│ │
5153
_c(2): ───X───X───""",
5254
)
55+
assert_wire_symbols_match_expected(
56+
Toffoli(), [Circle(filled=True), Circle(filled=True), ModPlus()]
57+
)
5358

5459

5560
def test_classical_sim():

qualtran/testing.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919
from typing import Dict, List, Optional, Sequence, Tuple, Union
2020

21+
import numpy as np
2122
import sympy
2223

2324
from qualtran import (
@@ -34,6 +35,7 @@
3435
)
3536
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
3637
from qualtran._infra.data_types import check_dtypes_consistent, QDTypeCheckingSeverity
38+
from qualtran.drawing.musical_score import WireSymbol
3739
from qualtran.resource_counting import GeneralizerT
3840

3941

@@ -225,23 +227,36 @@ def assert_valid_bloq_decomposition(bloq: Optional[Bloq]) -> CompositeBloq:
225227
return cbloq
226228

227229

228-
def assert_wire_symbols_match_expected(bloq: Bloq, expected_ws: List[str]):
230+
def assert_wire_symbols_match_expected(bloq: Bloq, expected_ws: List[Union[str, WireSymbol]]):
229231
"""Assert a bloq's wire symbols match the expected ones.
230232
233+
For multi-dimensional registers (with a shape), this will iterate
234+
through the register indices (see numpy.ndindices for iteration order).
235+
231236
Args:
232237
bloq: the bloq whose wire symbols we want to check.
233-
expected_ws: A list of the expected wire symbols.
238+
expected_ws: A list of the expected wire symbols or their associated text.
234239
"""
240+
expected_idx = 0
235241
ws = []
236-
regs = bloq.signature
237-
# note this will only work if shape = ().
238-
# See: https://github.com/quantumlib/Qualtran/issues/608
239-
for i, r in enumerate(regs):
240-
# note this will only work if shape = ().
241-
# See: https://github.com/quantumlib/Qualtran/issues/608
242-
ws.append(bloq.wire_symbol(r, (i,)).text)
243-
244-
assert ws == expected_ws
242+
for reg in bloq.signature:
243+
if reg.shape:
244+
indices = np.ndindex(reg.shape)
245+
else:
246+
indices = [(0,)]
247+
for idx in indices:
248+
wire_symbol = bloq.wire_symbol(reg, idx)
249+
expected_symbol = expected_ws[expected_idx]
250+
if isinstance(expected_symbol, str):
251+
wire_text = getattr(wire_symbol, 'text', None)
252+
assert (
253+
wire_text == expected_symbol
254+
), f'Wire symbol {wire_text} does not match expected {expected_symbol}'
255+
else:
256+
assert (
257+
wire_symbol == expected_symbol
258+
), f'Wire symbol {wire_symbol} does not match expected {expected_symbol}'
259+
expected_idx += 1
245260

246261

247262
def execute_notebook(name: str):

0 commit comments

Comments
 (0)