diff --git a/bellows/cli/util.py b/bellows/cli/util.py index 76f83511..ebf6b755 100644 --- a/bellows/cli/util.py +++ b/bellows/cli/util.py @@ -35,8 +35,7 @@ def convert(self, value, param, ctx): def background(f): @functools.wraps(f) def inner(*args, **kwargs): - loop = asyncio.get_event_loop() - loop.run_until_complete(f(*args, **kwargs)) + asyncio.run(f(*args, **kwargs)) return inner diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 29554c93..b1ba3171 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -117,6 +117,9 @@ def is_tcp_serial_port(self) -> bool: async def _startup_reset(self) -> None: """Start EZSP and reset the stack.""" + if self._gw is None: + raise EzspError("Gateway is not connected") + # `zigbeed` resets on startup if self.is_tcp_serial_port: try: @@ -220,8 +223,11 @@ async def get_xncp_features(self) -> xncp.FirmwareFeatures: async def disconnect(self): self.stop_ezsp() - if self._gw: - await self._gw.disconnect() + if self._gw is not None: + # Secondary loop closed; the proxy can't reach the gateway. + # Drop the reference so the caller can rebuild from scratch. + with contextlib.suppress(ConnectionError): + await self._gw.disconnect() self._gw = None async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: diff --git a/bellows/thread.py b/bellows/thread.py index 4311768d..270402f6 100644 --- a/bellows/thread.py +++ b/bellows/thread.py @@ -1,6 +1,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor import functools +import inspect import logging LOGGER = logging.getLogger(__name__) @@ -14,7 +15,7 @@ def __init__(self): self.thread_complete = None def run_coroutine_threadsafe(self, coroutine): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) return asyncio.wrap_future(future, loop=current_loop) @@ -30,7 +31,7 @@ def _thread_main(self, init_task): self.loop = None async def start(self): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() if self.loop is not None and not self.loop.is_closed(): return @@ -95,11 +96,21 @@ def func_wrapper(*args, **kwargs): if loop == curr_loop: return call() if loop.is_closed(): - # Disconnected - LOGGER.warning("Attempted to use a closed event loop") - return - if asyncio.iscoroutinefunction(func): - future = asyncio.run_coroutine_threadsafe(call(), loop) + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) + if inspect.iscoroutinefunction(func): + coro = call() + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + except RuntimeError: + # Loop closed between is_closed() check and dispatch + coro.close() + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) return asyncio.wrap_future(future, loop=curr_loop) else: diff --git a/bellows/uart.py b/bellows/uart.py index af274dc8..d4517744 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -14,12 +14,16 @@ class Gateway(zigpy.serial.SerialProtocol): - def __init__(self, api, connection_done_future=None): + def __init__(self, api, connection_done_future=None, loop=None): super().__init__() self._api = api self._reset_future = None - self._startup_reset_future = None + # Pre-create so reset frames arriving immediately after connect are + # captured by reset_received() instead of triggering enter_failed_state(). + # Tests construct Gateway without a loop and expect None here; in that + # case wait_for_startup_reset() will lazily create the future. + self._startup_reset_future = loop.create_future() if loop is not None else None self._connection_done_future = connection_done_future async def send_data(self, data: bytes) -> None: @@ -52,8 +56,8 @@ def error_received(self, code: t.NcpResetCode) -> None: async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" - assert self._startup_reset_future is None - self._startup_reset_future = asyncio.get_running_loop().create_future() + if self._startup_reset_future is None: + self._startup_reset_future = asyncio.get_running_loop().create_future() try: await self._startup_reset_future @@ -98,7 +102,7 @@ async def reset(self): return await self._reset_future self._transport.send_reset() - self._reset_future = asyncio.get_event_loop().create_future() + self._reset_future = asyncio.get_running_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) async with asyncio_timeout(RESET_TIMEOUT): @@ -106,11 +110,11 @@ async def reset(self): async def _connect(config, api): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() connection_done_future = loop.create_future() - gateway = Gateway(api, connection_done_future) + gateway = Gateway(api, connection_done_future, loop=loop) protocol = AshProtocol(gateway) if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: @@ -135,7 +139,7 @@ async def _connect(config, api): async def connect(config, api, use_thread=True): if use_thread: - api = ThreadsafeProxy(api, asyncio.get_event_loop()) + api = ThreadsafeProxy(api, asyncio.get_running_loop()) thread = EventLoopThread() await thread.start() try: diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index d548309a..7f6bc0ca 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -789,6 +789,44 @@ async def wait_forever(*args, **kwargs): assert version_mock.await_count == 1 +async def test_startup_reset_gw_none(): + """Test _startup_reset raises EzspError when gateway is None.""" + ezsp = make_ezsp( + config={ + **DEVICE_CONFIG, + zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", + } + ) + ezsp._gw = None + + with pytest.raises(EzspError, match="Gateway is not connected"): + await ezsp._startup_reset() + + +async def test_disconnect_gw_none(): + """Test disconnect doesn't raise when gateway is already None.""" + ezsp = make_ezsp() + ezsp._gw = None + + await ezsp.disconnect() # Should not raise + + assert ezsp._gw is None + + +async def test_disconnect_swallows_connection_error(): + """If the gateway's `disconnect()` raises ConnectionError because the + secondary loop is dead, drop the gateway reference and return without + raising so the caller can rebuild from scratch.""" + ezsp = make_ezsp() + mock_gw = MagicMock() + mock_gw.disconnect = AsyncMock(side_effect=ConnectionError("loop closed")) + ezsp._gw = mock_gw + + await ezsp.disconnect() + + assert ezsp._gw is None + + async def test_wait_for_stack_status(ezsp_f): assert not ezsp_f._stack_status_listeners[t.sl_Status.NETWORK_DOWN] diff --git a/tests/test_thread.py b/tests/test_thread.py index 72efa701..056e96ff 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -157,10 +157,35 @@ async def test_proxy_loop_closed(): obj = mock.MagicMock() proxy = ThreadsafeProxy(obj, loop) loop.close() - proxy.test() + with pytest.raises(ConnectionError, match="closed event loop"): + proxy.test() assert obj.test.call_count == 0 +async def test_proxy_coroutine_loop_closed_mid_dispatch(): + """If the loop closes between the `is_closed()` check and + `run_coroutine_threadsafe()`, the proxy must close the orphaned + coroutine and surface the failure as ConnectionError instead of + leaking an un-awaited coroutine warning.""" + loop = asyncio.new_event_loop() + + async def fake_coro(): # pragma: no cover - never awaited + return None + + obj = mock.MagicMock() + obj.test = fake_coro + proxy = ThreadsafeProxy(obj, loop) + + with mock.patch( + "asyncio.run_coroutine_threadsafe", + side_effect=RuntimeError("loop closed"), + ): + with pytest.raises(ConnectionError, match="closed event loop"): + proxy.test() + + loop.close() + + async def test_thread_task_cancellation_after_stop(thread): loop = asyncio.get_event_loop() obj = mock.MagicMock()