-
-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathvisitor100.py
More file actions
90 lines (70 loc) · 2.8 KB
/
visitor100.py
File metadata and controls
90 lines (70 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Contains visitor for ASYNC100.
A `with trio.fail_after(...):` or `with trio.move_on_after(...):`
context does not contain any `await` statements. This makes it pointless, as
the timeout can only be triggered by a checkpoint.
Checkpoints on Await, Async For and Async With
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import libcst as cst
import libcst.matchers as m
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
from .helpers import (
AttributeCall,
error_class_cst,
flatten_preserving_comments,
with_has_call,
)
if TYPE_CHECKING:
from collections.abc import Mapping
@error_class_cst
class Visitor100_libcst(Flake8AsyncVisitor_cst):
error_codes: Mapping[str, str] = {
"ASYNC100": (
"{0}.{1} context contains no checkpoints, remove the context or add"
" `await {0}.lowlevel.checkpoint()`."
),
}
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.has_checkpoint_stack: list[bool] = []
self.node_dict: dict[cst.With, list[AttributeCall]] = {}
def checkpoint(self) -> None:
# Set the whole stack to True.
self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack)
def visit_With(self, node: cst.With) -> None:
if m.matches(node, m.With(asynchronous=m.Asynchronous())):
self.checkpoint()
if res := with_has_call(
node, "fail_after", "fail_at", "move_on_after", "move_on_at", "CancelScope"
):
self.node_dict[node] = res
self.has_checkpoint_stack.append(False)
else:
self.has_checkpoint_stack.append(True)
def leave_With(
self, original_node: cst.With, updated_node: cst.With
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
if not self.has_checkpoint_stack.pop():
autofix = len(updated_node.items) == 1
for res in self.node_dict[original_node]:
autofix &= self.error(
res.node, res.base, res.function
) and self.should_autofix(res.node)
if autofix:
return flatten_preserving_comments(updated_node)
return updated_node
def visit_For(self, node: cst.For):
if node.asynchronous is not None:
self.checkpoint()
def visit_Await(self, node: cst.Await | cst.Yield):
self.checkpoint()
visit_Yield = visit_Await
def visit_FunctionDef(self, node: cst.FunctionDef):
self.save_state(node, "has_checkpoint_stack", copy=True)
self.has_checkpoint_stack = []
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
self.restore_state(original_node)
return updated_node