Skip to content

Commit 08e24fd

Browse files
authored
Mypy issues part 4 (#906)
* Mypy issues part 4 This PR features many fixes to get sympy.Expr to type correctly in a bunch of places. Another change was adding iteration_length_or_zero() so we can safely get iteration_length from dtypes even if they don't have one defined.
1 parent 7983de5 commit 08e24fd

56 files changed

Lines changed: 277 additions & 181 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dev_tools/conf/mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ follow_imports = silent
2020
ignore_missing_imports = true
2121

2222
# Non-Google
23-
[mypy-sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,plotly.*,dash.*,tensorflow_docs.*,fxpmath.*,ipywidgets.*,cachetools.*,pydot.*]
23+
[mypy-sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,plotly.*,dash.*,tensorflow_docs.*,fxpmath.*,ipywidgets.*,cachetools.*,pydot.*,nbformat.*,nbconvert.*,openfermion.*]
2424
follow_imports = silent
2525
ignore_missing_imports = true
2626

qualtran/_infra/adjoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def decompose_bloq(self) -> 'CompositeBloq':
144144
return self.subbloq.decompose_bloq().adjoint()
145145

146146
def decompose_from_registers(
147-
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
147+
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
148148
) -> cirq.OP_TREE:
149149
if isinstance(self.subbloq, GateWithRegisters):
150150
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))

qualtran/_infra/bloq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT
212212
except NotImplementedError as e:
213213
raise NotImplementedError(f"{self} does not support classical simulation: {e}") from e
214214

215-
def call_classically(self, **vals: 'ClassicalValT') -> Tuple['ClassicalValT', ...]:
215+
def call_classically(
216+
self, **vals: Union['sympy.Symbol', 'ClassicalValT']
217+
) -> Tuple['ClassicalValT', ...]:
216218
"""Call this bloq on classical data.
217219
218220
Bloq users can call this function to apply bloqs to classical data. If you're
@@ -297,7 +299,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
297299
def call_graph(
298300
self,
299301
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
300-
keep: Callable[['Bloq'], bool] = None,
302+
keep: Optional[Callable[['Bloq'], bool]] = None,
301303
max_depth: Optional[int] = None,
302304
) -> Tuple['nx.DiGraph', Dict['Bloq', Union[int, 'sympy.Expr']]]:
303305
"""Get the bloq call graph and call totals.

qualtran/_infra/composite_bloq.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import attrs
3636
import networkx as nx
3737
import numpy as np
38+
import sympy
3839
from numpy.typing import NDArray
3940

4041
from .bloq import Bloq, DecomposeTypeError
@@ -45,7 +46,7 @@
4546
if TYPE_CHECKING:
4647
import cirq
4748

