Skip to content

Commit c10d386

Browse files
Create a Subtract bloq (#924)
1 parent b6c7b29 commit c10d386

3 files changed

Lines changed: 176 additions & 3 deletions

File tree

qualtran/bloqs/arithmetic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AddConstantMod,
1818
OutOfPlaceAdder,
1919
SimpleAddConstant,
20+
Subtract,
2021
)
2122
from qualtran.bloqs.arithmetic.comparison import (
2223
BiQubitsMixer,

qualtran/bloqs/arithmetic/addition.py

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@
5151
SoquetT,
5252
)
5353
from qualtran._infra.data_types import QMontgomeryUInt
54+
from qualtran.bloqs import util_bloqs
5455
from qualtran.bloqs.basic_gates import CNOT, XGate
5556
from qualtran.bloqs.mcmt.and_bloq import And
5657
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlX
57-
from qualtran.bloqs.util_bloqs import ArbitraryClifford
5858
from qualtran.cirq_interop import decompose_from_cirq_style_method
5959
from qualtran.cirq_interop.bit_tools import iter_bits, iter_bits_twos_complement
6060
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
@@ -348,7 +348,7 @@ def _t_complexity_(self) -> TComplexity:
348348
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
349349
return {
350350
(And(uncompute=self.is_adjoint), self.bitsize),
351-
(ArbitraryClifford(n=2), 5 * self.bitsize),
351+
(util_bloqs.ArbitraryClifford(n=2), 5 * self.bitsize),
352352
}
353353

