diff --git a/pyproject.toml b/pyproject.toml index 7b1548955..f8e976880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ skip_glob = ["qualtran/protos/*"] [tool.pytest.ini_options] filterwarnings = [ 'ignore::DeprecationWarning:quimb.linalg.approx_spectral:', + 'ignore::qualtran.bloqs.bookkeeping.partition.LegacyPartitionWarning', 'ignore:.*standard platformdirs.*:DeprecationWarning:jupyter_client.*' ] # we define classes like TestBloq etc. which pytest tries to collect, diff --git a/qualtran/bloqs/bookkeeping/partition.ipynb b/qualtran/bloqs/bookkeeping/partition.ipynb index 65ff47688..9080931b6 100644 --- a/qualtran/bloqs/bookkeeping/partition.ipynb +++ b/qualtran/bloqs/bookkeeping/partition.ipynb @@ -39,8 +39,9 @@ "Partition a generic index into multiple registers.\n", "\n", "#### Parameters\n", - " - `n`: The total bitsize of the un-partitioned register\n", + " - `n`: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible.\n", " - `regs`: Registers to partition into. The `side` attribute is ignored.\n", + " - `dtype_in`: Type of the un-partitioned register. Required if `n` is None. If None, the type is inferred as `QUInt(n)`.\n", " - `partition`: `False` means un-partition instead. \n", "\n", "#### Registers\n", diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 8cebbb580..b0440f25c 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import warnings from functools import cached_property -from typing import Dict, List, Sequence, Tuple, TYPE_CHECKING +from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np import sympy @@ -29,6 +30,7 @@ DecomposeTypeError, QAny, QDType, + QUInt, Register, Side, Signature, @@ -46,16 +48,45 @@ from qualtran.simulation.classical_sim import ClassicalValT +class LegacyPartitionWarning(DeprecationWarning): + """Warnings for legacy Partition usage, when declaring only n.""" + + +def _constrain_qany_reg(reg: Register): + """Changes the dtype of a register to note break legacy code + + This function should be bound to dissapear + """ + if isinstance(reg.dtype, QAny): + warnings.warn( + f"Doing classical casting with QAny ({reg=}) is ambiguous, transforming it as QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + return evolve(reg, dtype=QUInt(reg.dtype.bitsize)) + return reg + + +def _regs_to_tuple(x): + if x is None: + return None + return x if isinstance(x, tuple) else tuple(x) + + +def _not_none(_inst, attr, value): + if value is None: + raise ValueError(f"{attr.name} cannot be None") + + class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): """Generalized paritioning functionality.""" @property @abc.abstractmethod - def n(self) -> SymbolicInt: ... + def n(self) -> Optional[SymbolicInt]: ... - @cached_property - def lumped_dtype(self) -> QDType: - return QAny(bitsize=self.n) + @property + @abc.abstractmethod + def lumped_dtype(self) -> QDType: ... @property @abc.abstractmethod @@ -98,6 +129,8 @@ def my_tensors( ) -> List['qtn.Tensor']: import quimb.tensor as qtn + if self.n is None: + raise DecomposeTypeError(f"cannot compute tensors with unknown n for {self}") if is_symbolic(self.n): raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}") @@ -124,6 +157,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] xbits = self.lumped_dtype.to_bits(x) start = 0 for reg in self._regs: + reg = _constrain_qany_reg(reg) size = int(np.prod(reg.shape + (reg.bitsize,))) bits_reg = xbits[start : start + size] if reg.shape == (): @@ -138,6 +172,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]: out_vals: list[NDArray[np.uint8]] = [] for reg in self._regs: + reg = _constrain_qany_reg(reg) reg_val = np.asanyarray(vals[reg.name]) bitstrings = reg.dtype.to_bits_array(reg_val.ravel()) out_vals.append(bitstrings.ravel()) @@ -166,8 +201,11 @@ class Partition(_PartitionBase): """Partition a generic index into multiple registers. Args: - n: The total bitsize of the un-partitioned register + n: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. + Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible. regs: Registers to partition into. The `side` attribute is ignored. + dtype_in: Type of the un-partitioned register. Required if `n` is None. If None, + the type is inferred as `QUInt(n)`. partition: `False` means un-partition instead. Registers: @@ -175,18 +213,42 @@ class Partition(_PartitionBase): [user spec]: The registers provided by the `regs` argument. RIGHT by default. """ - n: SymbolicInt - regs: Tuple[Register, ...] = field( - converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1) + n: Optional[SymbolicInt] = field(default=None) + regs: Optional[Tuple[Register, ...]] = field( + converter=_regs_to_tuple, validator=(_not_none, validators.min_len(1)), default=None ) - partition: bool = True + dtype_in: Optional[QDType] = field(default=None) + partition: bool = field(default=True) def __attrs_post_init__(self): + if self.n is None and self.dtype_in is None: + raise ValueError(f"Provide exactly n or dtype_in {self.n=}, {self.dtype_in=}") + elif self.n is not None and self.dtype_in is None: + warnings.warn( + "Partition: By not setting dtype_in you could encounter errors when running " + "assert_consistent_classical_action", + category=LegacyPartitionWarning, + ) + elif self.n is None and self.dtype_in is not None: + object.__setattr__(self, "n", self.dtype_in.num_qubits) + elif self.n is not None and self.dtype_in is not None: + if self.n != self.dtype_in.num_qubits: + raise ValueError( + f"{self.dtype_in=} should have size {self.n=}, currently {self.dtype_in.num_qubits=}" + ) + warnings.warn( + "Specifying both n and dtype_in is redundant", category=UserWarning, stacklevel=1 + ) + self._validate() + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=cast(SymbolicInt, self.n)) if self.dtype_in is None else self.dtype_in + @property def _regs(self) -> Sequence[Register]: - return self.regs + return cast(Tuple[Register, ...], self.regs) @cached_property def signature(self) -> 'Signature': @@ -195,11 +257,11 @@ def signature(self) -> 'Signature': return Signature( [Register('x', self.lumped_dtype, side=lumped)] - + [evolve(reg, side=partitioned) for reg in self.regs] + + [evolve(reg, side=partitioned) for reg in self._regs] ) def adjoint(self): - return evolve(self, partition=not self.partition) + return evolve(self, n=None, dtype_in=self.lumped_dtype, partition=not self.partition) @frozen @@ -228,6 +290,10 @@ class Split2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return True @@ -289,6 +355,10 @@ class Join2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return False diff --git a/qualtran/bloqs/bookkeeping/partition_test.py b/qualtran/bloqs/bookkeeping/partition_test.py index c6c323633..7be26b864 100644 --- a/qualtran/bloqs/bookkeeping/partition_test.py +++ b/qualtran/bloqs/bookkeeping/partition_test.py @@ -20,7 +20,7 @@ import pytest from attrs import frozen -from qualtran import Bloq, BloqBuilder, QAny, QGF, Register, Signature, Soquet, SoquetT +from qualtran import Bloq, BloqBuilder, QAny, QGF, QInt, QUInt, Register, Signature, Soquet, SoquetT from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import CNOT from qualtran.bloqs.bookkeeping import Partition @@ -37,10 +37,22 @@ def test_partition(bloq_autotester): def test_partition_check(): with pytest.raises(ValueError): _ = Partition(n=0, regs=()) + with pytest.raises(ValueError): + _ = Partition(n=10, regs=None) + with pytest.raises(ValueError): + _ = Partition(dtype_in=QUInt(10)) with pytest.raises(ValueError): _ = Partition(n=1, regs=(Register('x', QAny(2)),)) with pytest.raises(ValueError): _ = Partition(n=4, regs=(Register('x', QAny(1)), Register('x', QAny(3)))) + with pytest.raises(ValueError): + _ = Partition(n=10) + + regs = (Register("xx", QUInt(4)), Register("yy", QInt(6))) + with pytest.raises(ValueError): + _ = Partition(regs=regs) + with pytest.raises(ValueError): + _ = Partition(n=11, regs=regs) @frozen @@ -57,7 +69,7 @@ def signature(self) -> Signature: def build_composite_bloq(self, bb: 'BloqBuilder', test_regs: 'SoquetT') -> Dict[str, 'Soquet']: bloq_regs = self.test_bloq.signature - partition = Partition(self.bitsize, bloq_regs) # type: ignore[arg-type] + partition = Partition(dtype_in=QUInt(self.bitsize), regs=bloq_regs) # type: ignore[arg-type] out_regs = bb.add(partition, x=test_regs) out_regs = bb.add(self.test_bloq, **{reg.name: sp for reg, sp in zip(bloq_regs, out_regs)}) test_regs = bb.add(