Skip to content

Commit 03d09a0

Browse files
authored
Remove short_name method (#934)
* Remove short_name method - Replace with wire_symbol(reg=None) - Gates with no title label will now return Text('') and not print out a title.
1 parent a4a0f92 commit 03d09a0

77 files changed

Lines changed: 459 additions & 334 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

qualtran/_infra/adjoint.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, List, Set, Tuple, TYPE_CHECKING
16+
from typing import cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
1717

1818
import cirq
1919
from attrs import frozen
@@ -170,10 +170,6 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
170170
"""The call graph takes the adjoint of each of the bloqs in `subbloq`'s call graph."""
171171
return {(bloq.adjoint(), n) for bloq, n in self.subbloq.build_call_graph(ssa=ssa)}
172172

173-
def short_name(self) -> str:
174-
"""The subbloq's short_name with a dagger."""
175-
return self.subbloq.short_name() + '†'
176-
177173
def pretty_name(self) -> str:
178174
"""The subbloq's pretty_name with a dagger."""
179175
return self.subbloq.pretty_name() + '†'
@@ -182,10 +178,17 @@ def __str__(self) -> str:
182178
"""Delegate to subbloq's `__str__` method."""
183179
return f'Adjoint(subbloq={str(self.subbloq)})'
184180

185-
def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
181+
def wire_symbol(
182+
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
183+
) -> 'WireSymbol':
186184
# Note: since we pass are passed a soquet which has the 'new' side, we flip it before
187185
# delegating and then flip back. Subbloqs only have to answer this protocol
188186
# if the provided soquet is facing the correct direction.
187+
from qualtran.drawing import Text
188+
189+
if reg is None:
190+
return Text(cast(Text, self.subbloq.wire_symbol(reg=None)).text + '†')
191+
189192
return self.subbloq.wire_symbol(reg=reg.adjoint(), idx=idx).adjoint()
190193

191194
def _t_complexity_(self):

qualtran/_infra/adjoint_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import Dict, TYPE_CHECKING
15+
from typing import cast, Dict, TYPE_CHECKING
1616

1717
import pytest
1818
import sympy
@@ -25,7 +25,7 @@
2525
from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph
2626
from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo
2727
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
28-
from qualtran.drawing import LarrowTextBox, RarrowTextBox
28+
from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text
2929

3030
if TYPE_CHECKING:
3131
from qualtran import BloqBuilder, SoquetT
@@ -149,11 +149,11 @@ def test_call_graph():
149149
def test_names():
150150
atom = TestAtom()
151151
assert atom.pretty_name() == "TestAtom"
152-
assert atom.short_name() == "Atom"
152+
assert cast(Text, atom.wire_symbol(reg=None)).text == "TestAtom"
153153

154154
adj_atom = Adjoint(atom)
155155
assert adj_atom.pretty_name() == "TestAtom†"
156-
assert adj_atom.short_name() == "Atom†"
156+
assert cast(Text, adj_atom.wire_symbol(reg=None)).text == "TestAtom†"
157157
assert str(adj_atom) == "Adjoint(subbloq=TestAtom())"
158158

159159

qualtran/_infra/bloq.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,6 @@ def signature(self) -> 'Signature':
108108
def pretty_name(self) -> str:
109109
return self.__class__.__name__
110110

111-
def short_name(self) -> str:
112-
name = self.pretty_name()
113-
if len(name) <= 10:
114-
return name
115-
116-
return name[:8] + '..'
117-
118111
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
119112
"""Override this method to define a Bloq in terms of its constituent parts.
120113
@@ -282,7 +275,9 @@ def add_my_tensors(
282275
from qualtran.simulation.tensor import cbloq_as_contracted_tensor
283276

284277
cbloq = self.decompose_bloq()
285-
tn.add(cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.short_name(), tag]))
278+
tn.add(
279+
cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.pretty_name(), tag])
280+
)
286281

287282
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
288283
"""Override this method to build the bloq call graph.
@@ -508,18 +503,30 @@ def on_registers(
508503

509504
return self.on(*merge_qubits(self.signature, **qubit_regs))
510505

511-
def wire_symbol(self, reg: 'Register', idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
506+
def wire_symbol(
507+
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
508+
) -> 'WireSymbol':
512509
"""On a musical score visualization, use this `WireSymbol` to represent `soq`.
513510
514511
By default, we use a "directional text box", which is a text box that is either
515512
rectangular for thru-registers or facing to the left or right for non-thru-registers.
516513
514+
If reg is specified as `None`, this should return a Text label for the title of
515+
the gate. If no title is needed (as the wire_symbols are self-explanatory),
516+
this should return `Text('')`.
517+
517518
Override this method to provide a more relevant `WireSymbol` for the provided soquet.
518519
This method can access bloq attributes. For example: you may want to draw either
519520
a filled or empty circle for a control register depending on a control value bloq
520521
attribute.
521522
"""
522-
from qualtran.drawing import directional_text_box
523+
from qualtran.drawing import directional_text_box, Text
524+
525+
if reg is None:
526+
name = self.pretty_name()
527+
if len(name) <= 10:
528+
return Text(name)
529+
return Text(name[:8] + '..')
523530

524531
label = reg.name
525532
if len(idx) > 0:

qualtran/_infra/controlled.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def add_my_tensors(
418418
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
419419
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
420420
# Add the data to the tensor network.
421-
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.short_name(), tag]))
421+
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.pretty_name(), tag]))
422422

