Skip to content

Commit aa47606

Browse files
authored
SoquetT becomes a protocol (#1721)
Switches the `SoquetT` union type alias to a `typing.Protocol`. This supports "duck typing", which is appropriate for the situation where `SoquetT` is used as a type annotation. Specifically, `SoquetT` is now anything that has `.shape` and `.item(*args)`. With the addition of these properties/methods to the `Soquet` class, now this is satisfied by `Soquet` and `NDArray`. Note: we can't use mypy to annotate things with `NDArray[Soquet]` now or before due to limitations in numpy. When users or developers need to dispatch depending on whether a `SoquetT` is a single soquet or an array thereof, they should use the new `BloqBuilder` methods. The example from the new docstring: ```python >>> soq_or_soqs: SoquetT ... if BloqBuilder.is_ndarray(soq_or_soqs): ... first_soq = soq_or_soqs.reshape(-1).item(0) ... else: ... # Note: `.item()` raises if not a single item. ... first_soq = soq_or_soqs.item() ``` ----- In the protocol implementations and bloq standard library, this PR: - Adds the typing protocol and introduces new methods to `Soquet` - Uses the type-narrowing idioms in the `_infra/` and protocols modules. Following this, pytest passes but there are mypy issues in the `bloqs/` standard library - Then, uses the type-narrowing idioms in the `bloqs/` standard library. - ~uses new alias `soq.dtype` instead of `soq.reg.dtype`.~ (removed from this PR) Regarding #1720: this PR shows the changes to the standard library that would be required to avoid `isinstance(x, Soquet)` checks
1 parent 99fe289 commit aa47606

13 files changed

Lines changed: 138 additions & 62 deletions

File tree

qualtran/_infra/composite_bloq.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
Mapping,
2828
Optional,
2929
overload,
30+
Protocol,
3031
Sequence,
3132
Set,
3233
Tuple,
3334
TYPE_CHECKING,
35+
TypeGuard,
3436
TypeVar,
3537
Union,
3638
)
@@ -56,13 +58,30 @@
5658
from qualtran.simulation.classical_sim import ClassicalValT
5759
from qualtran.symbolics import SymbolicInt
5860

59-
# NDArrays must be bound to np.generic
60-
_SoquetType = TypeVar('_SoquetType', bound=np.generic)
6161

62-
SoquetT = Union[Soquet, NDArray[_SoquetType]]
63-
"""A `Soquet` or array of soquets."""
62+
class SoquetT(Protocol):
63+
"""Either a Soquet or an array thereof.
6464
65-
SoquetInT = Union[Soquet, NDArray[_SoquetType], Sequence[Soquet]]
65+
To narrow objects of this type, use `BloqBuilder.is_single(soq)` and/or
66+
`BloqBuilder.is_ndarray(soqs)`.
67+
68+
Example:
69+
>>> soq_or_soqs: SoquetT
70+
... if BloqBuilder.is_ndarray(soq_or_soqs):
71+
... first_soq = soq_or_soqs.reshape(-1).item(0)
72+
... else:
73+
... # Note: `.item()` raises if not a single item.
74+
... first_soq = soq_or_soqs.item()
75+
76+
"""
77+
78+
@property
79+
def shape(self) -> Tuple[int, ...]: ...
80+
81+
def item(self, *args) -> Soquet: ...
82+
83+
84+
SoquetInT = Union[SoquetT, Sequence[SoquetT]]
6685
"""A soquet or array-like of soquets.
6786
6887
This type alias is used for input argument to parts of the library that are more
@@ -693,9 +712,10 @@ def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]:
693712
"""
694713
soqvals = []
695714
for soq_or_arr in vals:
696-
if isinstance(soq_or_arr, Soquet):
697-
soqvals.append(soq_or_arr)
715+
if BloqBuilder.is_single(soq_or_arr):
716+
soqvals.append(soq_or_arr.item())
698717
else:
718+
assert BloqBuilder.is_ndarray(soq_or_arr)
699719
soqvals.extend(soq_or_arr.reshape(-1))
700720
return soqvals
701721

