Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bellows/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def convert(self, value, param, ctx):
def background(f):
@functools.wraps(f)
def inner(*args, **kwargs):
loop = asyncio.get_event_loop()
loop.run_until_complete(f(*args, **kwargs))
asyncio.run(f(*args, **kwargs))

return inner

Expand Down
10 changes: 8 additions & 2 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def is_tcp_serial_port(self) -> bool:

async def _startup_reset(self) -> None:
"""Start EZSP and reset the stack."""
if self._gw is None:
raise EzspError("Gateway is not connected")

# `zigbeed` resets on startup
if self.is_tcp_serial_port:
try:
Expand Down Expand Up @@ -220,8 +223,11 @@ async def get_xncp_features(self) -> xncp.FirmwareFeatures:

async def disconnect(self):
self.stop_ezsp()
if self._gw:
await self._gw.disconnect()
if self._gw is not None:
# Secondary loop closed; the proxy can't reach the gateway.
# Drop the reference so the caller can rebuild from scratch.
with contextlib.suppress(ConnectionError):
await self._gw.disconnect()
self._gw = None

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
Expand Down
25 changes: 18 additions & 7 deletions bellows/thread.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import functools
import inspect
import logging

LOGGER = logging.getLogger(__name__)
Expand All @@ -14,7 +15,7 @@ def __init__(self):
self.thread_complete = None

def run_coroutine_threadsafe(self, coroutine):
current_loop = asyncio.get_event_loop()
current_loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(coroutine, self.loop)
return asyncio.wrap_future(future, loop=current_loop)

Expand All @@ -30,7 +31,7 @@ def _thread_main(self, init_task):
self.loop = None

async def start(self):
current_loop = asyncio.get_event_loop()
current_loop = asyncio.get_running_loop()
if self.loop is not None and not self.loop.is_closed():
return

Expand Down Expand Up @@ -95,11 +96,21 @@ def func_wrapper(*args, **kwargs):
if loop == curr_loop:
return call()
if loop.is_closed():
# Disconnected
LOGGER.warning("Attempted to use a closed event loop")
return
if asyncio.iscoroutinefunction(func):
future = asyncio.run_coroutine_threadsafe(call(), loop)
raise ConnectionError(
"Attempted to use a closed event loop, "
"the connection may have been lost"
)
if inspect.iscoroutinefunction(func):
coro = call()
try:
future = asyncio.run_coroutine_threadsafe(coro, loop)
except RuntimeError:
# Loop closed between is_closed() check and dispatch
coro.close()
raise ConnectionError(
"Attempted to use a closed event loop, "
"the connection may have been lost"
)
return asyncio.wrap_future(future, loop=curr_loop)
else:

Expand Down
20 changes: 12 additions & 8 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@


class Gateway(zigpy.serial.SerialProtocol):
def __init__(self, api, connection_done_future=None):
def __init__(self, api, connection_done_future=None, loop=None):
super().__init__()
self._api = api

self._reset_future = None
self._startup_reset_future = None
# Pre-create so reset frames arriving immediately after connect are
# captured by reset_received() instead of triggering enter_failed_state().
# Tests construct Gateway without a loop and expect None here; in that
# case wait_for_startup_reset() will lazily create the future.
self._startup_reset_future = loop.create_future() if loop is not None else None
self._connection_done_future = connection_done_future

async def send_data(self, data: bytes) -> None:
Expand Down Expand Up @@ -52,8 +56,8 @@ def error_received(self, code: t.NcpResetCode) -> None:

async def wait_for_startup_reset(self) -> None:
"""Wait for the first reset frame on startup."""
assert self._startup_reset_future is None
self._startup_reset_future = asyncio.get_running_loop().create_future()
if self._startup_reset_future is None:
self._startup_reset_future = asyncio.get_running_loop().create_future()

try:
await self._startup_reset_future
Expand Down Expand Up @@ -98,19 +102,19 @@ async def reset(self):
return await self._reset_future

self._transport.send_reset()
self._reset_future = asyncio.get_event_loop().create_future()
self._reset_future = asyncio.get_running_loop().create_future()
self._reset_future.add_done_callback(self._reset_cleanup)

async with asyncio_timeout(RESET_TIMEOUT):
return await self._reset_future


async def _connect(config, api):
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()

connection_done_future = loop.create_future()

gateway = Gateway(api, connection_done_future)
gateway = Gateway(api, connection_done_future, loop=loop)
protocol = AshProtocol(gateway)

if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None:
Expand All @@ -135,7 +139,7 @@ async def _connect(config, api):

async def connect(config, api, use_thread=True):
if use_thread:
api = ThreadsafeProxy(api, asyncio.get_event_loop())
api = ThreadsafeProxy(api, asyncio.get_running_loop())
thread = EventLoopThread()
await thread.start()
try:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_ezsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,44 @@ async def wait_forever(*args, **kwargs):
assert version_mock.await_count == 1


async def test_startup_reset_gw_none():
"""Test _startup_reset raises EzspError when gateway is None."""
ezsp = make_ezsp(
config={
**DEVICE_CONFIG,
zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234",
}
)
ezsp._gw = None

with pytest.raises(EzspError, match="Gateway is not connected"):
await ezsp._startup_reset()


async def test_disconnect_gw_none():
"""Test disconnect doesn't raise when gateway is already None."""
ezsp = make_ezsp()
ezsp._gw = None

await ezsp.disconnect() # Should not raise

assert ezsp._gw is None


async def test_disconnect_swallows_connection_error():
"""If the gateway's `disconnect()` raises ConnectionError because the
secondary loop is dead, drop the gateway reference and return without
raising so the caller can rebuild from scratch."""
ezsp = make_ezsp()
mock_gw = MagicMock()
mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed"))
ezsp._gw = mock_gw

await ezsp.disconnect()

assert ezsp._gw is None


async def test_wait_for_stack_status(ezsp_f):
assert not ezsp_f._stack_status_listeners[t.sl_Status.NETWORK_DOWN]

Expand Down
27 changes: 26 additions & 1 deletion tests/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,35 @@ async def test_proxy_loop_closed():
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, loop)
loop.close()
proxy.test()
with pytest.raises(ConnectionError, match="closed event loop"):
proxy.test()
assert obj.test.call_count == 0


async def test_proxy_coroutine_loop_closed_mid_dispatch():
"""If the loop closes between the `is_closed()` check and
`run_coroutine_threadsafe()`, the proxy must close the orphaned
coroutine and surface the failure as ConnectionError instead of
leaking an un-awaited coroutine warning."""
loop = asyncio.new_event_loop()

async def fake_coro(): # pragma: no cover - never awaited
return None

obj = mock.MagicMock()
obj.test = fake_coro
proxy = ThreadsafeProxy(obj, loop)

with mock.patch(
"asyncio.run_coroutine_threadsafe",
side_effect=RuntimeError("loop closed"),
):
with pytest.raises(ConnectionError, match="closed event loop"):
proxy.test()

loop.close()


async def test_thread_task_cancellation_after_stop(thread):
loop = asyncio.get_event_loop()
obj = mock.MagicMock()
Expand Down
Loading