423423
def _unitary_(self):
424424
if isinstance(self.subbloq, GateWithRegisters):
@@ -433,11 +433,13 @@ def _unitary_(self):
433433
# Unable to determine the unitary effect.
434434
return NotImplemented
435435

436-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
436+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
437+
from qualtran.drawing import Text
438+
439+
if reg is None:
440+
return Text(f'C[{self.subbloq.wire_symbol(reg=None)}]')
437441
if reg.name not in self.ctrl_reg_names:
438442
# Delegate to subbloq
439-
print(self.subbloq)
440-
print(type(self.subbloq))
441443
return self.subbloq.wire_symbol(reg, idx)
442444

443445
# Otherwise, it's part of the control register.
@@ -450,9 +452,6 @@ def adjoint(self) -> 'Bloq':
450452
def pretty_name(self) -> str:
451453
return f'C[{self.subbloq.pretty_name()}]'
452454

453-
def short_name(self) -> str:
454-
return f'C[{self.subbloq.short_name()}]'
455-
456455
def __str__(self) -> str:
457456
return f'C[{self.subbloq}]'
458457

qualtran/_infra/gate_with_registers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,12 @@ def as_cirq_op(
310310
)
311311
return self.on_registers(**all_quregs), out_quregs
312312

313-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
313+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
314314
from qualtran.cirq_interop._cirq_to_bloq import _wire_symbol_from_gate
315+
from qualtran.drawing import Text
316+
317+
if reg is None:
318+
return Text(self.pretty_name())
315319

316320
return _wire_symbol_from_gate(self, self.signature, reg, idx)
317321

@@ -515,13 +519,7 @@ def add_my_tensors(
515519
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate
516520

517521
_add_my_tensors_from_gate(
518-
self,
519-
self.signature,
520-
self.short_name(),
521-
tn,
522-
tag,
523-
incoming=incoming,
524-
outgoing=outgoing,
522+
self, self.signature, str(self), tn, tag, incoming=incoming, outgoing=outgoing
525523
)
526524
else:
527525
return super().add_my_tensors(tn, tag, incoming=incoming, outgoing=outgoing)

qualtran/bloqs/arithmetic/addition.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def add_my_tensors(
142142
for a, b in itertools.product(range(N_a), range(N_b)):
143143
unitary[a, b, a, int(math.fmod(a + b, N_b))] = 1
144144

145-
tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.short_name(), tag]))
145+
tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.pretty_name(), tag]))
146146

147147
def decompose_bloq(self) -> 'CompositeBloq':
148148
return decompose_from_cirq_style_method(self)
@@ -155,17 +155,16 @@ def on_classical_vals(
155155
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
156156
return {'a': a, 'b': int(math.fmod(a + b, N))}
157157

158-
def short_name(self) -> str:
159-
return "a+b"
160-
161158
def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
162159
wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize)
163160
wire_symbols += ["In(y)/Out(x+y)"] * int(self.b_dtype.bitsize)
164161
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
165162

166-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
167-
from qualtran.drawing import directional_text_box
163+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
164+
from qualtran.drawing import directional_text_box, Text
168165