@@ -802,13 +822,10 @@ def _process_soquets(
802822
unchecked_names.remove(reg.name) # so we can check for surplus arguments.
803823

804824
for li in reg.all_idxs():
805-
idxed_soq = in_soq[li]
806-
assert isinstance(idxed_soq, Soquet), idxed_soq
825+
idxed_soq = in_soq[li].item()
807826
func(idxed_soq, reg, li)
808-
if not check_dtypes_consistent(idxed_soq.reg.dtype, reg.dtype):
809-
extra_str = (
810-
f"{idxed_soq.reg.name}: {idxed_soq.reg.dtype} vs {reg.name}: {reg.dtype}"
811-
)
827+
if not check_dtypes_consistent(idxed_soq.dtype, reg.dtype):
828+
extra_str = f"{idxed_soq.reg.name}: {idxed_soq.dtype} vs {reg.name}: {reg.dtype}"
812829
raise BloqError(
813830
f"{debug_str} register dtypes are not consistent {extra_str}."
814831
) from None
@@ -838,9 +855,9 @@ def _map_soqs(
838855
# First: flatten out any numpy arrays
839856
flat_soq_map: Dict[Soquet, Soquet] = {}
840857
for old_soqs, new_soqs in soq_map:
841-
if isinstance(old_soqs, Soquet):
842-
assert isinstance(new_soqs, Soquet), new_soqs
843-
flat_soq_map[old_soqs] = new_soqs
858+
if BloqBuilder.is_single(old_soqs):
859+
assert BloqBuilder.is_single(new_soqs), new_soqs
860+
flat_soq_map[old_soqs] = new_soqs.item()
844861
continue
845862

846863
assert isinstance(old_soqs, np.ndarray), old_soqs
@@ -858,9 +875,9 @@ def _map_soq(soq: Soquet) -> Soquet:
858875
vmap = np.vectorize(_map_soq, otypes=[object])
859876

860877
def _map_soqs(soqs: SoquetT) -> SoquetT:
861-
if isinstance(soqs, Soquet):
862-
return _map_soq(soqs)
863-
return vmap(soqs)
878+
if BloqBuilder.is_ndarray(soqs):
879+
return vmap(soqs)
880+
return _map_soq(soqs.item())
864881

865882
return {name: _map_soqs(soqs) for name, soqs in soqs.items()}
866883

@@ -1098,6 +1115,24 @@ def from_signature(
10981115

10991116
return bb, initial_soqs
11001117

1118+
@staticmethod
1119+
def is_single(x: 'SoquetT') -> TypeGuard['Soquet']:
1120+
"""Returns True if `x` is a single soquet (not an ndarray of them).
1121+
1122+
This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1123+
for "duck typing".
1124+
"""
1125+
return x.shape == ()
1126+
1127+
@staticmethod
1128+
def is_ndarray(x: 'SoquetT') -> TypeGuard['NDArray']:
1129+
"""Returns True if `x` is an ndarray of soquets (not a single one).
1130+
1131+
This doesn't use stringent runtime type checking; it uses the SoquetT protocol
1132+
for "duck typing".
1133+
"""
1134+
return x.shape != ()
1135+
11011136
@staticmethod
11021137
def map_soqs(
11031138
soqs: Dict[str, SoquetT], soq_map: Iterable[Tuple[SoquetT, SoquetT]]
@@ -1302,8 +1337,7 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]:
13021337
cbloq = bloq.decompose_bloq()
13031338

13041339
for k, v in in_soqs.items():
1305-
if not isinstance(v, Soquet):
1306-
in_soqs[k] = np.asarray(v)
1340+
in_soqs[k] = np.asarray(v)
13071341

13081342
# Initial mapping of LeftDangle according to user-provided in_soqs.
13091343
soq_map: List[Tuple[SoquetT, SoquetT]] = [
@@ -1343,12 +1377,13 @@ def finalize(self, **final_soqs: SoquetT) -> CompositeBloq:
13431377

13441378
def _infer_reg(name: str, soq: SoquetT) -> Register:
13451379
"""Go from Soquet -> register, but use a specific name for the register."""
1346-
if isinstance(soq, Soquet):
1347-
return Register(name=name, dtype=soq.reg.dtype, side=Side.RIGHT)
1380+
if BloqBuilder.is_single(soq):
1381+
return Register(name=name, dtype=soq.dtype, side=Side.RIGHT)
1382+
assert BloqBuilder.is_ndarray(soq)
13481383

13491384
# Get info from 0th soquet in an ndarray.
13501385
return Register(
1351-
name=name, dtype=soq.reshape(-1)[0].reg.dtype, shape=soq.shape, side=Side.RIGHT
1386+
name=name, dtype=soq.reshape(-1).item(0).dtype, shape=soq.shape, side=Side.RIGHT
13521387
)
13531388

13541389
right_reg_names = [reg.name for reg in self._regs if reg.side & Side.RIGHT]
@@ -1395,10 +1430,10 @@ def allocate(
13951430
def free(self, soq: Soquet, dirty: bool = False) -> None:
13961431
from qualtran.bloqs.bookkeeping import Free
13971432

1398-
if not isinstance(soq, Soquet):
1433+
if not BloqBuilder.is_single(soq):
13991434
raise ValueError("`free` expects a single Soquet to free.")
14001435

1401-
qdtype = soq.reg.dtype
1436+
qdtype = soq.dtype
14021437
if not isinstance(qdtype, QDType):
14031438
raise ValueError("`free` can only free quantum registers.")
14041439

@@ -1408,10 +1443,10 @@ def split(self, soq: SoquetInT) -> NDArray[Soquet]: # type: ignore[type-var]
14081443
"""Add a Split bloq to split up a register."""
14091444
from qualtran.bloqs.bookkeeping import Split
14101445

1411-
if not isinstance(soq, Soquet):
1446+
if not BloqBuilder.is_single(soq): # type: ignore[arg-type]
14121447
raise ValueError("`split` expects a single Soquet to split.")
14131448

1414-
qdtype = soq.reg.dtype
1449+
qdtype = soq.dtype
14151450
if not isinstance(qdtype, QDType):
14161451
raise ValueError("`split` can only split quantum registers.")
14171452

qualtran/_infra/composite_bloq_test.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, List, Tuple
16+
from typing import cast, Dict, List, Tuple
1717

1818
import attrs
1919
import networkx as nx
2020
import numpy as np
2121
import pytest
2222
import sympy
2323
from numpy.typing import NDArray
24+
from typing_extensions import assert_type
2425

2526
import qualtran.testing as qlt_testing
2627
from qualtran import (
@@ -646,6 +647,40 @@ def test_get_soquet():
646647
_ = _get_soquet(binst=binst, reg_name='in', right=True, binst_graph=binst_graph)
647648

648649

650+
def test_can_tell_individual_from_ndsoquet():
651+
s1 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(0,))
652+
s2 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(1,))
653+
s3 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(2,))
654+
s4 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(3,))
655+
656+
# A ndarray of soquet objects should be SoquetT and we can tell by checking its shape.
657+
ndsoq: SoquetT = np.array([s1, s2, s3, s4])
658+
assert_type(ndsoq, SoquetT)
659+
assert ndsoq.shape
660+
assert ndsoq.shape == (4,)
661+
assert ndsoq.item(2) == s3
662+
with pytest.raises(ValueError, match=r'scalar'):
663+
_ = ndsoq.item()
664+
665+
# A single soquet is still a valid SoquetT, and it has a false-y shape.
666+
single_soq: SoquetT = s1
667+
assert_type(single_soq, SoquetT)
668+
assert not single_soq.shape
669+
assert single_soq.shape == ()
670+
single_soq_unwarp = single_soq.item()
671+
assert single_soq_unwarp == s1
672+
673+
# A single soquet wrapped in a 0-dim ndarray is ok if you call `item()`.
674+
single_soq2: SoquetT = np.asarray(s1)
675+
assert_type(single_soq2, SoquetT)
676+
assert not single_soq2.shape
677+
assert single_soq2.shape == ()
678+
single_soq2_unwrap = single_soq2.item()
679+
assert hash(single_soq2_unwrap) == hash(s1)
680+
assert single_soq2_unwrap == s1
681+
assert isinstance(single_soq2_unwrap, Soquet)
682+
683+
649684
@pytest.mark.notebook
650685
def test_notebook():
651686
qlt_testing.execute_notebook('composite_bloq')

qualtran/_infra/quantum_graph.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from attrs import field, frozen
2121

2222
if TYPE_CHECKING:
23-
from qualtran import Bloq, Register
23+
from qualtran import Bloq, BloqBuilder, QCDType, Register
2424

2525

2626
@frozen
@@ -107,6 +107,20 @@ def _check_idx(self, attribute, value):
107107
for i, shape in zip(value, self.reg.shape):
108108
if i >= shape:
109109
raise ValueError(f"Bad index {i} for {self.reg}.")
110+
return value
111+
112+
@property
113+
def dtype(self) -> 'QCDType':
114+
return self.reg.dtype
115+
116+
@property
117+
def shape(self) -> Tuple[int, ...]:
118+
return ()
119+
120+
def item(self, *args) -> 'Soquet':
121+
if args:
122+
raise ValueError("Tried to index into a single soquet.")
123+
return self
110124

111125
def pretty(self) -> str:
112126
label = self.reg.name

qualtran/_infra/quantum_graph_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def test_soquet():
4848
assert soq.idx == ()
4949
assert soq.pretty() == 'x'
5050

51+
assert soq.item() == soq
52+
assert soq.dtype == QAny(10)
53+
5154

5255
def test_soquet_idxed():
5356
binst = BloqInstance(TestTwoBitOp(), i=0)

qualtran/bloqs/basic_gates/rotation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ def signature(self) -> 'Signature':
219219
def build_composite_bloq(self, bb: 'BloqBuilder', q: 'SoquetT') -> Dict[str, 'SoquetT']:
220220
from qualtran.bloqs.mcmt import And
221221

222-
q1, q2 = q # type: ignore
223-
(q1, q2), anc = bb.add(And(), ctrl=[q1, q2])
222+
(q1, q2), anc = bb.add(And(), ctrl=q)
224223
anc = bb.add(ZPowGate(self.exponent, eps=self.eps), q=anc)
225224
(q1, q2) = bb.add(And().adjoint(), ctrl=[q1, q2], target=anc)
226225
return {'q': np.array([q1, q2])}

qualtran/bloqs/block_encoding/sparse_matrix.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def build_composite_bloq(
242242
if is_symbolic(self.system_bitsize) or is_symbolic(self.row_oracle.num_nonzero):
243243
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")
244244

245-
assert not isinstance(ancilla, np.ndarray)
246-
ancilla_bits = bb.split(ancilla)
245+
ancilla_bits = bb.split(ancilla.item())
247246
q, l = ancilla_bits[0], bb.join(ancilla_bits[1:])
248247

249248
l = bb.add(self.diffusion, target=l)

qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ def _reshape_reg(
154154
"""
155155
# np.prod(()) returns a float (1.0), so take int
156156
size = int(np.prod(out_shape))
157-
if isinstance(in_reg, np.ndarray):
157+
if BloqBuilder.is_ndarray(in_reg):
158158
# split an array of bitsize qubits into flat list of qubits
159159
split_qubits = bb.split(bb.join(np.concatenate([bb.split(x) for x in in_reg.ravel()])))
160160
else:
161-
split_qubits = bb.split(in_reg)
161+
split_qubits = bb.split(in_reg.item())
162162
merged_qubits = np.array(
163163
[bb.join(split_qubits[i * bitsize : (i + 1) * bitsize]) for i in range(size)]
164164
)

qualtran/bloqs/chemistry/trotter/grid_ham/potential.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
QAny,
2828
Register,
2929
Signature,
30-
Soquet,
3130
SoquetT,
3231
)
3332
from qualtran._infra.data_types import BQUInt
@@ -93,7 +92,7 @@ def wire_symbol(
9392
def build_composite_bloq(
9493
self, bb: BloqBuilder, *, system_i: SoquetT, system_j: SoquetT
9594
) -> Dict[str, SoquetT]:
96-
if isinstance(system_i, Soquet) or isinstance(system_j, Soquet):
95+
if not (BloqBuilder.is_ndarray(system_i) and BloqBuilder.is_ndarray(system_j)):
9796
raise ValueError("system_i and system_j must be numpy arrays of Soquet")
9897
# compute r_i - r_j
9998
# r_i + (-r_j), in practice we need to flip the sign bit, but this is just 3 cliffords.
@@ -227,7 +226,7 @@ def wire_symbol(
227226
return super().wire_symbol(reg, idx)
228227

229228
def build_composite_bloq(self, bb: BloqBuilder, *, system: SoquetT) -> Dict[str, SoquetT]:
230-
if isinstance(system, Soquet):
229+
if not BloqBuilder.is_ndarray(system):
231230
raise ValueError("system must be a numpy array of Soquet")
232231
bitsize = (self.num_grid - 1).bit_length() + 1
233232
ij_pairs = np.triu_indices(self.num_elec, k=1)

qualtran/bloqs/data_loading/qroam_clean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
520520
# Construct and return dictionary of final soquets.
521521
soqs |= {reg.name: soq for reg, soq in zip(self.control_registers, ctrl)}
522522
soqs |= {reg.name: soq for reg, soq in zip(self.selection_registers, selection)}
523-
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[union-attr]
524-
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[union-attr]
523+
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[attr-defined]
524+
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[attr-defined]
525525
return soqs
526526

527527
def on_classical_vals(

qualtran/bloqs/for_testing/with_decomposition.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, TYPE_CHECKING
16+
from typing import Dict
1717

1818
from attrs import frozen
1919

20-
from qualtran import Bloq, BloqBuilder, Signature, Soquet
20+
from qualtran import Bloq, BloqBuilder, Signature, SoquetT
2121
from qualtran.bloqs.for_testing.atom import TestAtom
2222

23-
if TYPE_CHECKING:
24-
from qualtran import SoquetT
25-
2623

2724
@frozen
2825
class TestSerialCombo(Bloq):
@@ -47,7 +44,7 @@ def signature(self) -> Signature:
4744
return Signature.build(reg=3)
4845

4946
def build_composite_bloq(self, bb: 'BloqBuilder', reg: 'SoquetT') -> Dict[str, 'SoquetT']:
50-
assert isinstance(reg, Soquet)
47+
assert BloqBuilder.is_single(reg)
5148
reg = bb.split(reg)
5249
for i in range(len(reg)):
5350
reg[i] = bb.add(TestAtom(), q=reg[i])

0 commit comments

Comments
 (0)