Skip to content

Commit d2eb86a

Browse files
authored
Change wire_symbol to use Register and idx (#919)
* Change wire_symbol to use Register and idx - Change wire_symbol to use Register and idx since the wire_symbol should never need to use the binst of a Soquet.
1 parent 0262880 commit d2eb86a

32 files changed

Lines changed: 174 additions & 179 deletions

qualtran/_infra/adjoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from functools import cached_property
1616
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
1717

18-
import attrs
1918
import cirq
2019
from attrs import frozen
2120
from numpy.typing import NDArray
@@ -26,7 +25,7 @@
2625
from .registers import Signature
2726

2827
if TYPE_CHECKING:
29-
from qualtran import Bloq, CompositeBloq, Signature, Soquet, SoquetT
28+
from qualtran import Bloq, CompositeBloq, Register, Signature, Soquet, SoquetT
3029
from qualtran.drawing import WireSymbol
3130
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
3231

@@ -183,11 +182,11 @@ def __str__(self) -> str:
183182
"""Delegate to subbloq's `__str__` method."""
184183
return f'Adjoint(subbloq={str(self.subbloq)})'
185184

186-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
185+
def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
187186
# Note: since we pass are passed a soquet which has the 'new' side, we flip it before
188187
# delegating and then flip back. Subbloqs only have to answer this protocol
189188
# if the provided soquet is facing the correct direction.
190-
return self.subbloq.wire_symbol(attrs.evolve(soq, reg=soq.reg.adjoint())).adjoint()
189+
return self.subbloq.wire_symbol(reg=reg.adjoint(), idx=idx).adjoint()
191190

192191
def _t_complexity_(self):
193192
"""The cirq-style `_t_complexity_` delegates to the subbloq's method with a special shim.

qualtran/_infra/adjoint_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sympy
1919

2020
import qualtran.testing as qlt_testing
21-
from qualtran import Adjoint, Bloq, BloqInstance, CompositeBloq, Side, Signature, Soquet
21+
from qualtran import Adjoint, Bloq, CompositeBloq, Side, Signature
2222
from qualtran._infra.adjoint import _adjoint_cbloq
2323
from qualtran.bloqs.basic_gates import CNOT, CSwap, ZeroState
2424
from qualtran.bloqs.for_testing.atom import TestAtom
@@ -162,11 +162,8 @@ def test_wire_symbol():
162162
(reg,) = zero.signature
163163
adj = Adjoint(zero) # specifically use the Adjoint wrapper for testing
164164

165-
# TODO: Remove binst variable. These BloqInstances are for typing only
166-
# and are not really used by the function.
167-
# See https://github.com/quantumlib/Qualtran/issues/608
168-
ws = zero.wire_symbol(Soquet(BloqInstance(CNOT(), 1), reg))
169-
adj_ws = adj.wire_symbol(Soquet(BloqInstance(CNOT(), 2), reg.adjoint()))
165+
ws = zero.wire_symbol(reg)
166+
adj_ws = adj.wire_symbol(reg.adjoint())
170167
assert isinstance(ws, LarrowTextBox)
171168
assert isinstance(adj_ws, RarrowTextBox)
172169

qualtran/_infra/bloq.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
BloqBuilder,
3131
CompositeBloq,
3232
CtrlSpec,
33+
Register,
3334
Signature,
3435
Soquet,
3536
SoquetT,
@@ -503,7 +504,7 @@ def on_registers(
503504

504505
return self.on(*merge_qubits(self.signature, **qubit_regs))
505506

506-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
507+
def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
507508
"""On a musical score visualization, use this `WireSymbol` to represent `soq`.
508509
509510
By default, we use a "directional text box", which is a text box that is either
@@ -516,4 +517,10 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
516517
"""
517518
from qualtran.drawing import directional_text_box
518519

519-
return directional_text_box(text=soq.pretty(), side=soq.reg.side)
520+
label = reg.name
521+
if len(idx) > 0:
522+
pretty_str = f'{label}[{", ".join(str(i) for i in idx)}]'
523+
else:
524+
pretty_str = label
525+
526+
return directional_text_box(text=pretty_str, side=reg.side)

qualtran/_infra/controlled.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,15 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
166166
return False
167167
return True
168168

169-
def wire_symbol(self, i: int, soq: 'Soquet') -> 'WireSymbol':
169+
def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
170170
# Return a circle for bits; a box otherwise.
171171
from qualtran.drawing import Circle, TextBox
172172

173-
if soq.reg.bitsize == 1:
174-
cv = self.cvs[i][soq.idx]
173+
if reg.bitsize == 1:
174+
cv = self.cvs[i][idx]
175175
return Circle(filled=(cv == 1))
176176

177-
cv = self.cvs[i][soq.idx]
177+
cv = self.cvs[i][idx]
178178
return TextBox(f'{cv}')
179179

180180
@cached_property
@@ -431,14 +431,16 @@ def _unitary_(self):
431431
# Unable to determine the unitary effect.
432432
return NotImplemented
433433

434-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
435-
if soq.reg.name not in self.ctrl_reg_names:
434+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
435+
if reg.name not in self.ctrl_reg_names:
436436
# Delegate to subbloq
437-
return self.subbloq.wire_symbol(soq)
437+
print(self.subbloq)
438+
print(type(self.subbloq))
439+
return self.subbloq.wire_symbol(reg, idx)
438440

439441
# Otherwise, it's part of the control register.
440-
i = self.ctrl_reg_names.index(soq.reg.name)
441-
return self.ctrl_spec.wire_symbol(i, soq)
442+
i = self.ctrl_reg_names.index(reg.name)
443+
return self.ctrl_spec.wire_symbol(i, reg, idx)
442444

443445
def adjoint(self) -> 'Bloq':
444446
return self.subbloq.adjoint().controlled(self.ctrl_spec)

qualtran/_infra/gate_with_registers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from qualtran._infra.bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
3535
from qualtran._infra.composite_bloq import CompositeBloq
36-
from qualtran._infra.quantum_graph import Soquet
3736
from qualtran._infra.registers import Register, Side
3837

3938
if TYPE_CHECKING:
@@ -309,10 +308,10 @@ def as_cirq_op(
309308
)
310309
return self.on_registers(**all_quregs), out_quregs
311310

312-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
311+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
313312
from qualtran.cirq_interop._cirq_to_bloq import _wire_symbol_from_gate
314313

315-
return _wire_symbol_from_gate(self, self.signature, soq)
314+
return _wire_symbol_from_gate(self, self.signature, reg, idx)
316315

317316
# Part-2: Cirq-FT style interface can be used to implemented algorithms by Bloq authors.
318317

qualtran/bloqs/arithmetic/addition.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
150150
wire_symbols += ["In(y)/Out(x+y)"] * int(self.b_dtype.bitsize)
151151
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
152152

153-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
153+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
154154
from qualtran.drawing import directional_text_box
155155

156-
if soq.reg.name == 'a':
157-
return directional_text_box('a', side=soq.reg.side)
158-
elif soq.reg.name == 'b':
159-
return directional_text_box('a+b', side=soq.reg.side)
156+
if reg.name == 'a':
157+
return directional_text_box('a', side=reg.side)
158+
elif reg.name == 'b':
159+
return directional_text_box('a+b', side=reg.side)
160160
else:
161161
raise ValueError()
162162

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, Iterable, Iterator, List, Sequence, Set, TYPE_CHECKING, Union
16+
from typing import Dict, Iterable, Iterator, List, Sequence, Set, Tuple, TYPE_CHECKING, Union
1717

1818
import attrs
1919
import cirq
@@ -599,14 +599,14 @@ def _t_complexity_(self) -> 'TComplexity':
599599
# See: https://github.com/quantumlib/Qualtran/issues/217
600600
return t_complexity(LessThanEqual(self.a_bitsize, self.b_bitsize))
601601

602-
def wire_symbol(self, soq: Soquet) -> WireSymbol:
603-
if soq.reg.name == 'a':
602+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
603+
if reg.name == 'a':
604604
return TextBox("In(a)")
605-
if soq.reg.name == 'b':
605+
if reg.name == 'b':
606606
return TextBox("In(b)")
607-
elif soq.reg.name == 'target':
607+
elif reg.name == 'target':
608608
return TextBox("⨁(a > b)")
609-
raise ValueError(f'Unknown register name {soq.reg.name}')
609+
raise ValueError(f'Unknown register name {reg.name}')
610610

611611
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
612612
# TODO Determine precise clifford count and/or ignore.
@@ -677,7 +677,6 @@ def on_classical_vals(
677677
def build_composite_bloq(
678678
self, bb: 'BloqBuilder', a: Soquet, b: Soquet, target: SoquetT
679679
) -> Dict[str, 'SoquetT']:
680-
681680
# Base Case: Comparing two qubits.
682681
# Signed doesn't matter because we can't represent signed integers with 1 qubit.
683682
if self.bitsize == 1:
@@ -831,12 +830,12 @@ def _t_complexity_(self) -> TComplexity:
831830
def short_name(self) -> str:
832831
return f"x > {self.val}"
833832

834-
def wire_symbol(self, soq: Soquet) -> WireSymbol:
835-
if soq.reg.name == 'x':
833+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
834+
if reg.name == 'x':
836835
return TextBox("In(x)")
837-
elif soq.reg.name == 'target':
836+
elif reg.name == 'target':
838837
return TextBox(f"⨁(x > {self.val})")
839-
raise ValueError(f'Unknown register symbol {soq.reg.name}')
838+
raise ValueError(f'Unknown register symbol {reg.name}')
840839

841840
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
842841
# TODO Determine precise clifford count and/or ignore.
@@ -884,12 +883,12 @@ def _t_complexity_(self) -> 'TComplexity':
884883
def short_name(self) -> str:
885884
return f"x == {self.val}"
886885

887-
def wire_symbol(self, soq: Soquet) -> WireSymbol:
888-
if soq.reg.name == 'x':
886+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
887+
if reg.name == 'x':
889888
return TextBox("In(x)")
890-
elif soq.reg.name == 'target':
889+
elif reg.name == 'target':
891890
return TextBox(f"⨁(x = {self.val})")
892-
raise ValueError(f'Unknown register symbol {soq.reg.name}')
891+
raise ValueError(f'Unknown register symbol {reg.name}')
893892

894893
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
895894
# See: https://github.com/quantumlib/Qualtran/issues/219

qualtran/bloqs/arithmetic/conversions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, Set, TYPE_CHECKING
16+
from typing import Dict, Set, Tuple, TYPE_CHECKING
1717

1818
from attrs import frozen
1919

@@ -28,7 +28,6 @@
2828
Side,
2929
Signature,
3030
)
31-
from qualtran._infra.quantum_graph import Soquet
3231
from qualtran.bloqs.basic_gates import Toffoli
3332
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
3433
from qualtran.drawing import WireSymbol
@@ -89,10 +88,10 @@ def _t_complexity_(self) -> 'TComplexity':
8988
num_toffoli = self.bitsize**2 + self.bitsize - 1
9089
return TComplexity(t=4 * num_toffoli)
9190

92-
def wire_symbol(self, soq: Soquet) -> WireSymbol:
93-
if soq.reg.name == 'mu':
91+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
92+
if reg.name == 'mu':
9493
return TextBox(r'$\mu$')
95-
elif soq.reg.name == 'nu':
94+
elif reg.name == 'nu':
9695
return TextBox(r'$\mu$')
9796
else:
9897
text = r'$\oplus\nu(\nu-1)/2+\mu$'

qualtran/bloqs/basic_gates/cnot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
CompositeBloq,
3030
CtrlSpec,
3131
DecomposeTypeError,
32+
Register,
3233
Signature,
33-
Soquet,
3434
SoquetT,
3535
)
3636
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
@@ -131,12 +131,12 @@ def as_cirq_op(
131131
(target,) = target
132132
return cirq.CNOT(ctrl, target), {'ctrl': np.array([ctrl]), 'target': np.array([target])}
133133

134-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
135-
if soq.reg.name == 'ctrl':
134+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
135+
if reg.name == 'ctrl':
136136
return Circle(filled=True)
137-
elif soq.reg.name == 'target':
137+
elif reg.name == 'target':
138138
return ModPlus()
139-
raise ValueError(f'Bad wire symbol soquet: {soq}')
139+
raise ValueError(f'Unknown wire symbol register name: {reg.name}')
140140

141141
def _t_complexity_(self) -> 'TComplexity':
142142
return TComplexity(clifford=1)

qualtran/bloqs/basic_gates/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
BloqDocSpec,
2525
CompositeBloq,
2626
DecomposeTypeError,
27+
Register,
2728
Signature,
28-
Soquet,
2929
SoquetT,
3030
)
3131
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
@@ -95,7 +95,7 @@ def _t_complexity_(self):
9595
def short_name(self) -> 'str':
9696
return 'H'
9797

98-
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
98+
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
9999
return TextBox('H')
100100

101101

0 commit comments

Comments
 (0)