Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ skip_glob = ["qualtran/protos/*"]
[tool.pytest.ini_options]
filterwarnings = [
'ignore::DeprecationWarning:quimb.linalg.approx_spectral:',
'ignore::qualtran.bloqs.bookkeeping.partition.LegacyPartitionWarning',
'ignore:.*standard platformdirs.*:DeprecationWarning:jupyter_client.*'
]
# we define classes like TestBloq etc. which pytest tries to collect,
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/bookkeeping/partition.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
"Partition a generic index into multiple registers.\n",
"\n",
"#### Parameters\n",
" - `n`: The total bitsize of the un-partitioned register\n",
" - `n`: The total bit-size of the un-partitioned register. Required if `dtype_in` is None. Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible.\n",
" - `regs`: Registers to partition into. The `side` attribute is ignored.\n",
" - `dtype_in`: Type of the un-partitioned register. Required if `n` is None. If None, the type is inferred as `QUInt(n)`.\n",
" - `partition`: `False` means un-partition instead. \n",
"\n",
"#### Registers\n",
Expand Down
96 changes: 83 additions & 13 deletions qualtran/bloqs/bookkeeping/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import warnings
from functools import cached_property
from typing import Dict, List, Sequence, Tuple, TYPE_CHECKING
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING

import numpy as np
import sympy
Expand All @@ -29,6 +30,7 @@
DecomposeTypeError,
QAny,
QDType,
QUInt,
Register,
Side,
Signature,
Expand All @@ -46,16 +48,45 @@
from qualtran.simulation.classical_sim import ClassicalValT


class LegacyPartitionWarning(DeprecationWarning):
"""Warnings for legacy Partition usage, when declaring only n."""


def _constrain_qany_reg(reg: Register):
"""Changes the dtype of a register to note break legacy code

This function should be bound to dissapear
"""
if isinstance(reg.dtype, QAny):
warnings.warn(
f"Doing classical casting with QAny ({reg=}) is ambiguous, transforming it as QUInt for legacy purposes",
category=LegacyPartitionWarning,
)
return evolve(reg, dtype=QUInt(reg.dtype.bitsize))
return reg


def _regs_to_tuple(x):
if x is None:
return None
return x if isinstance(x, tuple) else tuple(x)


def _not_none(_inst, attr, value):
if value is None:
raise ValueError(f"{attr.name} cannot be None")


class _PartitionBase(_BookkeepingBloq, metaclass=abc.ABCMeta):
"""Generalized paritioning functionality."""

@property
@abc.abstractmethod
def n(self) -> SymbolicInt: ...
def n(self) -> Optional[SymbolicInt]: ...

@cached_property
def lumped_dtype(self) -> QDType:
return QAny(bitsize=self.n)
@property
@abc.abstractmethod
def lumped_dtype(self) -> QDType: ...

@property
@abc.abstractmethod
Expand Down Expand Up @@ -98,6 +129,8 @@ def my_tensors(
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

if self.n is None:
raise DecomposeTypeError(f"cannot compute tensors with unknown n for {self}")
if is_symbolic(self.n):
raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}")

Expand All @@ -124,6 +157,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']
xbits = self.lumped_dtype.to_bits(x)
start = 0
for reg in self._regs:
reg = _constrain_qany_reg(reg)
size = int(np.prod(reg.shape + (reg.bitsize,)))
bits_reg = xbits[start : start + size]
if reg.shape == ():
Expand All @@ -138,6 +172,7 @@ def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']
def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]:
out_vals: list[NDArray[np.uint8]] = []
for reg in self._regs:
reg = _constrain_qany_reg(reg)
reg_val = np.asanyarray(vals[reg.name])
bitstrings = reg.dtype.to_bits_array(reg_val.ravel())
out_vals.append(bitstrings.ravel())
Expand Down Expand Up @@ -166,27 +201,54 @@ class Partition(_PartitionBase):
"""Partition a generic index into multiple registers.

Args:
n: The total bitsize of the un-partitioned register
n: The total bit-size of the un-partitioned register. Required if `dtype_in` is None.
Deprecated. Kept for backward compatibility. Use `dtype_in` instead whenever possible.
regs: Registers to partition into. The `side` attribute is ignored.
dtype_in: Type of the un-partitioned register. Required if `n` is None. If None,
the type is inferred as `QUInt(n)`.
partition: `False` means un-partition instead.

Registers:
x: the un-partitioned register. LEFT by default.
[user spec]: The registers provided by the `regs` argument. RIGHT by default.
"""

