Skip to content

Commit 9db3b44

Browse files
authored
Add is_symbolic to data types and registers. (#942)
* Move symbolic_counting_utils to a new module * Add is_symbolic to data types and registers * Fix mypy errors * Fix tests
1 parent 717c18f commit 9db3b44

5 files changed

Lines changed: 80 additions & 17 deletions

File tree

qualtran/_infra/data_types.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@
5454

5555
import attrs
5656
import numpy as np
57-
import sympy
5857
from fxpmath import Fxp
5958
from numpy.typing import NDArray
6059

60+
from qualtran.symbolics import is_symbolic, SymbolicInt
61+
6162

6263
class QDType(metaclass=abc.ABCMeta):
6364
"""This defines the abstract interface for quantum data types."""
@@ -89,7 +90,11 @@ def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
8990
debug_str: Optional debugging information to use in exception messages.
9091
"""
9192

92-
def iteration_length_or_zero(self) -> Union[int, sympy.Expr]:
93+
@abc.abstractmethod
94+
def is_symbolic(self) -> bool:
95+
"""Returns True if this qdtype is parameterized with symbolic objects."""
96+
97+
def iteration_length_or_zero(self) -> SymbolicInt:
9398
"""Safe version of iteration length.
9499
95100
Returns the iteration_length if the type has it or else zero.
@@ -130,6 +135,9 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
130135
if not (val == 0 or val == 1):
131136
raise ValueError(f"Bad {self} value {val} in {debug_str}")
132137

138+
def is_symbolic(self) -> bool:
139+
return False
140+
133141
def to_bits(self, x) -> List[int]:
134142
"""Yields individual bits corresponding to binary representation of x"""
135143
self.assert_valid_classical_val(x)
@@ -154,7 +162,7 @@ def __str__(self):
154162
class QAny(QDType):
155163
"""Opaque bag-of-qbits type."""
156164

157-
bitsize: Union[int, sympy.Expr]
165+
bitsize: SymbolicInt
158166

159167
@property
160168
def num_qubits(self):
@@ -171,6 +179,9 @@ def from_bits(self, bits: Sequence[int]) -> int:
171179
# TODO: Raise an error once usage of `QAny` is minimized across the library
172180
return QUInt(self.bitsize).from_bits(bits)
173181

182+
def is_symbolic(self) -> bool:
183+
return is_symbolic(self.bitsize)
184+
174185
def assert_valid_classical_val(self, val, debug_str: str = 'val'):
175186
pass
176187

@@ -188,12 +199,15 @@ class QInt(QDType):
188199
bitsize: The number of qubits used to represent the integer.
189200
"""
190201

191-
bitsize: Union[int, sympy.Expr]
202+
bitsize: SymbolicInt
192203

193204
@property
194205
def num_qubits(self):
195206
return self.bitsize
196207

208+
def is_symbolic(self) -> bool:
209+
return is_symbolic(self.bitsize)
210+
197211
def get_classical_domain(self) -> Iterable[int]:
198212
max_val = 1 << (self.bitsize - 1)
199213
return range(-max_val, max_val)
@@ -240,7 +254,7 @@ class QIntOnesComp(QDType):
240254
bitsize: The number of qubits used to represent the integer.
241255
"""
242256

243-
bitsize: Union[int, sympy.Expr]
257+
bitsize: SymbolicInt
244258

245259
def __attrs_post_init__(self):
246260
if isinstance(self.bitsize, int):
@@ -251,6 +265,9 @@ def __attrs_post_init__(self):
251265
def num_qubits(self):
252266
return self.bitsize
253267

268+
def is_symbolic(self) -> bool:
269+
return is_symbolic(self.bitsize)
270+
254271
def to_bits(self, x: int) -> List[int]:
255272
"""Yields individual bits corresponding to binary representation of x"""
256273
self.assert_valid_classical_val(x)
@@ -286,12 +303,15 @@ class QUInt(QDType):
286303
bitsize: The number of qubits used to represent the integer.
287304
"""
288305

289-
bitsize: Union[int, sympy.Expr]
306+
bitsize: SymbolicInt
290307

291308
@property
292309
def num_qubits(self):
293310
return self.bitsize
294311

312+
def is_symbolic(self) -> bool:
313+
return is_symbolic(self.bitsize)
314+
295315
def get_classical_domain(self) -> Iterable[Any]:
296316
return range(2**self.bitsize)
297317

@@ -371,11 +391,11 @@ class BoundedQUInt(QDType):
371391
iteration_length: The length of the iteration range.
372392
"""
373393

374-
bitsize: Union[int, sympy.Expr]
375-
iteration_length: Union[int, sympy.Expr] = attrs.field()
394+
bitsize: SymbolicInt
395+
iteration_length: SymbolicInt = attrs.field()
376396

377397
def __attrs_post_init__(self):
378-
if isinstance(self.bitsize, int):
398+
if not self.is_symbolic():
379399
if self.iteration_length > 2**self.bitsize:
380400
raise ValueError(
381401
"BoundedQUInt iteration length is too large for given bitsize. "
@@ -386,6 +406,9 @@ def __attrs_post_init__(self):
386406
def _default_iteration_length(self):
387407
return 2**self.bitsize
388408

409+
def is_symbolic(self) -> bool:
410+
return is_symbolic(self.bitsize, self.iteration_length)
411+
389412
@property
390413
def num_qubits(self):
391414
return self.bitsize
@@ -446,16 +469,16 @@ class QFxp(QDType):
446469
number of integer bits is reduced by 1.
447470
"""
448471

449-
bitsize: Union[int, sympy.Expr]
450-
num_frac: Union[int, sympy.Expr]
472+
bitsize: SymbolicInt
473+
num_frac: SymbolicInt
451474
signed: bool = False
452475

453476
@property
454477
def num_qubits(self):
455478
return self.bitsize
456479

457480
@property
458-
def num_int(self) -> Union[int, sympy.Expr]:
481+
def num_int(self) -> SymbolicInt:
459482
return self.bitsize - self.num_frac - int(self.signed)
460483

461484
@property
@@ -466,6 +489,9 @@ def fxp_dtype_str(self) -> str:
466489
def _fxp_dtype(self) -> Fxp:
467490
return Fxp(None, dtype=self.fxp_dtype_str)
468491

492+
def is_symbolic(self) -> bool:
493+
return is_symbolic(self.bitsize, self.num_frac)
494+
469495
def to_bits(self, x: Union[float, Fxp]) -> List[int]:
470496
"""Yields individual bits corresponding to binary representation of x"""
471497
self._assert_valid_classical_val(x)
@@ -539,12 +565,15 @@ class QMontgomeryUInt(QDType):
539565
[Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication)
540566
"""
541567

542-
bitsize: Union[int, sympy.Expr]
568+
bitsize: SymbolicInt
543569

544570
@property
545571
def num_qubits(self):
546572
return self.bitsize
547573

574+
def is_symbolic(self) -> bool:
575+
return is_symbolic(self.bitsize)
576+
548577
def get_classical_domain(self) -> Iterable[Any]:
549578
return range(2**self.bitsize)
550579

qualtran/_infra/data_types_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pytest
1717
import sympy
1818

19+
from qualtran.symbolics import is_symbolic
20+
1921
from .data_types import (
2022
BoundedQUInt,
2123
check_dtypes_consistent,
@@ -39,6 +41,7 @@ def test_qint():
3941
qint_8 = QInt(n)
4042
assert qint_8.num_qubits == n
4143
assert str(qint_8) == 'QInt(x)'
44+
assert is_symbolic(QInt(sympy.Symbol('x')))
4245

4346

4447
def test_qint_ones():
@@ -50,6 +53,7 @@ def test_qint_ones():
5053
n = sympy.symbols('x')
5154
qint_8 = QIntOnesComp(n)
5255
assert qint_8.num_qubits == n
56+
assert is_symbolic(QIntOnesComp(sympy.Symbol('x')))
5357

5458

5559
def test_quint():
@@ -62,6 +66,7 @@ def test_quint():
6266
n = sympy.symbols('x')
6367
qint_8 = QUInt(n)
6468
assert qint_8.num_qubits == n
69+
assert is_symbolic(QUInt(sympy.Symbol('x')))
6570

6671

6772
def test_bounded_quint():
@@ -77,6 +82,9 @@ def test_bounded_quint():
7782
qint_8 = BoundedQUInt(n, l)
7883
assert qint_8.num_qubits == n
7984
assert qint_8.iteration_length == l
85+
assert is_symbolic(BoundedQUInt(sympy.Symbol('x'), 2))
86+
assert is_symbolic(BoundedQUInt(2, sympy.Symbol('x')))
87+
assert is_symbolic(BoundedQUInt(*sympy.symbols('x y')))
8088

8189

8290
def test_qfxp():
@@ -105,6 +113,7 @@ def test_qfxp():
105113
qfp = QFxp(b, f, True)
106114
assert qfp.num_qubits == b
107115
assert qfp.num_int == b - f - 1
116+
assert is_symbolic(QFxp(*sympy.symbols('x y')))
108117

109118

110119
def test_qmontgomeryuint():
@@ -116,6 +125,7 @@ def test_qmontgomeryuint():
116125
n = sympy.symbols('x')
117126
qmontgomeryuint_8 = QMontgomeryUInt(n)
118127
assert qmontgomeryuint_8.num_qubits == n
128+
assert is_symbolic(QMontgomeryUInt(sympy.Symbol('x')))
119129

120130

121131
@pytest.mark.parametrize('qdtype', [QBit(), QInt(4), QUInt(4), BoundedQUInt(3, 5)])

qualtran/_infra/registers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import enum
1717
import itertools
1818
from collections import defaultdict
19-
from typing import Dict, Iterable, Iterator, List, overload, Tuple, Union
19+
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union
2020

2121
import attrs
2222
import numpy as np
2323
import sympy
2424
from attrs import field, frozen
2525

26+
from qualtran.symbolics import is_symbolic, SymbolicInt
27+
2628
from .data_types import QAny, QBit, QDType
2729

2830

@@ -63,7 +65,7 @@ class Register:
6365

6466
name: str
6567
dtype: QDType
66-
shape: Tuple[int, ...] = field(
68+
_shape: Tuple[SymbolicInt, ...] = field(
6769
default=tuple(), converter=lambda v: (v,) if isinstance(v, int) else tuple(v)
6870
)
6971
side: Side = Side.THRU
@@ -72,6 +74,19 @@ def __attrs_post_init__(self):
7274
if not isinstance(self.dtype, QDType):
7375
raise ValueError(f'dtype must be a QDType: found {type(self.dtype)}')
7476

77+
def is_symbolic(self) -> bool:
78+
return is_symbolic(self.dtype, *self._shape)
79+
80+
@property
81+
def shape_symbolic(self) -> Tuple[SymbolicInt, ...]:
82+
return self._shape
83+
84+
@property
85+
def shape(self) -> Tuple[int, ...]:
86+
if is_symbolic(*self._shape):
87+
raise ValueError(f"{self} is symbolic. Cannot get real-valued shape.")
88+
return cast(Tuple[int, ...], self._shape)
89+
7590
@property
7691
def bitsize(self) -> int:
7792
return self.dtype.num_qubits

qualtran/_infra/registers_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import cirq
1616
import numpy as np
1717
import pytest
18+
import sympy
1819

1920
from qualtran import BoundedQUInt, QAny, QBit, QInt, Register, Side, Signature
2021
from qualtran._infra.gate_with_registers import get_named_qubits
22+
from qualtran.symbolics import is_symbolic
2123

2224

2325
def test_register():
@@ -193,6 +195,13 @@ def test_dtypes_converter():
193195
r1 = Register("my_reg", QBit())
194196
r2 = Register("my_reg", QBit())
195197
assert r1 == r2
196-
r2 = Register("my_reg", QAny(5))
198+
r1 = Register("my_reg", QAny(5))
197199
r2 = Register("my_reg", QInt(5))
198200
assert r1 != r2
201+
202+
203+
def test_is_symbolic():
204+
r = Register("my_reg", QAny(sympy.Symbol("x")))
205+
assert is_symbolic(r)
206+
r = Register("my_reg", QAny(2), shape=sympy.symbols("x y"))
207+
assert is_symbolic(r)

qualtran/serialization/registers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def register_to_proto(register: Register) -> registers_pb2.Register:
3232
return registers_pb2.Register(
3333
name=register.name,
3434
dtype=data_type_to_proto(register.dtype),
35-
shape=(args.int_or_sympy_to_proto(s) for s in register.shape),
35+
shape=(args.int_or_sympy_to_proto(s) for s in register.shape_symbolic),
3636
side=_side_to_proto(register.side),
3737
)
3838

0 commit comments

Comments
 (0)