Skip to content

Commit 5f29071

Browse files
authored
[arithmetic] Organize Subtract (#964)
1 parent 05504ce commit 5f29071

7 files changed

Lines changed: 232 additions & 183 deletions

File tree

qualtran/bloqs/arithmetic/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from qualtran.bloqs.arithmetic.addition import Add, AddK, OutOfPlaceAdder, Subtract
15+
from qualtran.bloqs.arithmetic.addition import Add, AddK, OutOfPlaceAdder
1616
from qualtran.bloqs.arithmetic.comparison import (
1717
BiQubitsMixer,
1818
EqualsAConstant,
@@ -34,5 +34,6 @@
3434
SumOfSquares,
3535
)
3636
from qualtran.bloqs.arithmetic.sorting import BitonicSort, Comparator
37+
from qualtran.bloqs.arithmetic.subtraction import Subtract
3738

38-
from ._shims import CHalf, Lt, MultiCToffoli, Negate, Sub
39+
from ._shims import CHalf, Lt, MultiCToffoli, Negate

qualtran/bloqs/arithmetic/_shims.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,6 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
4040
return {(Toffoli(), self.n - 2)}
4141

4242

43-
@frozen
44-
class Sub(Bloq):
45-
n: int
46-
47-
@cached_property
48-
def signature(self) -> 'Signature':
49-
return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))])
50-
51-
5243
@frozen
5344
class Lt(Bloq):
5445
n: int

qualtran/bloqs/arithmetic/addition.py

Lines changed: 1 addition & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
GateWithRegisters,
4848
QBit,
4949
QInt,
50+
QMontgomeryUInt,
5051
QUInt,
5152
Register,
5253
Side,
5354
Signature,
5455
Soquet,
5556
SoquetT,
5657
)
57-
from qualtran._infra.data_types import QMontgomeryUInt
5858
from qualtran.bloqs import util_bloqs
5959
from qualtran.bloqs.basic_gates import CNOT, XGate
6060
from qualtran.bloqs.mcmt.and_bloq import And
@@ -575,145 +575,3 @@ def _add_k_large() -> AddK:
575575

576576