166+
if reg is None:
167+
return Text("a+b")
169168
if reg.name == 'a':
170169
return directional_text_box('a', side=reg.side)
171170
elif reg.name == 'b':
@@ -318,7 +317,7 @@ def on_classical_vals(
318317
def with_registers(self, *new_registers: Union[int, Sequence[int]]):
319318
raise NotImplementedError("no need to implement with_registers.")
320319

321-
def short_name(self) -> str:
320+
def pretty_name(self) -> str:
322321
return "c = a + b"
323322

324323
def decompose_from_registers(
@@ -501,7 +500,7 @@ def build_composite_bloq(
501500
else:
502501
return {'x': x}
503502

504-
def short_name(self) -> str:
503+
def pretty_name(self) -> str:
505504
return f'x += {self.k}'
506505

507506

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,18 @@
1313
# limitations under the License.
1414

1515
from functools import cached_property
16-
from typing import Dict, Iterable, Iterator, List, Sequence, Set, Tuple, TYPE_CHECKING, Union
16+
from typing import (
17+
Dict,
18+
Iterable,
19+
Iterator,
20+
List,
21+
Optional,
22+
Sequence,
23+
Set,
24+
Tuple,
25+
TYPE_CHECKING,
26+
Union,
27+
)
1728

1829
import attrs
1930
import cirq
@@ -42,7 +53,7 @@
4253
from qualtran.cirq_interop.bit_tools import iter_bits
4354
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
4455
from qualtran.drawing import WireSymbol
45-
from qualtran.drawing.musical_score import TextBox
56+
from qualtran.drawing.musical_score import Text, TextBox
4657

4758
if TYPE_CHECKING:
4859
from qualtran import BloqBuilder
@@ -62,8 +73,12 @@ class LessThanConstant(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[
6273
def signature(self) -> Signature:
6374
return Signature.build_from_dtypes(x=QUInt(self.bitsize), target=QBit())
6475

65-
def short_name(self) -> str:
66-
return f'x<{self.less_than_val}'
76+
def wire_symbol(
77+
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
78+
) -> 'WireSymbol':
79+
if reg is None:
80+
return Text(f'x<{self.less_than_val}')
81+
return super().wire_symbol(reg, idx)
6782

6883
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
6984
return [2] * self.bitsize, self.less_than_val, [2]
@@ -428,8 +443,12 @@ def apply(self, *register_vals: int) -> Union[int, int, Iterable[int]]:
428443
x_val, y_val, target_val = register_vals
429444
return x_val, y_val, target_val ^ (x_val <= y_val)
430445

431-
def short_name(self) -> str:
432-
return 'x <= y'
446+
def wire_symbol(
447+
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
448+
) -> 'WireSymbol':
449+
if reg is None:
450+
return Text('x <= y')
451+
return super().wire_symbol(reg, idx)
433452

434453
def on_classical_vals(self, *, x: int, y: int, target: int) -> Dict[str, 'ClassicalValT']:
435454
return {'x': x, 'y': y, 'target': target ^ (x <= y)}
@@ -599,16 +618,15 @@ def signature(self):
599618
a=QUInt(self.a_bitsize), b=QUInt(self.b_bitsize), target=QBit()
600619
)
601620

602-
def short_name(self) -> str:
603-
return "a>b"
604-
605621
def _t_complexity_(self) -> 'TComplexity':
606622
# TODO Determine precise clifford count and/or ignore.
607623
# See: https://github.com/quantumlib/Qualtran/issues/219
608624
# See: https://github.com/quantumlib/Qualtran/issues/217
609625
return t_complexity(LessThanEqual(self.a_bitsize, self.b_bitsize))
610626

611-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
627+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
628+
if reg is None:
629+
return Text("a>b")
612630
if reg.name == 'a':
613631
return TextBox("In(a)")
614632
if reg.name == 'b':
@@ -799,8 +817,12 @@ def build_composite_bloq(
799817
# Return the output registers.
800818
return {'a': a, 'b': b, 'target': target}
801819

802-
def short_name(self) -> str:
803-
return "a > b"
820+
def wire_symbol(
821+
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
822+
) -> 'WireSymbol':
823+
if reg is None:
824+
return Text('a > b')
825+
return super().wire_symbol(reg, idx)
804826

805827

806828
@frozen
@@ -836,10 +858,9 @@ def _t_complexity_(self) -> TComplexity:
836858
# See: https://github.com/quantumlib/Qualtran/issues/217
837859
return t_complexity(LessThanConstant(self.bitsize, less_than_val=self.val))
838860

839-
def short_name(self) -> str:
840-
return f"x > {self.val}"
841-
842-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
861+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
862+
if reg is None:
863+
return Text(f"x > {self.val}")
843864
if reg.name == 'x':
844865
return TextBox("In(x)")
845866
elif reg.name == 'target':
@@ -889,10 +910,9 @@ def signature(self) -> Signature:
889910
def _t_complexity_(self) -> 'TComplexity':
890911
return TComplexity(t=4 * (self.bitsize - 1))
891912

892-
def short_name(self) -> str:
893-
return f"x == {self.val}"
894-
895-
def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> WireSymbol:
913+
def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
914+
if reg is None:
915+
return Text(f"x == {self.val}")
896916
if reg.name == 'x':
897917
return TextBox("In(x)")
898918
elif reg.name == 'target':

0 commit comments

Comments
 (0)