Skip to content

Commit 992afe8

Browse files
Signature.build with flexible typing and shape syntax (#1826)
- a new `ShapedQCDType` capable of specifying quantum arrays natively through subscripting (e.g., `QInt(8)[20]`) - upgraded `Signature.build` syntactic sugar. `Signature.build` can now seamlessly handle fully instantiated types, shaped types, positional lists of registers/signatures, and granular side directives via tuples. fixes #1769 ```python >>> from qualtran import QBit, QUInt >>> sig = Signature.build(ctrl=QBit()[5, 5], system=QUInt(32)) >>> sig == Signature([ ... Register('ctrl', QBit(), shape=(5, 5)), ... Register('system', QUInt(32)) ... ]) True >>> sig = Signature.build(a=(QBit(), QBit()), b=(None, QBit())) >>> sig == Signature([ ... Register('a', QBit(), side=Side.THRU), ... Register('b', QBit(), side=Side.RIGHT) ... ]) True ``` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 857d6b5 commit 992afe8

3 files changed

Lines changed: 264 additions & 17 deletions

File tree

qualtran/_infra/data_types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
List,
2727
Optional,
2828
Sequence,
29+
Tuple,
2930
TYPE_CHECKING,
3031
TypeVar,
3132
Union,
@@ -155,6 +156,14 @@ def assert_valid_val_array(self, val_array: NDArray, debug_str: str = 'val') ->
155156
self.qdtype.assert_valid_classical_val(val)
156157

157158

159+
@attrs.frozen
160+
class ShapedQCDType:
161+
qcdtype: 'QCDType'
162+
shape: Tuple[int, ...] = attrs.field(
163+
default=tuple(), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
164+
)
165+
166+
158167
class QCDType(Generic[T], metaclass=abc.ABCMeta):
159168
"""The abstract interface for quantum/classical quantum computing data types."""
160169

@@ -245,6 +254,10 @@ def iteration_length_or_zero(self) -> SymbolicInt:
245254
# TODO: remove https://github.com/quantumlib/Qualtran/issues/1716
246255
return getattr(self, 'iteration_length', 0)
247256

257+
def __getitem__(self, shape):
258+
"""QInt(8)[20] returns a size-20 array of QInt(8)"""
259+
return ShapedQCDType(qcdtype=self, shape=shape)
260+
248261
@classmethod
249262
def _pkg_(cls):
250263
return 'qualtran'

qualtran/_infra/registers.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union
2222

2323
import attrs
24-
import sympy
2524
from attrs import field, frozen
2625

2726
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicInt
2827

29-
from .data_types import QAny, QBit, QCDType
28+
from .data_types import QAny, QBit, QCDType, ShapedQCDType
3029

3130

3231
class Side(enum.Flag):
@@ -53,6 +52,12 @@ def __repr__(self):
5352
return f'{self.__class__.__name__}.{self._name_}'
5453

5554

55+
def _consume_register_dtype(dtype: Union[QCDType, ShapedQCDType]) -> QCDType:
56+
# In __attrs_post_init__, we actually handle the ShapedQCDType case, which isn't accounted
57+
# for in attrs type checking.
58+
return cast(QCDType, dtype)
59+
60+
5661
@frozen
5762
class Register:
5863
"""A register serves as the input/output quantum data specifications in a bloq's `Signature`.
@@ -72,7 +77,7 @@ class Register:
7277
"""
7378

7479
name: str
75-
dtype: QCDType
80+
dtype: QCDType = field(converter=_consume_register_dtype)
7681
_shape: Tuple[SymbolicInt, ...] = field(
7782
default=tuple(), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
7883
)
@@ -86,6 +91,15 @@ def __repr__(self):
8691
return f'Register({self.name!r}, dtype={self.dtype!r}, shape={self._shape!r}, side={self.side!r})'
8792

8893
def __attrs_post_init__(self):
94+
if isinstance(self.dtype, ShapedQCDType):
95+
if self._shape != ():
96+
raise ValueError(
97+
f"for Register {self.name}, use either a shaped dtype {self.dtype} "
98+
f"or an explicit shape argument {self._shape}, not both."
99+
)
100+
object.__setattr__(self, '_shape', self.dtype.shape)
101+
object.__setattr__(self, 'dtype', self.dtype.qcdtype)
102+
89103
if not isinstance(self.dtype, QCDType):
90104
raise ValueError(f'dtype must be a QCDType: found {type(self.dtype)}')
91105

@@ -193,13 +207,15 @@ def __init__(self, registers: Iterable[Register]):
193207
self._rights = _dedupe((reg.name, reg) for reg in self._registers if reg.side & Side.RIGHT)
194208

195209
@classmethod
196-
def build(cls, **registers: Union[int, sympy.Expr]) -> 'Signature':
197-
"""Construct a Signature comprised of untyped thru registers of the given bitsizes.
210+
def build(cls, *args, **kwargs) -> 'Signature':
211+
"""Construct a Signature using a more natural syntax.
198212
199-
For rapid prorotyping or simple gates, this syntactic sugar can be used.
213+
This builder constructs a `Signature` flexibly from a mix of types, positional elements,
214+
and named keyword arguments. For rapid prototyping or simple gates, you can quickly define
215+
registers without manually instantiating `Register` objects.
200216
201217
Examples:
202-
The following constructors are equivalent
218+
The following constructors are equivalent:
203219
204220
>>> sig1 = Signature.build(a=32, b=1)
205221
>>> sig2 = Signature([
@@ -209,13 +225,104 @@ def build(cls, **registers: Union[int, sympy.Expr]) -> 'Signature':
209225
>>> sig1 == sig2
210226
True
211227
228+
We can also build signatures with fully instantiated `QCDType` arguments, including
229+
shaped multidimensional registers:
230+
231+
>>> from qualtran import QBit, QUInt
232+
>>> sig = Signature.build(ctrl=QBit()[5, 5], system=QUInt(32))
233+
>>> sig == Signature([
234+
... Register('ctrl', QBit(), shape=(5, 5)),
235+
... Register('system', QUInt(32))
236+
... ])
237+
True
238+
239+
Left and Right registers can be specified with a 2-tuple `(LEFT, RIGHT)`.
240+
Here, we allocate `b` as a right register.
241+
242+
>>> sig = Signature.build(a=(QBit(), QBit()), b=(None, QBit()))
243+
>>> sig == Signature([
244+
... Register('a', QBit(), side=Side.THRU),
245+
... Register('b', QBit(), side=Side.RIGHT)
246+
... ])
247+
True
248+
249+
Positional arguments can be used to join previously defined components:
250+
251+
>>> sig1 = Signature.build(a=1)
252+
>>> extra = [Register('c', QAny(5))]
253+
>>> sig2 = Signature.build(sig1, extra)
254+
212255
Args:
213-
**registers: Keyword arguments mapping register names to bitsizes. All registers
214-
will be 0-dimensional, THRU, and of type QAny/QBit.
256+
*args: Positional arguments must be instances of `Register`, `Signature`, or iterables
257+
thereof, which will be concatenated in order of layout.
258+
**kwargs: Keyword arguments mapping register names to data types or sizes.
259+
Values can be integer bitsizes (where 1 maps to `QBit` and n to `QAny(n)`),
260+
`QCDType` instances, `ShapedQCDType` instances, or 2-tuples of
261+
`(left_dtype, right_dtype)` to explicitly specify sides.
215262
"""
216-
return cls(
217-
Register(name=k, dtype=QBit() if v == 1 else QAny(v)) for k, v in registers.items() if v
218-
)
263+
if args and kwargs:
264+
raise ValueError(
265+
f"When using `Signature.build`, you must either specify a mapping "
266+
f"from register names to data types or positional Signature and "
267+
f"Register arguments, not both. Found positional {args} and keyword {kwargs}"
268+
)
269+
270+
registers = []
271+
272+
def _flat_add(arg):
273+
# add positional Signature, Register, or iterables thereof.
274+
nonlocal registers
275+
if isinstance(arg, Register):
276+
registers.append(arg)
277+
elif isinstance(arg, Signature):
278+
registers.extend(arg)
279+
elif isinstance(arg, Iterable) and not isinstance(arg, str):
280+
for a2 in arg:
281+
_flat_add(a2)
282+
else:
283+
raise ValueError(
284+
f"Unknown type for positional argument to Signature.build: {arg!r}"
285+
)
286+
287+
if args:
288+
for arg in args:
289+
_flat_add(arg)
290+
return cls(registers)
291+
292+
for k, v in kwargs.items():
293+
if not v:
294+
continue
295+
296+
if isinstance(v, (QCDType, ShapedQCDType)):
297+
registers.append(Register(name=k, dtype=v))
298+
elif isinstance(v, tuple):
299+
if len(v) != 2:
300+
raise ValueError(
301+
f"When using Signature.build with a tuple of data types, "
302+
f"you must specify a tuple of length 2. For LEFT registers, "
303+
f"the tuple is (dtype, None). For RIGHT registers, "
304+
f"the tuple is (None, dtype). You provided {v}"
305+
)
306+
ldt, rdt = v
307+
if ldt is not None:
308+
registers.append(Register(name=k, dtype=ldt, side=Side.LEFT))
309+
if rdt is not None:
310+
registers.append(Register(name=k, dtype=rdt, side=Side.RIGHT))
311+
312+
elif isinstance(v, (Register, Signature)):
313+
# mild defensiveness against common errors, but duck typing in the `else` clause.
314+
raise ValueError(
315+
f"Invalid data type for Signature.build keyword argument '{k}': {v}"
316+
)
317+
else:
318+
dt: QCDType
319+
if v == 1:
320+
dt = QBit()
321+
else:
322+
dt = QAny(v)
323+
registers.append(Register(name=k, dtype=dt))
324+
325+
return cls(registers)
219326

220327
@classmethod
221328
def build_from_dtypes(cls, **registers: QCDType) -> 'Signature':
@@ -323,12 +430,10 @@ def __repr__(self):
323430
return f'Signature({repr(self._registers)})'
324431

325432
@overload
326-
def __getitem__(self, key: int) -> Register:
327-
pass
433+
def __getitem__(self, key: int) -> Register: ...
328434

329435
@overload
330-
def __getitem__(self, key: slice) -> Tuple[Register, ...]:
331-
pass
436+
def __getitem__(self, key: slice) -> Tuple[Register, ...]: ...
332437

333438
def __getitem__(self, key):
334439
return self._registers[key]

qualtran/_infra/registers_test.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,87 @@ def test_signature_build():
155155
sig1 = Signature([Register("r1", QAny(5)), Register("r2", QAny(2))])
156156
sig2 = Signature.build(r1=5, r2=2)
157157
assert sig1 == sig2
158-
assert sig1.n_qubits() == 7
158+
assert sig2.n_qubits() == 7
159+
160+
161+
def test_signature_build_drops_falsey():
162+
should_be = Signature([Register('x', QBit())])
163+
assert Signature.build(x=1, y=0) == should_be
164+
assert Signature.build(x=1, y=None) == should_be
165+
166+
167+
def test_signature_build_dtypes():
168+
should_be = Signature([Register('system', QUInt(8))])
169+
assert Signature.build(system=QUInt(8)) == should_be
170+
171+
172+
def test_signature_build_shaped():
173+
should_be = Signature([Register('qubits', QBit(), shape=(5, 5))])
174+
assert Signature.build(qubits=QBit()[5, 5]) == should_be
175+
176+
should_be = Signature([Register('ctrl', QBit()), Register('ints', QInt(8), shape=(5,))])
177+
assert Signature.build(ctrl=1, ints=QInt(8)[5]) == should_be
178+
179+
180+
def test_signature_build_sided():
181+
should_be = Signature(
182+
[Register('x_in', QAny(3), side=Side.LEFT), Register('x_out', QAny(3), side=Side.RIGHT)]
183+
)
184+
assert Signature.build(x_in=(QAny(3), None), x_out=(None, QAny(3))) == should_be
185+
186+
187+
def test_signature_build_grouped_sided():
188+
should_be = Signature(
189+
[Register('x', QAny(3), side=Side.LEFT), Register('x', QBit(), shape=(3,), side=Side.RIGHT)]
190+
)
191+
assert Signature.build(x=(QAny(3), QBit()[3])) == should_be
192+
193+
194+
def test_signature_build_signature():
195+
should_be = Signature(
196+
[Register('x', QAny(3), side=Side.LEFT), Register('x', QBit(), shape=(3,), side=Side.RIGHT)]
197+
)
198+
assert Signature.build(should_be) == should_be
199+
200+
201+
def test_signature_build_registers():
202+
should_be = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
203+
assert Signature.build(Register('ctrl', QBit()), Register('system', QAny(5))) == should_be
204+
205+
206+
def test_signature_build_signature_registers():
207+
should_be = Signature(
208+
[
209+
Register('ctrl', QBit()),
210+
Register('system', QAny(5)),
211+
Register('x_in', QAny(3), side=Side.LEFT),
212+
Register('x_out', QAny(3), side=Side.RIGHT),
213+
Register('x', QBit()),
214+
]
215+
)
216+
217+
first_signature = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
218+
regs = [Register('x_in', QAny(3), side=Side.LEFT), Register('x_out', QAny(3), side=Side.RIGHT)]
219+
last_signature = Signature([Register('x', QBit())])
220+
assert Signature.build(first_signature, regs, last_signature) == should_be
221+
222+
223+
def test_signature_build_mixed_args_kwargs():
224+
first_signature = Signature([Register('ctrl', QBit()), Register('system', QAny(5))])
225+
with pytest.raises(ValueError, match=r'either.*not both.*'):
226+
Signature.build(first_signature, y=QBit())
227+
228+
229+
def test_signature_build_kwregs():
230+
with pytest.raises(ValueError, match=r"Invalid data type.*'x'.*"):
231+
Signature.build(x=Register('x', QBit()))
232+
233+
234+
def test_signature_build_from_dtypes():
159235
sig1 = Signature([Register("r1", QInt(7)), Register("r2", QBit())])
160236
sig2 = Signature.build_from_dtypes(r1=QInt(7), r2=QBit())
161237
assert sig1 == sig2
238+
162239
sig1 = Signature([Register("r1", QInt(7))])
163240
sig2 = Signature.build_from_dtypes(r1=QInt(7), r2=QAny(0))
164241
assert sig1 == sig2
@@ -235,3 +312,55 @@ def test_is_symbolic():
235312
assert is_symbolic(r)
236313
r = Register("my_reg", QAny(2), shape=sympy.symbols("x y"))
237314
assert is_symbolic(r)
315+
316+
317+
def test_register_pkg():
318+
assert Register._pkg_() == 'qualtran'
319+
320+
321+
def test_register_shape_error():
322+
with pytest.raises(ValueError, match="use either a shaped dtype.*or an explicit shape"):
323+
Register("my_reg", QBit()[2], shape=(2,))
324+
325+
326+
def test_register_invalid_dtype():
327+
with pytest.raises(ValueError, match="dtype must be a QCDType"):
328+
Register("my_reg", 5) # type: ignore
329+
330+
331+
def test_register_adjoint_side():
332+
r2 = Register("my_reg", QBit(), side=Side.RIGHT)
333+
assert r2.adjoint().side == Side.LEFT
334+
335+
r3 = Register("my_reg", QBit(), side=Side.LEFT)
336+
assert r3.adjoint().side == Side.RIGHT
337+
338+
339+
def test_signature_build_positional_errors():
340+
with pytest.raises(ValueError, match="Unknown type for positional argument"):
341+
Signature.build("not_a_register_or_signature")
342+
343+
344+
def test_signature_build_tuple_error():
345+
with pytest.raises(ValueError, match="you must specify a tuple of length 2"):
346+
Signature.build(a=(QBit(),))
347+
348+
349+
def test_signature_thru_registers_only():
350+
sig = Signature.build(a=1)
351+
assert sig.thru_registers_only
352+
sig2 = Signature([Register('a', QBit(), side=Side.LEFT)])
353+
assert not sig2.thru_registers_only
354+
355+
356+
def test_signature_get_left_right():
357+
sig = Signature([Register('a', QBit(), side=Side.LEFT), Register('b', QBit(), side=Side.RIGHT)])
358+
assert sig.get_left('a').name == 'a'
359+
assert sig.get_right('b').name == 'b'
360+
361+
362+
def test_signature_contains_and_hash():
363+
r = Register('a', QBit())
364+
sig = Signature([r])
365+
assert r in sig
366+
assert hash(sig) == hash(Signature([Register('a', QBit())]))

0 commit comments

Comments
 (0)