From eb60089ec5daa4fa8bccd559b53bd15f827c2e2f Mon Sep 17 00:00:00 2001 From: tachard Date: Fri, 13 Feb 2026 19:40:04 +0100 Subject: [PATCH 1/8] Changes Partition to accept an input dtype --- qualtran/bloqs/bookkeeping/partition.py | 69 ++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 8cebbb5809..1452741b65 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -13,12 +13,13 @@ # limitations under the License. import abc from functools import cached_property -from typing import Dict, List, Sequence, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np import sympy from attrs import evolve, field, frozen, validators from numpy.typing import NDArray +import warnings from qualtran import ( Bloq, @@ -28,6 +29,7 @@ ConnectionT, DecomposeTypeError, QAny, + QUInt, QDType, Register, Side, @@ -46,6 +48,26 @@ from qualtran.simulation.classical_sim import ClassicalValT +class LegacyPartitionWarning(DeprecationWarning): + """Warnings for legacy Partition usage, when declaring only n.""" + + pass + + +def _constrain_qany_reg(reg: Register): + """Changes the dtype of a register to note break legacy code + + This function should be bound to dissapear + """ + if isinstance(reg.dtype, QAny): + warnings.warn( + f"Doing classical casting with QAny ({reg=}) is ambiguous, transforming it as QUInt for legacy purposes", + category=LegacyPartitionWarning, + ) + return evolve(reg, dtype=QUInt(reg.dtype.bitsize)) + return reg + + class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): """Generalized paritioning functionality.""" @@ -53,9 +75,9 @@ class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): @abc.abstractmethod def n(self) -> SymbolicInt: ... - @cached_property - def lumped_dtype(self) -> QDType: - return QAny(bitsize=self.n) + @property + @abc.abstractmethod + def lumped_dtype(self) -> QDType: ... @property @abc.abstractmethod @@ -124,6 +146,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] xbits = self.lumped_dtype.to_bits(x) start = 0 for reg in self._regs: + reg = _constrain_qany_reg(reg) size = int(np.prod(reg.shape + (reg.bitsize,))) bits_reg = xbits[start : start + size] if reg.shape == (): @@ -138,6 +161,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT'] def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]: out_vals: list[NDArray[np.uint8]] = [] for reg in self._regs: + reg = _constrain_qany_reg(reg) reg_val = np.asanyarray(vals[reg.name]) bitstrings = reg.dtype.to_bits_array(reg_val.ravel()) out_vals.append(bitstrings.ravel()) @@ -168,6 +192,7 @@ class Partition(_PartitionBase): Args: n: The total bitsize of the un-partitioned register regs: Registers to partition into. The `side` attribute is ignored. + dtype_in: Type of the un-partitioned register, this is partition: `False` means un-partition instead. Registers: @@ -175,15 +200,39 @@ class Partition(_PartitionBase): [user spec]: The registers provided by the `regs` argument. RIGHT by default. """ - n: SymbolicInt + n: Optional[SymbolicInt] regs: Tuple[Register, ...] = field( converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1) ) - partition: bool = True + dtype_in: Optional[QDType] = field(default=None) + partition: bool = field(default=True) def __attrs_post_init__(self): + match (self.n, self.dtype_in): + case (None, None): + raise ValueError("Provide exactly n or dtype_in") + case (None, dt): + object.__setattr__(self, "n", dt.num_qubits) + case (n, None): + warnings.warn( + "Partition: By not setting dtype_in you could encounter errors when running " + "assert_consistent_classical_action", + category=LegacyPartitionWarning, + ) + case (n, dt): + if n != dt.num_qubits: + raise ValueError(f"{dt=} should have size {n=}, currently {dt.num_qubits=}") + warnings.warn( + "Specifying both n and dtype_in is redundant", + category=UserWarning, + stacklevel=1, + ) self._validate() + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) if self.dtype_in is None else self.dtype_in + @property def _regs(self) -> Sequence[Register]: return self.regs @@ -228,6 +277,10 @@ class Split2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return True @@ -289,6 +342,10 @@ class Join2(_PartitionBase): def n(self) -> SymbolicInt: return self.n1 + self.n2 + @property + def lumped_dtype(self) -> QDType: + return QUInt(bitsize=self.n) + @property def partition(self) -> bool: return False From 3f3f2a7d9dcae0dbe4fae2809f3d33ad03e12184 Mon Sep 17 00:00:00 2001 From: tachard Date: Fri, 13 Feb 2026 19:44:41 +0100 Subject: [PATCH 2/8] Changes pytest settigs so it ignores LegacyPartitionWarning --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7b15489557..f8e9768800 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ skip_glob = ["qualtran/protos/*"] [tool.pytest.ini_options] filterwarnings = [ 'ignore::DeprecationWarning:quimb.linalg.approx_spectral:', + 'ignore::qualtran.bloqs.bookkeeping.partition.LegacyPartitionWarning', 'ignore:.*standard platformdirs.*:DeprecationWarning:jupyter_client.*' ] # we define classes like TestBloq etc. which pytest tries to collect, From 2405da0642e575a11402294fc984a8639e75f8fe Mon Sep 17 00:00:00 2001 From: tachard Date: Fri, 13 Feb 2026 20:10:33 +0100 Subject: [PATCH 3/8] Mypy made me add a lot of assertions + pylint --- qualtran/bloqs/bookkeeping/partition.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 1452741b65..532d00f2ef 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -51,8 +51,6 @@ class LegacyPartitionWarning(DeprecationWarning): """Warnings for legacy Partition usage, when declaring only n.""" - pass - def _constrain_qany_reg(reg: Register): """Changes the dtype of a register to note break legacy code @@ -73,7 +71,7 @@ class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): @property @abc.abstractmethod - def n(self) -> SymbolicInt: ... + def n(self) -> Optional[SymbolicInt]: ... @property @abc.abstractmethod @@ -120,6 +118,8 @@ def my_tensors( ) -> List['qtn.Tensor']: import quimb.tensor as qtn + if self.n is None: + raise DecomposeTypeError(f"cannot compute tensors with unknown n for {self}") if is_symbolic(self.n): raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}") @@ -211,15 +211,18 @@ def __attrs_post_init__(self): match (self.n, self.dtype_in): case (None, None): raise ValueError("Provide exactly n or dtype_in") - case (None, dt): - object.__setattr__(self, "n", dt.num_qubits) case (n, None): warnings.warn( "Partition: By not setting dtype_in you could encounter errors when running " "assert_consistent_classical_action", category=LegacyPartitionWarning, ) + case (None, dt): + assert dt is not None + object.__setattr__(self, "n", dt.num_qubits) case (n, dt): + assert dt is not None + assert n is not None # for mypy if n != dt.num_qubits: raise ValueError(f"{dt=} should have size {n=}, currently {dt.num_qubits=}") warnings.warn( @@ -231,7 +234,11 @@ def __attrs_post_init__(self): @property def lumped_dtype(self) -> QDType: - return QUInt(bitsize=self.n) if self.dtype_in is None else self.dtype_in + if self.dtype_in is not None: + return self.dtype_in + + assert self.n is not None + return QUInt(bitsize=self.n) @property def _regs(self) -> Sequence[Register]: From a4e79de47a7ccc7db1f72c0b19c40bf3e4142066 Mon Sep 17 00:00:00 2001 From: tachard Date: Sun, 15 Feb 2026 00:15:08 +0100 Subject: [PATCH 4/8] linter --- qualtran/bloqs/bookkeeping/partition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 532d00f2ef..ed7f103e05 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import warnings from functools import cached_property from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING @@ -19,7 +20,6 @@ import sympy from attrs import evolve, field, frozen, validators from numpy.typing import NDArray -import warnings from qualtran import ( Bloq, @@ -29,8 +29,8 @@ ConnectionT, DecomposeTypeError, QAny, - QUInt, QDType, + QUInt, Register, Side, Signature, From 4848dae2b1c3d11f473c456bfb479a886d760892 Mon Sep 17 00:00:00 2001 From: tachard Date: Sun, 15 Feb 2026 21:00:23 +0100 Subject: [PATCH 5/8] Adds test + small fixes --- qualtran/bloqs/bookkeeping/partition.py | 30 ++++++++++++++------ qualtran/bloqs/bookkeeping/partition_test.py | 16 +++++++++-- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index ed7f103e05..18791c339f 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -66,6 +66,17 @@ def _constrain_qany_reg(reg: Register): return reg +def _regs_to_tuple(x): + if x is None: + return None + return x if isinstance(x, tuple) else tuple(x) + + +def _not_none(_inst, attr, value): + if value is None: + raise ValueError(f"{attr.name} cannot be None") + + class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta): """Generalized paritioning functionality.""" @@ -190,9 +201,11 @@ class Partition(_PartitionBase): """Partition a generic index into multiple registers. Args: - n: The total bitsize of the un-partitioned register + n: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. + Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible. regs: Registers to partition into. The `side` attribute is ignored. - dtype_in: Type of the un-partitioned register, this is + dtype_in: Type of the un-partitioned register. Required if `n` is None. If None, + the type is inferred as `QUInt(n)`. partition: `False` means un-partition instead. Registers: @@ -200,9 +213,9 @@ class Partition(_PartitionBase): [user spec]: The registers provided by the `regs` argument. RIGHT by default. """ - n: Optional[SymbolicInt] - regs: Tuple[Register, ...] = field( - converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1) + n: Optional[SymbolicInt] = field(default=None) + regs: Optional[Tuple[Register, ...]] = field( + converter=_regs_to_tuple, validator=(_not_none, validators.min_len(1)), default=None ) dtype_in: Optional[QDType] = field(default=None) partition: bool = field(default=True) @@ -218,7 +231,7 @@ def __attrs_post_init__(self): category=LegacyPartitionWarning, ) case (None, dt): - assert dt is not None + assert dt is not None # for mypy object.__setattr__(self, "n", dt.num_qubits) case (n, dt): assert dt is not None @@ -230,6 +243,7 @@ def __attrs_post_init__(self): category=UserWarning, stacklevel=1, ) + self._validate() @property @@ -237,7 +251,7 @@ def lumped_dtype(self) -> QDType: if self.dtype_in is not None: return self.dtype_in - assert self.n is not None + assert self.n is not None # for mypy return QUInt(bitsize=self.n) @property @@ -255,7 +269,7 @@ def signature(self) -> 'Signature': ) def adjoint(self): - return evolve(self, partition=not self.partition) + return evolve(self, n=None, dtype_in=self.lumped_dtype, partition=not self.partition) @frozen diff --git a/qualtran/bloqs/bookkeeping/partition_test.py b/qualtran/bloqs/bookkeeping/partition_test.py index c6c3236332..7be26b864c 100644 --- a/qualtran/bloqs/bookkeeping/partition_test.py +++ b/qualtran/bloqs/bookkeeping/partition_test.py @@ -20,7 +20,7 @@ import pytest from attrs import frozen -from qualtran import Bloq, BloqBuilder, QAny, QGF, Register, Signature, Soquet, SoquetT +from qualtran import Bloq, BloqBuilder, QAny, QGF, QInt, QUInt, Register, Signature, Soquet, SoquetT from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import CNOT from qualtran.bloqs.bookkeeping import Partition @@ -37,10 +37,22 @@ def test_partition(bloq_autotester): def test_partition_check(): with pytest.raises(ValueError): _ = Partition(n=0, regs=()) + with pytest.raises(ValueError): + _ = Partition(n=10, regs=None) + with pytest.raises(ValueError): + _ = Partition(dtype_in=QUInt(10)) with pytest.raises(ValueError): _ = Partition(n=1, regs=(Register('x', QAny(2)),)) with pytest.raises(ValueError): _ = Partition(n=4, regs=(Register('x', QAny(1)), Register('x', QAny(3)))) + with pytest.raises(ValueError): + _ = Partition(n=10) + + regs = (Register("xx", QUInt(4)), Register("yy", QInt(6))) + with pytest.raises(ValueError): + _ = Partition(regs=regs) + with pytest.raises(ValueError): + _ = Partition(n=11, regs=regs) @frozen @@ -57,7 +69,7 @@ def signature(self) -> Signature: def build_composite_bloq(self, bb: 'BloqBuilder', test_regs: 'SoquetT') -> Dict[str, 'Soquet']: bloq_regs = self.test_bloq.signature - partition = Partition(self.bitsize, bloq_regs) # type: ignore[arg-type] + partition = Partition(dtype_in=QUInt(self.bitsize), regs=bloq_regs) # type: ignore[arg-type] out_regs = bb.add(partition, x=test_regs) out_regs = bb.add(self.test_bloq, **{reg.name: sp for reg, sp in zip(bloq_regs, out_regs)}) test_regs = bb.add( From d0e067f29ebe17407f3451a999033f72aa8b7ddf Mon Sep 17 00:00:00 2001 From: tachard Date: Sun, 15 Feb 2026 21:38:51 +0100 Subject: [PATCH 6/8] Modified stuff for mypy, cleaner now --- qualtran/bloqs/bookkeeping/partition.py | 50 +++++++++++-------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 18791c339f..9db5785e34 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -14,7 +14,7 @@ import abc import warnings from functools import cached_property -from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, cast import numpy as np import sympy @@ -221,42 +221,34 @@ class Partition(_PartitionBase): partition: bool = field(default=True) def __attrs_post_init__(self): - match (self.n, self.dtype_in): - case (None, None): - raise ValueError("Provide exactly n or dtype_in") - case (n, None): - warnings.warn( - "Partition: By not setting dtype_in you could encounter errors when running " - "assert_consistent_classical_action", - category=LegacyPartitionWarning, - ) - case (None, dt): - assert dt is not None # for mypy - object.__setattr__(self, "n", dt.num_qubits) - case (n, dt): - assert dt is not None - assert n is not None # for mypy - if n != dt.num_qubits: - raise ValueError(f"{dt=} should have size {n=}, currently {dt.num_qubits=}") - warnings.warn( - "Specifying both n and dtype_in is redundant", - category=UserWarning, - stacklevel=1, + if self.n is None and self.dtype_in is None: + raise ValueError(f"Provide exactly n or dtype_in {self.n=}, {self.dtype_in=}") + elif self.n is not None and self.dtype_in is None: + warnings.warn( + "Partition: By not setting dtype_in you could encounter errors when running " + "assert_consistent_classical_action", + category=LegacyPartitionWarning, + ) + elif self.n is None and self.dtype_in is not None: + object.__setattr__(self, "n", self.dtype_in.num_qubits) + elif self.n is not None and self.dtype_in is not None: + if self.n != self.dtype_in.num_qubits: + raise ValueError( + f"{self.dtype_in=} should have size {self.n=}, currently {self.dtype_in.num_qubits=}" ) + warnings.warn( + "Specifying both n and dtype_in is redundant", category=UserWarning, stacklevel=1 + ) self._validate() @property def lumped_dtype(self) -> QDType: - if self.dtype_in is not None: - return self.dtype_in - - assert self.n is not None # for mypy - return QUInt(bitsize=self.n) + return QUInt(bitsize=cast(SymbolicInt, self.n)) if self.dtype_in is None else self.dtype_in @property def _regs(self) -> Sequence[Register]: - return self.regs + return cast(Tuple[Register, ...], self.regs) @cached_property def signature(self) -> 'Signature': @@ -265,7 +257,7 @@ def signature(self) -> 'Signature': return Signature( [Register('x', self.lumped_dtype, side=lumped)] - + [evolve(reg, side=partitioned) for reg in self.regs] + + [evolve(reg, side=partitioned) for reg in self._regs] ) def adjoint(self): From 10b3e0667825b2afbc8537a1a55de5637173b8fa Mon Sep 17 00:00:00 2001 From: tachard Date: Sun, 15 Feb 2026 21:50:10 +0100 Subject: [PATCH 7/8] linter --- qualtran/bloqs/bookkeeping/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qualtran/bloqs/bookkeeping/partition.py b/qualtran/bloqs/bookkeeping/partition.py index 9db5785e34..b0440f25ca 100644 --- a/qualtran/bloqs/bookkeeping/partition.py +++ b/qualtran/bloqs/bookkeeping/partition.py @@ -14,7 +14,7 @@ import abc import warnings from functools import cached_property -from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, cast +from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np import sympy From 8ebb73f65f87be1bd47292166521561007c7cc62 Mon Sep 17 00:00:00 2001 From: tachard Date: Sun, 15 Feb 2026 21:57:04 +0100 Subject: [PATCH 8/8] Notebook change --- qualtran/bloqs/bookkeeping/partition.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/qualtran/bloqs/bookkeeping/partition.ipynb b/qualtran/bloqs/bookkeeping/partition.ipynb index 65ff47688d..9080931b62 100644 --- a/qualtran/bloqs/bookkeeping/partition.ipynb +++ b/qualtran/bloqs/bookkeeping/partition.ipynb @@ -39,8 +39,9 @@ "Partition a generic index into multiple registers.\n", "\n", "#### Parameters\n", - " - `n`: The total bitsize of the un-partitioned register\n", + " - `n`: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible.\n", " - `regs`: Registers to partition into. The `side` attribute is ignored.\n", + " - `dtype_in`: Type of the un-partitioned register. Required if `n` is None. If None, the type is inferred as `QUInt(n)`.\n", " - `partition`: `False` means un-partition instead. \n", "\n", "#### Registers\n",