354354
def __pow__(self, power: int):
@@ -408,7 +408,7 @@ class SimpleAddConstant(Bloq):
408408
[Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2001.09580) Fig 2a
409409
"""
410410

411-
bitsize: int
411+
bitsize: Union[int, sympy.Expr]
412412
k: int
413413
cvs: Tuple[int, ...] = field(
414414
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
@@ -445,6 +445,10 @@ def on_classical_vals(
445445
def build_composite_bloq(
446446
self, bb: 'BloqBuilder', x: Soquet, **regs: SoquetT
447447
) -> Dict[str, 'SoquetT']:
448+
if isinstance(self.bitsize, sympy.Expr):
449+
raise ValueError(
450+
f'Symbolic bitsize {self.bitsize} not supported for SimpleAddConstant.build_composite_bloq'
451+
)
448452
# Assign registers to variables and allocate ancilla bits for classical integer k.
449453
if len(self.cvs) > 0:
450454
ctrls = regs['ctrls']
@@ -628,3 +632,145 @@ def _add_k_large() -> AddConstantMod:
628632
_ADD_K_DOC = BloqDocSpec(
629633
bloq_cls=AddConstantMod, examples=[_add_k_symb, _add_k_small, _add_k_large]
630634
)
635+
636+
637+
@frozen
638+
class Subtract(Bloq):
639+
r"""An n-bit subtraction gate.
640+
641+
Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|a-b\rangle$ using $4n - 4 T$ gates.
642+
643+
Args:
644+
a_dtype: Quantum datatype used to represent the integer a.
645+
b_dtype: Quantum datatype used to represent the integer b. Must be large
646+
enough to hold the result in the output register of a - b, or else it simply
647+
drops the most significant bits. If not specified, b_dtype is set to a_dtype.
648+
649+
Registers:
650+
a: A a_dtype.bitsize-sized input register (register a above).
651+
b: A b_dtype.bitsize-sized input/output register (register b above).
652+
653+
References:
654+
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648)
655+
"""
656+
657+
a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
658+
b_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
659+
660+
@b_dtype.default
661+
def b_dtype_default(self):
662+
return self.a_dtype
663+
664+
@a_dtype.validator
665+
def _a_dtype_validate(self, field, val):
666+
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
667+
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
668+
if isinstance(val.num_qubits, sympy.Expr):
669+
return
670+
if val.bitsize > self.b_dtype.bitsize:
671+
raise ValueError("a_dtype bitsize must be less than or equal to b_dtype bitsize")
672+
673+
@b_dtype.validator
674+
def _b_dtype_validate(self, field, val):
675+
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
676+
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
677+
678+
@property
679+
def dtype(self):
680+
if self.a_dtype != self.b_dtype:
681+
raise ValueError(
682+
"Add.dtype is only supported when both operands have the same dtype: "
683+
f"{self.a_dtype=}, {self.b_dtype=}"
684+
)
685+
return self.a_dtype
686+
687+
@property
688+
def signature(self):
689+
return Signature([Register("a", self.a_dtype), Register("b", self.b_dtype)])
690+
691+
def on_classical_vals(
692+
self, a: 'ClassicalValT', b: 'ClassicalValT'
693+
) -> Dict[str, 'ClassicalValT']:
694+
unsigned = isinstance(self.a_dtype, (QUInt, QMontgomeryUInt))
695+
b_bitsize = self.b_dtype.bitsize
696+
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
697+
return {'a': a, 'b': int(math.fmod(a - b, N))}
698+
699+
def short_name(self) -> str:
700+
return "a-b"
701+
702+
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
703+
wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize)
704+
wire_symbols += ["In(y)/Out(x-y)"] * int(self.b_dtype.bitsize)
705+
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
706+
707+
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
708+
from qualtran.drawing import directional_text_box
709+
710+
if soq.reg.name == 'a':
711+
return directional_text_box('a', side=soq.reg.side)
712+
elif soq.reg.name == 'b':
713+
return directional_text_box('a-b', side=soq.reg.side)
714+
else:
715+
raise ValueError()
716+
717+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
718+
a_dtype = (
719+
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
720+
)
721+
b_dtype = (
722+
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
723+
)
724+
return {
725+
(XGate(), self.b_dtype.bitsize),
726+
(SimpleAddConstant(self.b_dtype.bitsize, k=1), 1),
727+
(Add(a_dtype, b_dtype), 1),
728+
(util_bloqs.Split(self.b_dtype), 1),
729+
(util_bloqs.Join(self.b_dtype), 1),
730+
}
731+
732+
def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[str, 'SoquetT']:
733+
b = np.array([bb.add(XGate(), q=q) for q in bb.split(b)]) # 1s complement of b.
734+
b = bb.add(
735+
SimpleAddConstant(self.b_dtype.bitsize, k=1), x=bb.join(b, self.b_dtype)
736+
) # 2s complement of b.
737+
738+
a_dtype = (
739+
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
740+
)
741+
b_dtype = (
742+
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
743+
)
744+
745+
a, b = bb.add(Add(a_dtype, b_dtype), a=a, b=b) # a - b
746+
return {'a': a, 'b': b}
747+
748+
749+
@bloq_example
750+
def _sub_symb() -> Subtract:
751+
n = sympy.Symbol('n')
752+
sub_symb = Subtract(QInt(bitsize=n))
753+
return sub_symb
754+
755+
756+
@bloq_example
757+
def _sub_small() -> Subtract:
758+
sub_small = Subtract(QInt(bitsize=4))
759+
return sub_small
760+
761+
762+
@bloq_example
763+
def _sub_large() -> Subtract:
764+
sub_large = Subtract(QInt(bitsize=64))
765+
return sub_large
766+
767+
768+
@bloq_example
769+
def _sub_diff_size_regs() -> Subtract:
770+
sub_diff_size_regs = Subtract(QInt(bitsize=4), QInt(bitsize=16))
771+
return sub_diff_size_regs
772+
773+
774+
_SUB_DOC = BloqDocSpec(
775+
bloq_cls=Subtract, examples=[_sub_symb, _sub_small, _sub_large, _sub_diff_size_regs]
776+
)

qualtran/bloqs/arithmetic/addition_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AddConstantMod,
2727
OutOfPlaceAdder,
2828
SimpleAddConstant,
29+
Subtract,
2930
)
3031
from qualtran.bloqs.arithmetic.comparison_test import identity_map
3132
from qualtran.cirq_interop.bit_tools import iter_bits, iter_bits_twos_complement
@@ -381,6 +382,31 @@ def test_classical_simple_add_constant_signed(bitsize, k, x, cvs, ctrls, result)
381382
assert bloq_classical[-1] == result
382383

383384

385+
@pytest.mark.slow
386+
def test_subtract_bloq_decomposition():
387+
gate = Subtract(QInt(3), QInt(5))
388+
qlt_testing.assert_valid_bloq_decomposition(gate)
389+
390+
want = np.zeros((256, 256))
391+
for a_b in range(256):
392+
a, b = a_b >> 5, a_b & 31
393+
c = (a - b) % 32
394+
want[(a << 5) | c][a_b] = 1
395+
got = gate.tensor_contract()
396+
np.testing.assert_equal(got, want)
397+
398+
399+
def test_subtract_bloq_validation():
400+
assert Subtract(QUInt(3)) == Subtract(QUInt(3), QUInt(3))
401+
with pytest.raises(ValueError, match='bitsize must be less'):
402+
_ = Subtract(QInt(5), QInt(3))
403+
assert Subtract(QUInt(3)).dtype == QUInt(3)
404+
405+
406+
def test_subtract_bloq_consitant_counts():
407+
qlt_testing.assert_equivalent_bloq_counts(Subtract(QInt(3), QInt(4)))
408+
409+
384410
@pytest.mark.notebook
385411
def test_notebook():
386412
qlt_testing.execute_notebook('addition')

0 commit comments

Comments
 (0)