n: SymbolicInt
regs: Tuple[Register, ...] = field(
converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1)
n: Optional[SymbolicInt] = field(default=None)
regs: Optional[Tuple[Register, ...]] = field(
converter=_regs_to_tuple, validator=(_not_none, validators.min_len(1)), default=None
)
partition: bool = True
dtype_in: Optional[QDType] = field(default=None)
partition: bool = field(default=True)

def __attrs_post_init__(self):
if self.n is None and self.dtype_in is None:
raise ValueError(f"Provide exactly n or dtype_in {self.n=}, {self.dtype_in=}")
elif self.n is not None and self.dtype_in is None:
warnings.warn(
"Partition: By not setting dtype_in you could encounter errors when running "
"assert_consistent_classical_action",
category=LegacyPartitionWarning,
)
elif self.n is None and self.dtype_in is not None:
object.__setattr__(self, "n", self.dtype_in.num_qubits)
elif self.n is not None and self.dtype_in is not None:
if self.n != self.dtype_in.num_qubits:
raise ValueError(
f"{self.dtype_in=} should have size {self.n=}, currently {self.dtype_in.num_qubits=}"
)
warnings.warn(
"Specifying both n and dtype_in is redundant", category=UserWarning, stacklevel=1
)

self._validate()

@property
def lumped_dtype(self) -> QDType:
return QUInt(bitsize=cast(SymbolicInt, self.n)) if self.dtype_in is None else self.dtype_in

@property
def _regs(self) -> Sequence[Register]:
return self.regs
return cast(Tuple[Register, ...], self.regs)

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -195,11 +257,11 @@ def signature(self) -> 'Signature':

return Signature(
[Register('x', self.lumped_dtype, side=lumped)]
+ [evolve(reg, side=partitioned) for reg in self.regs]
+ [evolve(reg, side=partitioned) for reg in self._regs]
)

def adjoint(self):
return evolve(self, partition=not self.partition)
return evolve(self, n=None, dtype_in=self.lumped_dtype, partition=not self.partition)


@frozen
Expand Down Expand Up @@ -228,6 +290,10 @@ class Split2(_PartitionBase):
def n(self) -> SymbolicInt:
return self.n1 + self.n2

@property
def lumped_dtype(self) -> QDType:
return QUInt(bitsize=self.n)

@property
def partition(self) -> bool:
return True
Expand Down Expand Up @@ -289,6 +355,10 @@ class Join2(_PartitionBase):
def n(self) -> SymbolicInt:
return self.n1 + self.n2

@property
def lumped_dtype(self) -> QDType:
return QUInt(bitsize=self.n)

@property
def partition(self) -> bool:
return False
Expand Down
16 changes: 14 additions & 2 deletions qualtran/bloqs/bookkeeping/partition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from attrs import frozen

from qualtran import Bloq, BloqBuilder, QAny, QGF, Register, Signature, Soquet, SoquetT
from qualtran import Bloq, BloqBuilder, QAny, QGF, QInt, QUInt, Register, Signature, Soquet, SoquetT
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CNOT
from qualtran.bloqs.bookkeeping import Partition
Expand All @@ -37,10 +37,22 @@ def test_partition(bloq_autotester):
def test_partition_check():
with pytest.raises(ValueError):
_ = Partition(n=0, regs=())
with pytest.raises(ValueError):
_ = Partition(n=10, regs=None)
with pytest.raises(ValueError):
_ = Partition(dtype_in=QUInt(10))
with pytest.raises(ValueError):
_ = Partition(n=1, regs=(Register('x', QAny(2)),))
with pytest.raises(ValueError):
_ = Partition(n=4, regs=(Register('x', QAny(1)), Register('x', QAny(3))))
with pytest.raises(ValueError):
_ = Partition(n=10)

regs = (Register("xx", QUInt(4)), Register("yy", QInt(6)))
with pytest.raises(ValueError):
_ = Partition(regs=regs)
with pytest.raises(ValueError):
_ = Partition(n=11, regs=regs)


@frozen
Expand All @@ -57,7 +69,7 @@ def signature(self) -> Signature:

def build_composite_bloq(self, bb: 'BloqBuilder', test_regs: 'SoquetT') -> Dict[str, 'Soquet']:
bloq_regs = self.test_bloq.signature
partition = Partition(self.bitsize, bloq_regs) # type: ignore[arg-type]
partition = Partition(dtype_in=QUInt(self.bitsize), regs=bloq_regs) # type: ignore[arg-type]
out_regs = bb.add(partition, x=test_regs)
out_regs = bb.add(self.test_bloq, **{reg.name: sp for reg, sp in zip(bloq_regs, out_regs)})
test_regs = bb.add(
Expand Down
Loading