Skip to content

Commit 0338e4d

Browse files
committed
changes for reporting on costing
1 parent e18e8ad commit 0338e4d

5 files changed

Lines changed: 206 additions & 136 deletions

File tree

dev_tools/bloq-method-overrides-report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _call_graph(bc: Type[Bloq]):
4343
if annot['ssa'] != 'SympySymbolAllocator':
4444
print(f"{bc}.build_call_graph `ssa: 'SympySymbolAllocator'`")
4545
if annot['return'] != Set[ForwardRef('BloqCountT')]: # type: ignore[misc]
46-
print(f"{bc}.build_call_graph -> 'BloqCountT'")
46+
print(f"{bc}.build_call_graph -> Set['BloqCountT'], not {annot['return']}")
4747

4848

4949
def report_call_graph_methods():

dev_tools/qualtran_dev_tools/bloq_report_card.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pandas.io.formats.style
2020

2121
from qualtran import Bloq, BloqExample
22+
from qualtran.resource_counting import get_cost_value, QubitCount
23+
from qualtran.simulation.tensor import cbloq_to_quimb
2224
from qualtran.testing import (
2325
BloqCheckResult,
2426
check_bloq_example_decompose,
@@ -134,3 +136,44 @@ def summarize_results(report_card: pd.DataFrame) -> pd.DataFrame:
134136
)
135137
summary.columns = [v.name.lower() for v in summary.columns]
136138
return summary
139+
140+
141+
def report_on_tensors(name: str, cls_name: str, bloq: Bloq, cxn) -> None:
142+
"""Get timing information for tensor functionality.
143+
144+
This should be used with `ExecuteWithTimeout`. The resultant
145+
record dictionary is sent over `cxn`.
146+
"""
147+
record: Dict[str, Any] = {'name': name, 'cls': cls_name}
148+
149+
try:
150+
start = time.perf_counter()
151+
flat = bloq.as_composite_bloq().flatten()
152+
record['flat_dur'] = time.perf_counter() - start
153+
154+
start = time.perf_counter()
155+
tn = cbloq_to_quimb(flat)
156+
record['tn_dur'] = time.perf_counter() - start
157+
158+
start = time.perf_counter()
159+
record['width'] = tn.contraction_width()
160+
record['width_dur'] = time.perf_counter() - start
161+
162+
except Exception as e: # pylint: disable=broad-exception-caught
163+
record['err'] = str(e)
164+
165+
cxn.send(record)
166+
167+
168+
def report_on_cost_timings(name: str, cls_name: str, bloq: Bloq, cxn) -> None:
169+
record: Dict[str, Any] = {'name': name, 'cls': cls_name}
170+
171+
try:
172+
start = time.perf_counter()
173+
_ = get_cost_value(bloq, QubitCount())
174+
record['qubitcount_dur'] = time.perf_counter() - start
175+
176+
except Exception as e: # pylint: disable=broad-exception-caught
177+
record['err'] = str(e)
178+
179+
cxn.send(record)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import multiprocessing.connection
16+
import time
17+
from typing import Any, Callable, Dict, List, Optional, Tuple
18+
19+
from attrs import define
20+
21+
22+
@define
23+
class _Pending:
24+
"""Helper dataclass to track currently executing processes in `ExecuteWithTimeout`."""
25+
26+
p: multiprocessing.Process
27+
recv: multiprocessing.connection.Connection
28+
start_time: float
29+
kwargs: Dict[str, Any]
30+
31+
32+
class ExecuteWithTimeout:
33+
"""Execute tasks in processes where each task will be killed if it exceeds `timeout`.
34+
35+
Seemingly all the existing "timeout" parameters in the various built-in concurrency
36+
primitives in Python won't actually terminate the process. This one does.
37+
"""
38+
39+
def __init__(self, timeout: float, max_workers: int):
40+
self.timeout = timeout
41+
self.max_workers = max_workers
42+
43+
self.queued: List[Tuple[Callable, Dict[str, Any]]] = []
44+
self.pending: List[_Pending] = []
45+
46+
@property
47+
def work_to_be_done(self) -> int:
48+
"""The number of tasks currently executing or queued."""
49+
return len(self.queued) + len(self.pending)
50+
51+
def submit(self, func: Callable, kwargs: Dict[str, Any]) -> None:
52+
"""Add a task to the queue.
53+
54+
`func` must be a callable that can accept `kwargs` in addition to
55+
a keyword argument `cxn` which is a multiprocessing `Connection` object that forms
56+
the sending-half of a `mp.Pipe`. The callable must call `cxn.send(...)`
57+
to return a result.
58+
"""
59+
self.queued.append((func, kwargs))
60+
61+
def _submit_from_queue(self):
62+
# helper method that takes an item from the queue, launches a process,
63+
# and records it in the `pending` attribute. This must only be called
64+
# if we're allowed to spawn a new process.
65+
func, kwargs = self.queued.pop(0)
66+
recv, send = multiprocessing.Pipe(duplex=False)
67+
kwargs['cxn'] = send
68+
p = multiprocessing.Process(target=func, kwargs=kwargs)
69+
start_time = time.time()
70+
p.start()
71+
self.pending.append(_Pending(p=p, recv=recv, start_time=start_time, kwargs=kwargs))
72+
73+
def _scan_pendings(self) -> Optional[_Pending]:
74+
# helper method that goes through the currently pending tasks, terminates the ones
75+
# that have been going on too long, and accounts for ones that have finished.
76+
# Returns the `_Pending` of the killed or completed job or `None` if each pending
77+
# task is still running but none have exceeded the timeout.
78+
for i in range(len(self.pending)):
79+
pen = self.pending[i]
80+
81+
if not pen.p.is_alive():
82+
self.pending.pop(i)
83+
pen.p.join()
84+
return pen
85+
86+
if time.time() - pen.start_time > self.timeout:
87+
pen.p.terminate()
88+
self.pending.pop(i)
89+
return pen
90+
91+
return None
92+
93+
def next_result(self) -> Tuple[Dict[str, Any], Optional[Any]]:
94+
"""Get the next available result.
95+
96+
This call is blocking, but should never take longer than `self.timeout`. This should
97+
be called in a loop to make sure the queue continues to be processed.
98+
99+
Returns:
100+
task kwargs: The keyword arguments used to submit the task.
101+
result: If the process finished successfully, this is the object that was
102+
sent through the multiprocessing pipe as the result. Otherwise, the result
103+
is None.
104+
"""
105+
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
106+
self._submit_from_queue()
107+
108+
while True:
109+
finished = self._scan_pendings()
110+
if finished is not None:
111+
break
112+
113+
if finished.p.exitcode == 0:
114+
result = finished.recv.recv()
115+
else:
116+
result = None
117+
118+
finished.recv.close()
119+
120+
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
121+
self._submit_from_queue()
122+
123+
return (finished.kwargs, result)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import time
15+
16+
from .execute_with_timeout import ExecuteWithTimeout
17+
18+
19+
def a_long_function(n_seconds: int, cxn) -> None:
20+
time.sleep(n_seconds)
21+
cxn.send("Done")
22+
23+
24+
def test_execute_with_timeout():
25+
exec = ExecuteWithTimeout(timeout=1, max_workers=1)
26+
27+
for ns in [0.1, 100]:
28+
exec.submit(a_long_function, {'n_seconds': ns})
29+
30+
results = []
31+
while exec.work_to_be_done:
32+
kwargs, result = exec.next_result()
33+
if result is None:
34+
results.append('Timeout')
35+
else:
36+
results.append(result)
37+
38+
assert set(results) == {'Done', 'Timeout'}

