Skip to content

Commit d6caf66

Browse files
authored
Mypy issues part 5 (#915)
* Mypy issues part 5 - Solves almost all of the remaining mypy issues - Mostly confusion about symbolic expressions and whether things should be numpy arrays or not
1 parent d2eb86a commit d6caf66

58 files changed

Lines changed: 370 additions & 196 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

qualtran/_infra/adjoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _t_complexity_(self):
199199
return NotImplemented
200200

201201
try:
202-
return self.subbloq._t_complexity_(adjoint=True)
202+
return self.subbloq._t_complexity_(adjoint=True) # type: ignore[call-arg]
203203
except TypeError as e:
204204
if 'adjoint' in str(e):
205205
return self.subbloq._t_complexity_()

qualtran/_infra/bloq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def adjoint(self) -> 'Bloq':
180180

181181
return Adjoint(self)
182182

183-
def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
183+
def on_classical_vals(
184+
self, **vals: Union['sympy.Symbol', 'ClassicalValT']
185+
) -> Dict[str, 'ClassicalValT']:
184186
"""How this bloq operates on classical data.
185187
186188
Override this method if your bloq represents classical, reversible logic. For example:

qualtran/_infra/bloq_example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,11 @@ def bloq_example(
9090

9191

9292
def bloq_example(
93-
_func: Optional[Callable[[], _BloqType]] = None, *, generalizer: _GeneralizerType = lambda x: x
94-
) -> BloqExample[_BloqType]:
93+
_func: Optional[Callable[[], _BloqType]] = None,
94+
*,
95+
generalizer: _GeneralizerType = lambda x: x,
96+
**kwargs: Any,
97+
) -> Union[Callable[[Callable[[], _BloqType]], BloqExample[_BloqType]], BloqExample[_BloqType]]:
9598
"""Decorator to turn a function into a `BloqExample`.
9699
97100
This will set `name` to the name of the function and `bloq_cls` according to the return-type

qualtran/_infra/composite_bloq.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
from functools import cached_property
1717
from typing import (
1818
Callable,
19+
cast,
1920
Dict,
2021
FrozenSet,
2122
Hashable,
2223
Iterable,
2324
Iterator,
2425
List,
26+
Mapping,
2527
Optional,
2628
overload,
2729
Sequence,
@@ -176,7 +178,9 @@ def from_cirq_circuit(cls, circuit: 'cirq.Circuit') -> 'CompositeBloq':
176178

177179
return cirq_optree_to_cbloq(circuit)
178180

179-
def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
181+
def on_classical_vals(
182+
self, **vals: Union[sympy.Symbol, 'ClassicalValT']
183+
) -> Dict[str, 'ClassicalValT']:
180184
"""Support classical data by recursing into the composite bloq."""
181185
from qualtran.simulation.classical_sim import call_cbloq_classically
182186

@@ -606,7 +610,7 @@ def _reg_to_soq(
606610

607611
def _process_soquets(
608612
registers: Iterable[Register],
609-
in_soqs: Dict[str, SoquetT],
613+
in_soqs: Mapping[str, SoquetInT],
610614
debug_str: str,
611615
func: Callable[[Soquet, Register, Tuple[int, ...]], None],
612616
) -> None:
@@ -631,6 +635,7 @@ def _process_soquets(
631635
the incoming, indexed soquet as well as the register and (left-)index it
632636
has been mapped to.
633637
"""
638+
unchecked_names: Set[str] = set(in_soqs.keys())
634639
for reg in registers:
635640
try:
636641
# if we want fancy indexing (which we do), we need numpy
@@ -639,7 +644,7 @@ def _process_soquets(
639644
except KeyError:
640645
raise BloqError(f"{debug_str} requires a Soquet named `{reg.name}`.") from None
641646

642-
del in_soqs[reg.name] # so we can check for surplus arguments.
647+
unchecked_names.remove(reg.name) # so we can check for surplus arguments.
643648

644649
for li in reg.all_idxs():
645650
idxed_soq = in_soq[li]
@@ -652,9 +657,8 @@ def _process_soquets(
652657
raise BloqError(
653658
f"{debug_str} register dtypes are not consistent {extra_str}."
654659
) from None
655-
656-
if in_soqs:
657-
raise BloqError(f"{debug_str} does not accept Soquets: {in_soqs.keys()}.") from None
660+
if unchecked_names:
661+
raise BloqError(f"{debug_str} does not accept Soquets: {unchecked_names}.") from None
658662

659663

660664
def _map_soqs(
@@ -965,7 +969,7 @@ def add(self, bloq: Bloq, **in_soqs: SoquetInT):
965969
return outs
966970

967971
def _add_binst(
968-
self, binst: BloqInstance, in_soqs: Dict[str, SoquetInT]
972+
self, binst: BloqInstance, in_soqs: Mapping[str, SoquetInT]
969973
) -> Iterator[Tuple[str, SoquetT]]:
970974
"""Add a bloq instance.
971975
@@ -1010,7 +1014,8 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]:
10101014

10111015
# Initial mapping of LeftDangle according to user-provided in_soqs.
10121016
soq_map: List[Tuple[SoquetT, SoquetT]] = [
1013-
(_reg_to_soq(LeftDangle, reg), in_soqs[reg.name]) for reg in cbloq.signature.lefts()
1017+
(_reg_to_soq(LeftDangle, reg), cast(SoquetT, in_soqs[reg.name]))
1018+
for reg in cbloq.signature.lefts()
10141019
]
10151020

10161021
for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs():

qualtran/_infra/controlled.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def _cvs_convert(
4949
cvs: Union[
5050
int,
5151
np.integer,
52-
NDArray[np.number],
52+
NDArray[np.integer],
5353
Sequence[Union[int, np.integer]],
5454
Sequence[Sequence[Union[int, np.integer]]],
55-
Sequence[NDArray[np.number]],
55+
Sequence[NDArray[np.integer]],
5656
]
5757
) -> Tuple[NDArray[np.integer], ...]:
5858
if isinstance(cvs, (int, np.integer)):
@@ -235,7 +235,7 @@ def from_cirq_cv(
235235
for qdtype, shape in zip(qdtypes, shapes):
236236
full_shape = shape + (qdtype.num_qubits,)
237237
curr_cvs_bits = np.array(cv[idx : idx + int(np.prod(full_shape))]).reshape(full_shape)
238-
curr_cvs = np.apply_along_axis(qdtype.from_bits, -1, curr_cvs_bits)
238+
curr_cvs = np.apply_along_axis(qdtype.from_bits, -1, curr_cvs_bits) # type: ignore[arg-type]
239239
bloq_cvs.append(curr_cvs)
240240
return CtrlSpec(tuple(qdtypes), tuple(bloq_cvs))
241241

@@ -274,7 +274,7 @@ def _get_nice_ctrl_reg_names(reg_names: List[str], n: int) -> Tuple[str, ...]:
274274
i = 1
275275
else:
276276
i = 0
277-
names = []
277+
names: List[str] = []
278278
while len(names) < n:
279279
while True:
280280
i += 1
@@ -358,13 +358,15 @@ def decompose_bloq(self) -> 'CompositeBloq':
358358
cbloq = self.subbloq.decompose_bloq()
359359

360360
bb, initial_soqs = BloqBuilder.from_signature(self.signature)
361-
ctrl_soqs = [initial_soqs[creg_name] for creg_name in self.ctrl_reg_names]
361+
ctrl_soqs: List['SoquetT'] = [initial_soqs[creg_name] for creg_name in self.ctrl_reg_names]
362362

363363
soq_map: List[Tuple[SoquetT, SoquetT]] = []
364364
for binst, in_soqs, old_out_soqs in cbloq.iter_bloqsoqs():
365365
in_soqs = bb.map_soqs(in_soqs, soq_map)
366366
new_bloq, adder = binst.bloq.get_ctrl_system(self.ctrl_spec)
367-
ctrl_soqs, new_out_soqs = adder(bb, ctrl_soqs=ctrl_soqs, in_soqs=in_soqs)
367+
adder_output = adder(bb, ctrl_soqs=ctrl_soqs, in_soqs=in_soqs)
368+
ctrl_soqs = list(adder_output[0])
369+
new_out_soqs = adder_output[1]
368370
soq_map.extend(zip(old_out_soqs, new_out_soqs))
369371

370372
fsoqs = bb.map_soqs(cbloq.final_soqs(), soq_map)
@@ -460,6 +462,7 @@ def as_cirq_op(
460462
ctrl_regs = {reg_name: cirq_quregs.pop(reg_name) for reg_name in self.ctrl_reg_names}
461463
ctrl_qubits = [q for reg in ctrl_regs.values() for q in reg.reshape(-1)]
462464
sub_op, cirq_quregs = self.subbloq.as_cirq_op(qubit_manager, **cirq_quregs)
465+
assert sub_op is not None
463466
return (
464467
sub_op.controlled_by(*ctrl_qubits, control_values=self.ctrl_spec.to_cirq_cv()),
465468
cirq_quregs | ctrl_regs,

qualtran/_infra/controlled_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_ctrl_spec():
5858
cspec1 = CtrlSpec()
5959
assert cspec1 == CtrlSpec(QBit(), cvs=1)
6060

61-
cspec2 = CtrlSpec(cvs=np.ones(27).reshape((3, 3, 3)))
61+
cspec2 = CtrlSpec(cvs=np.ones(27, dtype=np.intc).reshape((3, 3, 3)))
6262
assert cspec2.shapes == ((3, 3, 3),)
6363
assert cspec2 != cspec1
6464

qualtran/_infra/gate_with_registers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def controlled(
451451

452452
# pylint: disable=signature-differs
453453
@overload
454-
def controlled(self, ctrl_spec: Optional['CtrlSpec'] = None) -> 'GateWithRegisters':
454+
def controlled(self, *, ctrl_spec: Optional['CtrlSpec'] = None) -> 'GateWithRegisters':
455455
"""Bloq-style API to construct a controlled Bloq. See `Bloq.controlled()`."""
456456

457457
def controlled(

qualtran/_infra/gate_with_registers_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def test_gate_with_registers():
7979
assert (
8080
tg.controlled(num_controls=1, control_values=[0])
8181
== tg.controlled(control_values=[0], control_qid_shape=(2,))
82-
== tg.controlled(CtrlSpec(cvs=0))
8382
== tg.controlled(ctrl_spec=CtrlSpec(cvs=0))
8483
)
8584

qualtran/bloqs/arithmetic/addition.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
from qualtran.drawing import WireSymbol
5454
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
55+
from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt
5556
from qualtran.simulation.classical_sim import ClassicalValT
5657

5758

@@ -272,7 +273,7 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[m
272273
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648)
273274
"""
274275

275-
bitsize: int
276+
bitsize: 'SymbolicInt'
276277
is_adjoint: bool = False
277278

278279
@property
@@ -287,6 +288,8 @@ def signature(self):
287288
)
288289

289290
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
291+
if not isinstance(self.bitsize, int):
292+
raise ValueError(f'Symbolic bitsize {self.bitsize} not supported')
290293
return [2] * self.bitsize, [2] * self.bitsize, [2] * (self.bitsize + 1)
291294

292295
def apply(self, a: int, b: int, c: int) -> Tuple[int, int, int]:
@@ -309,6 +312,8 @@ def short_name(self) -> str:
309312
def decompose_from_registers(
310313
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
311314
) -> cirq.OP_TREE:
315+
if not isinstance(self.bitsize, int):
316+
raise ValueError(f'Symbolic bitsize {self.bitsize} not supported')
312317
a, b, c = quregs['a'][::-1], quregs['b'][::-1], quregs['c'][::-1]
313318
optree: List[List[cirq.Operation]] = [
314319
[

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
if TYPE_CHECKING:
4848
from qualtran import BloqBuilder
4949
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
50+
from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt
5051
from qualtran.simulation.classical_sim import ClassicalValT
5152

5253

@@ -404,8 +405,8 @@ class LessThanEqual(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[mis
404405
https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
405406
"""
406407

407-
x_bitsize: int
408-
y_bitsize: int
408+
x_bitsize: 'SymbolicInt'
409+
y_bitsize: 'SymbolicInt'
409410

410411
@cached_property
411412
def signature(self) -> 'Signature':
@@ -414,6 +415,10 @@ def signature(self) -> 'Signature':
414415
)
415416

416417
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
418+
if isinstance(self.x_bitsize, sympy.Expr):
419+
raise ValueError(f'Symbolic x bitsize {self.x_bitsize} not allowed')
420+
if isinstance(self.y_bitsize, sympy.Expr):
421+
raise ValueError(f'Symbolic y bitsize {self.y_bitsize} not allowed')
417422
return [2] * self.x_bitsize, [2] * self.y_bitsize, [2]
418423

419424
def with_registers(self, *new_registers) -> "LessThanEqual":
@@ -430,6 +435,10 @@ def on_classical_vals(self, *, x: int, y: int, target: int) -> Dict[str, 'Classi
430435
return {'x': x, 'y': y, 'target': target ^ (x <= y)}
431436

432437
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
438+
if isinstance(self.x_bitsize, sympy.Expr):
439+
raise ValueError(f'Symbolic x bitsize {self.x_bitsize} not allowed')
440+
if isinstance(self.y_bitsize, sympy.Expr):
441+
raise ValueError(f'Symbolic y bitsize {self.y_bitsize} not allowed')
433442
wire_symbols = ["In(x)"] * self.x_bitsize
434443
wire_symbols += ["In(y)"] * self.y_bitsize
435444
wire_symbols += ['⨁(x <= y)']
@@ -581,8 +590,8 @@ class GreaterThan(Bloq):
581590
target: A single bit output register to store the result of A > B.
582591
"""
583592

584-
a_bitsize: int
585-
b_bitsize: int
593+
a_bitsize: 'SymbolicInt'
594+
b_bitsize: 'SymbolicInt'
586595

587596
@property
588597
def signature(self):
@@ -690,7 +699,7 @@ def build_composite_bloq(
690699

691700
# Allocate lists to store ancillas generated by the logical-and and control pairs input
692701
# into logical-ands.
693-
ancillas = []
702+
ancillas: List[SoquetT] = []
694703
and_ctrls = []
695704

696705
# If the input registers are unsigned we need to append a sign bit to them in order to use

0 commit comments

Comments
 (0)