48-
from qualtran.cirq_interop import CirqQuregInT, CirqQuregT
49+
from qualtran.cirq_interop._cirq_to_bloq import CirqQuregInT, CirqQuregT
4950
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
5051
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
5152
from qualtran.simulation.classical_sim import ClassicalValT
@@ -447,7 +448,7 @@ def _create_binst_graph(
447448

448449

449450
def _binst_to_cxns(
450-
binst: BloqInstance, binst_graph: nx.DiGraph
451+
binst: Union[BloqInstance, DanglingT], binst_graph: nx.DiGraph
451452
) -> Tuple[List[Connection], List[Connection]]:
452453
"""Helper method to extract all predecessor and successor Connections for a binst."""
453454
pred_cxns: List[Connection] = []
@@ -494,7 +495,7 @@ def _cxn_to_soq_dict(
494495
assign = get_assign(cxn)
495496

496497
if me.reg.shape:
497-
soqdict[me.reg.name][me.idx] = assign
498+
soqdict[me.reg.name][me.idx] = assign # type: ignore[index]
498499
else:
499500
soqdict[me.reg.name] = assign
500501

@@ -598,9 +599,9 @@ def _reg_to_soq(
598599
# Annoyingly, this must be a special case.
599600
# Otherwise, x[i] = thing will nest *array* objects because our ndarray's type is
600601
# 'object'. This wouldn't happen--for example--with an integer array.
601-
soqs = Soquet(binst, reg)
602-
available.add(soqs)
603-
return soqs
602+
soq = Soquet(binst, reg)
603+
available.add(soq)
604+
return soq
604605

605606

606607
def _process_soquets(
@@ -630,7 +631,6 @@ def _process_soquets(
630631
the incoming, indexed soquet as well as the register and (left-)index it
631632
has been mapped to.
632633
"""
633-
634634
for reg in registers:
635635
try:
636636
# if we want fancy indexing (which we do), we need numpy
@@ -831,7 +831,9 @@ def from_signature(
831831
initial_soqs: Dict[str, SoquetT] = {}
832832
for reg in signature:
833833
if reg.side & Side.LEFT:
834-
initial_soqs[reg.name] = bb.add_register_from_dtype(reg)
834+
register = bb.add_register_from_dtype(reg)
835+
assert register is not None
836+
initial_soqs[reg.name] = register
835837
else:
836838
bb.add_register_from_dtype(reg)
837839

@@ -866,7 +868,11 @@ def _new_binst_i(self) -> int:
866868
return i
867869

868870
def _add_cxn(
869-
self, binst: BloqInstance, idxed_soq: Soquet, reg: Register, idx: Tuple[int, ...]
871+
self,
872+
binst: Union[BloqInstance, DanglingT],
873+
idxed_soq: Soquet,
874+
reg: Register,
875+
idx: Tuple[int, ...],
870876
) -> None:
871877
"""Helper function to be used as the base for the `func` argument of `_process_soquets`.
872878
@@ -1079,7 +1085,7 @@ def _fin(idxed_soq: Soquet, reg: Register, idx: Tuple[int, ...]):
10791085
connections=self._cxns, signature=signature, bloq_instances=self._binsts
10801086
)
10811087

1082-
def allocate(self, n: int = 1, dtype: Optional[QDType] = None) -> Soquet:
1088+
def allocate(self, n: Union[int, sympy.Expr] = 1, dtype: Optional[QDType] = None) -> Soquet:
10831089
from qualtran.bloqs.util_bloqs import Allocate
10841090

10851091
if dtype is not None:

qualtran/_infra/data_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
8989
debug_str: Optional debugging information to use in exception messages.
9090
"""
9191

92+
def iteration_length_or_zero(self) -> Union[int, sympy.Expr]:
93+
"""Safe version of iteration length.
94+
95+
Returns the iteration_length if the type has it or else zero.
96+
"""
97+
return getattr(self, 'iteration_length', 0)
98+
9299
def assert_valid_classical_val_array(self, val_array: NDArray[Any], debug_str: str = 'val'):
93100
"""Raises an exception if `val_array` is not a valid array of classical values
94101
for this type.

qualtran/_infra/data_types_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_validation_errs():
135135
QBit().assert_valid_classical_val(-1)
136136

137137
with pytest.raises(ValueError):
138-
QBit().assert_valid_classical_val('|0>')
138+
QBit().assert_valid_classical_val('|0>') # type: ignore[arg-type]
139139

140140
with pytest.raises(ValueError):
141141
QUInt(3).assert_valid_classical_val(8)

qualtran/_infra/gate_with_registers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import abc
1616
from typing import (
17+
Any,
1718
Collection,
1819
Dict,
1920
Iterable,
@@ -36,7 +37,9 @@
3637
from qualtran._infra.registers import Register, Side
3738

3839
if TYPE_CHECKING:
39-
from qualtran import CtrlSpec
40+
import quimb.tensor as qtn
41+
42+
from qualtran import CtrlSpec, SoquetT
4043
from qualtran.cirq_interop import CirqQuregT
4144
from qualtran.drawing import WireSymbol
4245

@@ -355,7 +358,7 @@ def on_registers(
355358
) -> cirq.Operation:
356359
return self.on(*merge_qubits(self.signature, **qubit_regs))
357360

358-
def __pow__(self, power: int) -> 'GateWithRegisters':
361+
def __pow__(self, power: int) -> 'Bloq':
359362
bloq = self if power > 0 else self.adjoint()
360363
if abs(power) == 1:
361364
return bloq

qualtran/_infra/quantum_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _singleton_error(self, x):
129129
raise ValueError("Do not instantiate a new DanglingT. Use `LeftDangle` or `RightDangle`.")
130130

131131

132-
DanglingT.__init__ = _singleton_error
132+
DanglingT.__init__ = _singleton_error # type: ignore[method-assign]
133133

134134

135135
@frozen

qualtran/bloqs/arithmetic/addition.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import itertools
1515
import math
1616
from functools import cached_property
17-
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
17+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
1818

1919
import cirq
2020
import numpy as np
@@ -253,7 +253,7 @@ def _add_diff_size_regs() -> Add:
253253

254254

255255
@frozen
256-
class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate):
256+
class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[misc]
257257
r"""An n-bit addition gate.
258258
259259
Implements $U|a\rangle|b\rangle 0\rangle \rightarrow |a\rangle|b\rangle|a+b\rangle$
@@ -310,11 +310,14 @@ def decompose_from_registers(
310310
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
311311
) -> cirq.OP_TREE:
312312
a, b, c = quregs['a'][::-1], quregs['b'][::-1], quregs['c'][::-1]
313-
optree = [
313+
optree: List[List[cirq.Operation]] = [
314314
[
315-
[cirq.CX(a[i], b[i]), cirq.CX(a[i], c[i])],
315+
cirq.CX(a[i], b[i]),
316+
cirq.CX(a[i], c[i]),
316317
And().on(b[i], c[i], c[i + 1]),
317-
[cirq.CX(a[i], b[i]), cirq.CX(a[i], c[i + 1]), cirq.CX(b[i], c[i])],
318+
cirq.CX(a[i], b[i]),
319+
cirq.CX(a[i], c[i + 1]),
320+
cirq.CX(b[i], c[i]),
318321
]
319322
for i in range(self.bitsize)
320323
]
@@ -418,13 +421,13 @@ def on_classical_vals(
418421
else:
419422
return {'x': x + self.k}
420423

421-
if (self.cvs == ctrls).all():
424+
if np.all(self.cvs == ctrls):
422425
x = x + self.k
423426

424427
return {'ctrls': ctrls, 'x': x}
425428

426429
def build_composite_bloq(
427-
self, bb: 'BloqBuilder', x: SoquetT, **regs: SoquetT
430+
self, bb: 'BloqBuilder', x: Soquet, **regs: SoquetT
428431
) -> Dict[str, 'SoquetT']:
429432
# Assign registers to variables and allocate ancilla bits for classical integer k.
430433
if len(self.cvs) > 0:
@@ -444,7 +447,7 @@ def build_composite_bloq(
444447
# controlled.
445448
for i in range(self.bitsize):
446449
if binary_rep[i] == 1:
447-
if len(self.cvs) > 0:
450+
if len(self.cvs) > 0 and ctrls is not None:
448451
ctrls, k_split[i] = bb.add(
449452
MultiControlX(cvs=self.cvs), ctrls=ctrls, x=k_split[i]
450453
)
@@ -453,14 +456,18 @@ def build_composite_bloq(
453456

454457
# Rejoin the qubits representing k for in-place addition.
455458
k = bb.join(k_split, dtype=x.reg.dtype)
459+
if not isinstance(x.reg.dtype, (QInt, QUInt, QMontgomeryUInt)):
460+
raise ValueError(
461+
"Only QInt, QUInt and QMontgomerUInt types are supported for composite addition."
462+
)
456463
k, x = bb.add(Add(x.reg.dtype, x.reg.dtype), a=k, b=x)
457464

458465
# Resplit the k qubits in order to undo the original bit flips to go from the binary
459466
# representation back to the zero state.
460467
k_split = bb.split(k)
461468
for i in range(self.bitsize):
462469
if binary_rep[i] == 1:
463-
if len(self.cvs) > 0:
470+
if len(self.cvs) > 0 and ctrls is not None:
464471
ctrls, k_split[i] = bb.add(
465472
MultiControlX(cvs=self.cvs), ctrls=ctrls, x=k_split[i]
466473
)
@@ -472,7 +479,7 @@ def build_composite_bloq(
472479
bb.free(k)
473480

474481
# Return the output registers.
475-
if len(self.cvs) > 0:
482+
if len(self.cvs) > 0 and ctrls is not None:
476483
return {'ctrls': ctrls, 'x': x}
477484
else:
478485
return {'x': x}
@@ -499,7 +506,7 @@ def _simple_add_k_large() -> SimpleAddConstant:
499506

500507

501508
@frozen(auto_attribs=True)
502-
class AddConstantMod(GateWithRegisters, cirq.ArithmeticGate):
509+
class AddConstantMod(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[misc]
503510
"""Applies U(add, M)|x> = |(x + add) % M> if x < M else |x>.
504511
505512
Applies modular addition to input register `|x>` given parameters `mod` and `add_val` s.t.

0 commit comments

Comments
 (0)