|
18 | 18 | from pathlib import Path |
19 | 19 | from typing import Dict, List, Optional, Sequence, Tuple, Union |
20 | 20 |
|
| 21 | +import numpy as np |
21 | 22 | import sympy |
22 | 23 |
|
23 | 24 | from qualtran import ( |
|
34 | 35 | ) |
35 | 36 | from qualtran._infra.composite_bloq import _get_flat_dangling_soqs |
36 | 37 | from qualtran._infra.data_types import check_dtypes_consistent, QDTypeCheckingSeverity |
| 38 | +from qualtran.drawing.musical_score import WireSymbol |
37 | 39 | from qualtran.resource_counting import GeneralizerT |
38 | 40 |
|
39 | 41 |
|
@@ -225,23 +227,36 @@ def assert_valid_bloq_decomposition(bloq: Optional[Bloq]) -> CompositeBloq: |
225 | 227 | return cbloq |
226 | 228 |
|
227 | 229 |
|
228 | | -def assert_wire_symbols_match_expected(bloq: Bloq, expected_ws: List[str]): |
| 230 | +def assert_wire_symbols_match_expected(bloq: Bloq, expected_ws: List[Union[str, WireSymbol]]): |
229 | 231 | """Assert a bloq's wire symbols match the expected ones. |
230 | 232 |
|
| 233 | + For multi-dimensional registers (with a shape), this will iterate |
| 234 | + through the register indices (see numpy.ndindices for iteration order). |
| 235 | +
|
231 | 236 | Args: |
232 | 237 | bloq: the bloq whose wire symbols we want to check. |
233 | | - expected_ws: A list of the expected wire symbols. |
| 238 | + expected_ws: A list of the expected wire symbols or their associated text. |
234 | 239 | """ |
| 240 | + expected_idx = 0 |
235 | 241 | ws = [] |
236 | | - regs = bloq.signature |
237 | | - # note this will only work if shape = (). |
238 | | - # See: https://github.com/quantumlib/Qualtran/issues/608 |
239 | | - for i, r in enumerate(regs): |
240 | | - # note this will only work if shape = (). |
241 | | - # See: https://github.com/quantumlib/Qualtran/issues/608 |
242 | | - ws.append(bloq.wire_symbol(r, (i,)).text) |
243 | | - |
244 | | - assert ws == expected_ws |
| 242 | + for reg in bloq.signature: |
| 243 | + if reg.shape: |
| 244 | + indices = np.ndindex(reg.shape) |
| 245 | + else: |
| 246 | + indices = [(0,)] |
| 247 | + for idx in indices: |
| 248 | + wire_symbol = bloq.wire_symbol(reg, idx) |
| 249 | + expected_symbol = expected_ws[expected_idx] |
| 250 | + if isinstance(expected_symbol, str): |
| 251 | + wire_text = getattr(wire_symbol, 'text', None) |
| 252 | + assert ( |
| 253 | + wire_text == expected_symbol |
| 254 | + ), f'Wire symbol {wire_text} does not match expected {expected_symbol}' |
| 255 | + else: |
| 256 | + assert ( |
| 257 | + wire_symbol == expected_symbol |
| 258 | + ), f'Wire symbol {wire_symbol} does not match expected {expected_symbol}' |
| 259 | + expected_idx += 1 |
245 | 260 |
|
246 | 261 |
|
247 | 262 | def execute_notebook(name: str): |
|
0 commit comments