577577
_ADD_K_DOC = BloqDocSpec(bloq_cls=AddK, examples=[_add_k, _add_k_small, _add_k_large])
578-
579-
580-
@frozen
581-
class Subtract(Bloq):
582-
r"""An n-bit subtraction gate.
583-
584-
Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|a-b\rangle$ using $4n - 4 T$ gates.
585-
586-
Args:
587-
a_dtype: Quantum datatype used to represent the integer a.
588-
b_dtype: Quantum datatype used to represent the integer b. Must be large
589-
enough to hold the result in the output register of a - b, or else it simply
590-
drops the most significant bits. If not specified, b_dtype is set to a_dtype.
591-
592-
Registers:
593-
a: A a_dtype.bitsize-sized input register (register a above).
594-
b: A b_dtype.bitsize-sized input/output register (register b above).
595-
596-
References:
597-
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648)
598-
"""
599-
600-
a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
601-
b_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
602-
603-
@b_dtype.default
604-
def b_dtype_default(self):
605-
return self.a_dtype
606-
607-
@a_dtype.validator
608-
def _a_dtype_validate(self, field, val):
609-
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
610-
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
611-
if isinstance(val.num_qubits, sympy.Expr):
612-
return
613-
if val.bitsize > self.b_dtype.bitsize:
614-
raise ValueError("a_dtype bitsize must be less than or equal to b_dtype bitsize")
615-
616-
@b_dtype.validator
617-
def _b_dtype_validate(self, field, val):
618-
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
619-
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
620-
621-
@property
622-
def dtype(self):
623-
if self.a_dtype != self.b_dtype:
624-
raise ValueError(
625-
"Add.dtype is only supported when both operands have the same dtype: "
626-
f"{self.a_dtype=}, {self.b_dtype=}"
627-
)
628-
return self.a_dtype
629-
630-
@property
631-
def signature(self):
632-
return Signature([Register("a", self.a_dtype), Register("b", self.b_dtype)])
633-
634-
def on_classical_vals(
635-
self, a: 'ClassicalValT', b: 'ClassicalValT'
636-
) -> Dict[str, 'ClassicalValT']:
637-
unsigned = isinstance(self.a_dtype, (QUInt, QMontgomeryUInt))
638-
b_bitsize = self.b_dtype.bitsize
639-
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
640-
return {'a': a, 'b': int(math.fmod(a - b, N))}
641-
642-
def short_name(self) -> str:
643-
return "a-b"
644-
645-
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
646-
wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize)
647-
wire_symbols += ["In(y)/Out(x-y)"] * int(self.b_dtype.bitsize)
648-
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
649-
650-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
651-
from qualtran.drawing import directional_text_box
652-
653-
if soq.reg.name == 'a':
654-
return directional_text_box('a', side=soq.reg.side)
655-
elif soq.reg.name == 'b':
656-
return directional_text_box('a-b', side=soq.reg.side)
657-
else:
658-
raise ValueError()
659-
660-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
661-
a_dtype = (
662-
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
663-
)
664-
b_dtype = (
665-
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
666-
)
667-
return {
668-
(XGate(), self.b_dtype.bitsize),
669-
(AddK(self.b_dtype.bitsize, k=1), 1),
670-
(Add(a_dtype, b_dtype), 1),
671-
(util_bloqs.Split(self.b_dtype), 1),
672-
(util_bloqs.Join(self.b_dtype), 1),
673-
}
674-
675-
def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[str, 'SoquetT']:
676-
b = np.array([bb.add(XGate(), q=q) for q in bb.split(b)]) # 1s complement of b.
677-
b = bb.add(
678-
AddK(self.b_dtype.bitsize, k=1), x=bb.join(b, self.b_dtype)
679-
) # 2s complement of b.
680-
681-
a_dtype = (
682-
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
683-
)
684-
b_dtype = (
685-
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
686-
)
687-
688-
a, b = bb.add(Add(a_dtype, b_dtype), a=a, b=b) # a - b
689-
return {'a': a, 'b': b}
690-
691-
692-
@bloq_example
693-
def _sub_symb() -> Subtract:
694-
n = sympy.Symbol('n')
695-
sub_symb = Subtract(QInt(bitsize=n))
696-
return sub_symb
697-
698-
699-
@bloq_example
700-
def _sub_small() -> Subtract:
701-
sub_small = Subtract(QInt(bitsize=4))
702-
return sub_small
703-
704-
705-
@bloq_example
706-
def _sub_large() -> Subtract:
707-
sub_large = Subtract(QInt(bitsize=64))
708-
return sub_large
709-
710-
711-
@bloq_example
712-
def _sub_diff_size_regs() -> Subtract:
713-
sub_diff_size_regs = Subtract(QInt(bitsize=4), QInt(bitsize=16))
714-
return sub_diff_size_regs
715-
716-
717-
_SUB_DOC = BloqDocSpec(
718-
bloq_cls=Subtract, examples=[_sub_symb, _sub_small, _sub_large, _sub_diff_size_regs]
719-
)

qualtran/bloqs/arithmetic/addition_test.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import qualtran.testing as qlt_testing
2323
from qualtran import BloqBuilder, CtrlSpec, QInt, QUInt
24-
from qualtran.bloqs.arithmetic.addition import Add, AddK, OutOfPlaceAdder, Subtract
24+
from qualtran.bloqs.arithmetic.addition import Add, AddK, OutOfPlaceAdder
2525
from qualtran.cirq_interop.bit_tools import iter_bits, iter_bits_twos_complement
2626
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
2727
from qualtran.cirq_interop.testing import (
@@ -329,31 +329,6 @@ def test_classical_add_k_signed(bitsize, k, x, cvs, ctrls, result):
329329
assert bloq_classical[-1] == result
330330

331331

332-
@pytest.mark.slow
333-
def test_subtract_bloq_decomposition():
334-
gate = Subtract(QInt(3), QInt(5))
335-
qlt_testing.assert_valid_bloq_decomposition(gate)
336-
337-
want = np.zeros((256, 256))
338-
for a_b in range(256):
339-
a, b = a_b >> 5, a_b & 31
340-
c = (a - b) % 32
341-
want[(a << 5) | c][a_b] = 1
342-
got = gate.tensor_contract()
343-
np.testing.assert_equal(got, want)
344-
345-
346-
def test_subtract_bloq_validation():
347-
assert Subtract(QUInt(3)) == Subtract(QUInt(3), QUInt(3))
348-
with pytest.raises(ValueError, match='bitsize must be less'):
349-
_ = Subtract(QInt(5), QInt(3))
350-
assert Subtract(QUInt(3)).dtype == QUInt(3)
351-
352-
353-
def test_subtract_bloq_consitant_counts():
354-
qlt_testing.assert_equivalent_bloq_counts(Subtract(QInt(3), QInt(4)))
355-
356-
357332
@pytest.mark.notebook
358333
def test_notebook():
359334
qlt_testing.execute_notebook('addition')

0 commit comments

Comments
 (0)