Skip to content

Commit 6cd49f1

Browse files
authored
Enable mypy CI check (#926)
* Enable mypy CI check - This enables the mypy CI check - It also fixes the mypy issues introduced in the past few days.
1 parent 54bc181 commit 6cd49f1

50 files changed

Lines changed: 198 additions & 146 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,20 @@ jobs:
9999
pip install --no-deps -e .
100100
- run: |
101101
check/pylint
102+
103+
mypy:
104+
runs-on: ubuntu-latest
105+
steps:
106+
- uses: actions/checkout@v3
107+
with:
108+
fetch-depth: 0
109+
- uses: actions/setup-python@v4
110+
with:
111+
python-version: "3.10"
112+
- name: Install dependencies
113+
run: |
114+
python -m pip install --upgrade pip
115+
pip install -r dev_tools/requirements/envs/dev.env.txt
116+
pip install --no-deps -e .
117+
- run: |
118+
check/mypy

dev_tools/conf/mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ show_error_codes = true
33
plugins = duet.typing, numpy.typing.mypy_plugin
44
allow_redefinition = true
55
check_untyped_defs = true
6+
67
# Disabling function override checking
78
# Qualtran has many places where kwargs are used
89
# with the intention to override in subclasses in ways mypy does not like

qualtran/_infra/bloq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828
from qualtran import (
2929
AddControlledT,
30+
Adjoint,
3031
BloqBuilder,
3132
CompositeBloq,
3233
CtrlSpec,
34+
GateWithRegisters,
3335
Register,
3436
Signature,
3537
Soquet,

qualtran/_infra/controlled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
445445
return self.ctrl_spec.wire_symbol(i, reg, idx)
446446

447447
def adjoint(self) -> 'Bloq':
448-
return self.subbloq.adjoint().controlled(self.ctrl_spec)
448+
return self.subbloq.adjoint().controlled(ctrl_spec=self.ctrl_spec)
449449

450450
def pretty_name(self) -> str:
451451
return f'C[{self.subbloq.pretty_name()}]'

qualtran/_infra/controlled_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def test_controlled_diagrams():
415415
q: ──────X^0.25───''',
416416
)
417417

418-
ctrl_0_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=0))
418+
ctrl_0_gate = XPowGate(0.25).controlled(ctrl_spec=CtrlSpec(cvs=0))
419419
cirq.testing.assert_has_diagram(
420420
cirq.Circuit(ctrl_0_gate.on_registers(**get_named_qubits(ctrl_0_gate.signature))),
421421
'''
@@ -424,7 +424,7 @@ def test_controlled_diagrams():
424424
q: ──────X^0.25───''',
425425
)
426426

427-
multi_ctrl_gate = XPowGate(0.25).controlled(CtrlSpec(cvs=[0, 1]))
427+
multi_ctrl_gate = XPowGate(0.25).controlled(ctrl_spec=CtrlSpec(cvs=[0, 1]))
428428
cirq.testing.assert_has_diagram(
429429
cirq.Circuit(multi_ctrl_gate.on_registers(**get_named_qubits(multi_ctrl_gate.signature))),
430430
'''
@@ -435,7 +435,7 @@ def test_controlled_diagrams():
435435
q: ─────────X^0.25───''',
436436
)
437437

438-
ctrl_bloq = Swap(2).controlled(CtrlSpec(cvs=[0, 1]))
438+
ctrl_bloq = Swap(2).controlled(ctrl_spec=CtrlSpec(cvs=[0, 1]))
439439
cirq.testing.assert_has_diagram(
440440
cirq.Circuit(ctrl_bloq.on_registers(**get_named_qubits(ctrl_bloq.signature))),
441441
'''

qualtran/_infra/gate_with_registers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import abc
1616
from typing import (
1717
Any,
18+
cast,
1819
Collection,
1920
Dict,
2021
Iterable,
@@ -357,8 +358,8 @@ def on_registers(
357358
) -> cirq.Operation:
358359
return self.on(*merge_qubits(self.signature, **qubit_regs))
359360

360-
def __pow__(self, power: int) -> 'Bloq':
361-
bloq = self if power > 0 else self.adjoint()
361+
def __pow__(self, power: int) -> 'GateWithRegisters':
362+
bloq = self if power > 0 else cast(GateWithRegisters, self.adjoint())
362363
if abs(power) == 1:
363364
return bloq
364365
if all(reg.side == Side.THRU for reg in self.signature):

qualtran/_infra/gate_with_registers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, TYPE_CHECKING
15+
from typing import Dict, Iterator, TYPE_CHECKING
1616

1717
import cirq
1818
import numpy as np
@@ -46,7 +46,7 @@ def signature(self) -> Signature:
4646
regs = Signature([r1, r2, r3])
4747
return regs
4848

49-
def decompose_from_registers(self, *, context, **quregs) -> cirq.OP_TREE:
49+
def decompose_from_registers(self, *, context, **quregs) -> Iterator[cirq.OP_TREE]:
5050
yield cirq.H.on_each(quregs['r1'])
5151
yield cirq.X.on_each(quregs['r2'])
5252
yield cirq.X.on_each(quregs['r3'])

qualtran/bloqs/arithmetic/addition.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414
import itertools
1515
import math
1616
from functools import cached_property
17-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
17+
from typing import (
18+
Any,
19+
Dict,
20+
Iterable,
21+
Iterator,
22+
List,
23+
Optional,
24+
Sequence,
25+
Set,
26+
Tuple,
27+
TYPE_CHECKING,
28+
Union,
29+
)
1830

1931
import cirq
2032
import numpy as np
@@ -197,7 +209,7 @@ def _right_building_block(self, inp, out, anc, depth):
197209

198210
def decompose_from_registers(
199211
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
200-
) -> cirq.OP_TREE:
212+
) -> Iterator[cirq.OP_TREE]:
201213
# reverse the order of qubits for big endian-ness.
202214
input_bits = quregs['a'][::-1]
203215
output_bits = quregs['b'][::-1]

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __pow__(self, power: int):
9090

9191
def decompose_from_registers(
9292
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
93-
) -> cirq.OP_TREE:
93+
) -> Iterator[cirq.OP_TREE]:
9494
"""Decomposes the gate into 4N And and And† operations for a T complexity of 4N.
9595
9696
The decomposition proceeds from the most significant qubit -bit 0- to the least significant
@@ -217,7 +217,7 @@ def signature(self) -> Signature:
217217

218218
def decompose_from_registers(
219219
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
220-
) -> cirq.OP_TREE:
220+
) -> Iterator[cirq.OP_TREE]:
221221
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
222222
x_msb, x_lsb = x
223223
y_msb, y_lsb = y
@@ -310,7 +310,7 @@ def signature(self) -> Signature:
310310

311311
def decompose_from_registers(
312312
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
313-
) -> cirq.OP_TREE:
313+
) -> Iterator[cirq.OP_TREE]:
314314
a = quregs['a']
315315
b = quregs['b']
316316
less_than = quregs['less_than']
@@ -358,7 +358,7 @@ def _sq_cmp() -> SingleQubitCompare:
358358

359359
def _equality_with_zero(
360360
context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid
361-
) -> cirq.OP_TREE:
361+
) -> Iterator[cirq.OP_TREE]:
362362
"""Helper decomposition used in `LessThanEqual`"""
363363
if len(qubits) == 1:
364364
(q,) = qubits
@@ -451,7 +451,7 @@ def __pow__(self, power: int):
451451

452452
def _decompose_via_tree(
453453
self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid]
454-
) -> cirq.OP_TREE:
454+
) -> Iterator[cirq.OP_TREE]:
455455
if len(X) == 1:
456456
return
457457
if len(X) == 2:
@@ -467,7 +467,7 @@ def _decompose_via_tree(
467467

468468
def decompose_from_registers(
469469
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
470-
) -> cirq.OP_TREE:
470+
) -> Iterator[cirq.OP_TREE]:
471471
lhs, rhs, (target,) = list(quregs['x']), list(quregs['y']), quregs['target']
472472

473473
n = min(len(lhs), len(rhs))

qualtran/bloqs/arithmetic/hamming_weight.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import List, Set, TYPE_CHECKING
16+
from typing import Iterator, List, Set, TYPE_CHECKING
1717

1818
import cirq
1919
from attrs import frozen
@@ -76,7 +76,7 @@ def _three_to_two_adder(self, a, b, c, out) -> cirq.OP_TREE:
7676

7777
def _decompose_using_three_to_two_adders(
7878
self, x: List[cirq.Qid], junk: List[cirq.Qid], out: List[cirq.Qid]
79-
) -> cirq.OP_TREE:
79+
) -> Iterator[cirq.OP_TREE]:
8080
for out_idx in range(len(out)):
8181
y = []
8282
for in_idx in range(0, len(x) - 2, 2):
@@ -94,7 +94,7 @@ def _decompose_using_three_to_two_adders(
9494

9595
def decompose_from_registers(
9696
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
97-
) -> cirq.OP_TREE:
97+
) -> Iterator[cirq.OP_TREE]:
9898
# Qubit order needs to be reversed because the registers store Big Endian representation
9999
# of integers.
100100
x: List[cirq.Qid] = [*quregs['x'][::-1]]

0 commit comments

Comments
 (0)