dev_tools/qualtran_dev_tools/tensor_report_card.py

Lines changed: 1 addition & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -11,142 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import multiprocessing.connection
1514
import time
16-
from typing import Any, Callable, Dict, List, Optional, Tuple
17-
18-
from attrs import define
15+
from typing import Any, Dict
1916

2017
from qualtran import Bloq
2118
from qualtran.simulation.tensor import cbloq_to_quimb
22-
23-
24-
@define
25-
class _Pending:
26-
"""Helper dataclass to track currently executing processes in `ExecuteWithTimeout`."""
27-
28-
p: multiprocessing.Process
29-
recv: multiprocessing.connection.Connection
30-
start_time: float
31-
kwargs: Dict[str, Any]
32-
33-
34-
class ExecuteWithTimeout:
35-
"""Execute tasks in processes where each task will be killed if it exceeds `timeout`.
36-
37-
Seemingly all the existing "timeout" parameters in the various built-in concurrency
38-
primitives in Python won't actually terminate the process. This one does.
39-
"""
40-
41-
def __init__(self, timeout: float, max_workers: int):
42-
self.timeout = timeout
43-
self.max_workers = max_workers
44-
45-
self.queued: List[Tuple[Callable, Dict[str, Any]]] = []
46-
self.pending: List[_Pending] = []
47-
48-
@property
49-
def work_to_be_done(self) -> int:
50-
"""The number of tasks currently executing or queued."""
51-
return len(self.queued) + len(self.pending)
52-
53-
def submit(self, func: Callable, kwargs: Dict[str, Any]) -> None:
54-
"""Add a task to the queue.
55-
56-
`func` must be a callable that can accept `kwargs` in addition to
57-
a keyword argument `cxn` which is a multiprocessing `Connection` object that forms
58-
the sending-half of a `mp.Pipe`. The callable must call `cxn.send(...)`
59-
to return a result.
60-
"""
61-
self.queued.append((func, kwargs))
62-
63-
def _submit_from_queue(self):
64-
# helper method that takes an item from the queue, launches a process,
65-
# and records it in the `pending` attribute. This must only be called
66-
# if we're allowed to spawn a new process.
67-
func, kwargs = self.queued.pop(0)
68-
recv, send = multiprocessing.Pipe(duplex=False)
69-
kwargs['cxn'] = send
70-
p = multiprocessing.Process(target=func, kwargs=kwargs)
71-
start_time = time.time()
72-
p.start()
73-
self.pending.append(_Pending(p=p, recv=recv, start_time=start_time, kwargs=kwargs))
74-
75-
def _scan_pendings(self) -> Optional[_Pending]:
76-
# helper method that goes through the currently pending tasks, terminates the ones
77-
# that have been going on too long, and accounts for ones that have finished.
78-
# Returns the `_Pending` of the killed or completed job or `None` if each pending
79-
# task is still running but none have exceeded the timeout.
80-
for i in range(len(self.pending)):
81-
pen = self.pending[i]
82-
83-
if not pen.p.is_alive():
84-
self.pending.pop(i)
85-
pen.p.join()
86-
return pen
87-
88-
if time.time() - pen.start_time > self.timeout:
89-
pen.p.terminate()
90-
self.pending.pop(i)
91-
return pen
92-
93-
return None
94-
95-
def next_result(self) -> Tuple[Dict[str, Any], Optional[Any]]:
96-
"""Get the next available result.
97-
98-
This call is blocking, but should never take longer than `self.timeout`. This should
99-
be called in a loop to make sure the queue continues to be processed.
100-
101-
Returns:
102-
task kwargs: The keyword arguments used to submit the task.
103-
result: If the process finished successfully, this is the object that was
104-
sent through the multiprocessing pipe as the result. Otherwise, the result
105-
is None.
106-
"""
107-
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
108-
self._submit_from_queue()
109-
110-
while True:
111-
finished = self._scan_pendings()
112-
if finished is not None:
113-
break
114-
115-
if finished.p.exitcode == 0:
116-
result = finished.recv.recv()
117-
else:
118-
result = None
119-
120-
finished.recv.close()
121-
122-
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
123-
self._submit_from_queue()
124-
125-
return (finished.kwargs, result)
126-
127-
128-
def report_on_tensors(name: str, cls_name: str, bloq: Bloq, cxn) -> None:
129-
"""Get timing information for tensor functionality.
130-
131-
This should be used with `ExecuteWithTimeout`. The resultant
132-
record dictionary is sent over `cxn`.
133-
"""
134-
record: Dict[str, Any] = {'name': name, 'cls': cls_name}
135-
136-
try:
137-
start = time.perf_counter()
138-
flat = bloq.as_composite_bloq().flatten()
139-
record['flat_dur'] = time.perf_counter() - start
140-
141-
start = time.perf_counter()
142-
tn = cbloq_to_quimb(flat)
143-
record['tn_dur'] = time.perf_counter() - start
144-
145-
start = time.perf_counter()
146-
record['width'] = tn.contraction_width()
147-
record['width_dur'] = time.perf_counter() - start
148-
149-
except Exception as e: # pylint: disable=broad-exception-caught
150-
record['err'] = str(e)
151-
152-
cxn.send(record)

0 commit comments

Comments
 (0)