diff --git a/test/test_realtime.py b/test/test_realtime.py index d113e7e..2f30e12 100644 --- a/test/test_realtime.py +++ b/test/test_realtime.py @@ -159,3 +159,63 @@ async def mock_send(message: str) -> None: assert mock_adapter.send.await_count == 1 mock_adapter.receive.assert_awaited() assert mock_adapter.close.await_count == 1 + +async def test_watchdog_resets_sub_manager_after_close( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """sub_manager must be None after watchdog closes connection + so _create_sub_manager() builds a fresh transport instead of + reusing the stale one with an expired token.""" + await tibber_rt.connect() + assert tibber_rt.sub_manager is not None + + await tibber_rt._reset_connection() + + assert tibber_rt.sub_manager is None + assert tibber_rt.session is None + +async def test_on_reconnect_callback_called_before_reconnect( + mock_client: MagicMock, +) -> None: + """on_reconnect must be called before _create_sub_manager() + so the fresh websocketSubscriptionUrl is used for the new transport.""" + call_order = [] + + async def mock_reconnect() -> None: + call_order.append("on_reconnect") + + tibber_rt = TibberRT( + access_token="test_token", + timeout=30, + user_agent="test_agent", + ssl=True, + on_reconnect=mock_reconnect, + ) + tibber_rt.sub_endpoint = "wss://test.endpoint" + await tibber_rt.connect() + call_order.append("connected") + + await tibber_rt._reset_connection() + await tibber_rt._on_reconnect() + call_order.append("create_sub_manager") + + assert call_order == ["connected", "on_reconnect", "create_sub_manager"] + +async def test_sub_endpoint_setter_skips_replacement_on_same_url( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """Setting the same URL must not replace a running sub_manager. + Previously this would orphan the existing websocket connection.""" + await tibber_rt.connect() + + # Track how many times Client() was instantiated + with patch("tibber.realtime.Client") as mock_client_class: + # Set same URL — should be a no-op, Client() not called + tibber_rt.sub_endpoint = "wss://test.endpoint" + mock_client_class.assert_not_called() + + # Set different URL — should create new Client + tibber_rt.sub_endpoint = "wss://new.endpoint" + mock_client_class.assert_called_once() diff --git a/tibber/__init__.py b/tibber/__init__.py index cacad7b..8501465 100644 --- a/tibber/__init__.py +++ b/tibber/__init__.py @@ -71,6 +71,7 @@ def __init__( self.timeout, self._user_agent, ssl=ssl, + on_reconnect=self.update_info, ) self.time_zone: dt.tzinfo = time_zone or dt.UTC @@ -121,7 +122,7 @@ async def execute( timeout=aiohttp.ClientTimeout(total=self.timeout), ) return (await extract_response_data(resp)).get("data") - except (TimeoutError, aiohttp.ClientError) as err: + except (TimeoutError, aiohttp.ClientError, RetryableHttpExceptionError) as err: if retry > 0: return await self.execute( document, @@ -228,16 +229,28 @@ async def rt_disconnect(self) -> None: return await self.realtime.disconnect() async def set_access_token(self, access_token: str) -> None: + """Set access token and reauthorize clients.""" if access_token == self._access_token: return - """Set access token and reauthorize clients.""" + restore_realtime = self.realtime.should_restore_connection + old_token = self._access_token # Store old token in case of rollback. + self._access_token = access_token await self.realtime.set_access_token(access_token) self.data_api.set_access_token(access_token) - await self.update_info() - if restore_realtime: - await self.realtime.reconnect() + + try: + await self.update_info() + except Exception: + # Rollback: If token was wrong or API threw transient error. + self._access_token = old_token + raise + else: + # Watchdog start only when connection was successful. + if restore_realtime: + await self.realtime.reconnect() + @property def user_id(self) -> str | None: diff --git a/tibber/home.py b/tibber/home.py index 7ed1786..b69dc07 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -466,16 +466,14 @@ async def _start() -> None: self._rt_stopped = False async def rt_resubscribe(self) -> None: - """Resubscribe to Tibber data.""" + """Resubscribe to Tibber data. + + Note: websocketSubscriptionUrl refresh is handled by the watchdog + via on_reconnect callback before calling this method. + """ self.rt_unsubscribe() _LOGGER.debug("Resubscribe, %s", self.home_id) - await asyncio.gather( - *[ - self.update_info(), - self._tibber_control.update_info(), - ], - return_exceptions=False, - ) + await self.update_info() if self._rt_callback is None: _LOGGER.warning("No callback set for rt_resubscribe") return diff --git a/tibber/realtime.py b/tibber/realtime.py index 5bdf645..0b618fd 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -1,11 +1,15 @@ """Tibber RT connection.""" +from __future__ import annotations +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + from ssl import SSLContext + from collections.abc import Awaitable, Callable + import asyncio import datetime as dt import logging import random -from ssl import SSLContext -from typing import Any from gql import Client @@ -13,6 +17,8 @@ from .home import TibberHome from .websocket_transport import TibberWebsocketsTransport + + LOCK_CONNECT = asyncio.Lock() _LOGGER = logging.getLogger(__name__) @@ -21,7 +27,13 @@ class TibberRT: """Class to handle real time connection with the Tibber api.""" - def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLContext | bool) -> None: + def __init__(self, + access_token: str, + timeout: int, + user_agent: str, + ssl: SSLContext | bool, + on_reconnect: Callable[[], Awaitable[None]] | None = None, + ) -> None: """Initialize the Tibber connection. :param access_token: The access token to access the Tibber API with. @@ -32,6 +44,7 @@ def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLCon self._timeout: int = timeout self._user_agent: str = user_agent self._ssl_context = ssl + self._on_reconnect = on_reconnect self._sub_endpoint: str | None = None self._homes: list[TibberHome] = [] @@ -69,7 +82,11 @@ async def connect(self) -> None: """Start subscription manager.""" self._create_sub_manager() - assert self.sub_manager is not None + # _create_sub_manager() already raises SubscriptionEndpointMissingError + # if sub_endpoint is None, so sub_manager is guaranteed to be set here. + # This guard catches future regressions if _create_sub_manager() changes. + if self.sub_manager is None: + raise RuntimeError("sub_manager not initialized before connect()") async with LOCK_CONNECT: if self.subscription_running: @@ -78,6 +95,11 @@ async def connect(self) -> None: _LOGGER.debug("Starting watchdog") self._watchdog_running = True self._watchdog_runner = asyncio.create_task(self._watchdog()) + # Make sure that we see Watchdog raises in the log. + self._watchdog_runner.add_done_callback( + lambda t: _LOGGER.error("Watchdog task failed: %s", t.exception()) + if not t.cancelled() and t.exception() else None + ) self.session = await self.sub_manager.connect_async() async def reconnect(self) -> None: @@ -107,8 +129,17 @@ def _create_sub_manager(self) -> None: async def _watchdog(self) -> None: """Watchdog to keep connection alive.""" - assert self.sub_manager is not None - assert isinstance(self.sub_manager.transport, TibberWebsocketsTransport) + + # Watchdog is started from connect() which calls _create_sub_manager() first, + # so sub_manager is guaranteed to exist and have the correct transport type. + # This guard catches future regressions and or rouge watchdog calls. + if self.sub_manager is None: + raise RuntimeError("Watchdog started without sub_manager") + if not isinstance(self.sub_manager.transport, TibberWebsocketsTransport): + raise RuntimeError( + f"Watchdog started with unexpected transport type: " + f"{type(self.sub_manager.transport)}" + ) await asyncio.sleep(60) @@ -116,61 +147,84 @@ async def _watchdog(self) -> None: next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) while self._watchdog_running: await asyncio.sleep(5) - if ( - self.sub_manager.transport.running - and self.sub_manager.transport.reconnect_at - > dt.datetime.now( - tz=dt.UTC, - ) - and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running - ): - is_running = True - for home in self._homes: - _LOGGER.debug( - "Watchdog: Checking if home %s is alive, %s, %s", - home.home_id, - home.has_real_time_consumption, - home.rt_subscription_running, - ) - if not home.rt_subscription_running: - is_running = False - next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60) - break - _LOGGER.debug( - "Watchdog: Home %s is alive", - home.home_id, + + # Reconnect Backoff + if self.sub_manager: + if ( + self.sub_manager.transport.running + and self.sub_manager.transport.reconnect_at + > dt.datetime.now( + tz=dt.UTC, ) - if is_running: - _retry_count = 0 - _LOGGER.debug("Watchdog: Connection is alive") - continue + and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running + ): + is_running = True + for home in self._homes: + _LOGGER.debug( + "Watchdog: Checking if home %s is alive, %s, %s", + home.home_id, + home.has_real_time_consumption, + home.rt_subscription_running, + ) + if not home.rt_subscription_running: + is_running = False + next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60) + break + _LOGGER.debug( + "Watchdog: Home %s is alive", + home.home_id, + ) + if is_running: + _retry_count = 0 + _LOGGER.debug("Watchdog: Connection is alive") + continue + if self.sub_manager: + self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta( + seconds=self._timeout) + reconnect_at = self.sub_manager.transport.reconnect_at + else: + reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - _LOGGER.error( - "Watchdog: Connection is down, %s", - self.sub_manager.transport.reconnect_at, - ) + _LOGGER.error("Watchdog: Connection is down, %s", reconnect_at) try: - if self.session is not None: + if self.session is not None and self.sub_manager is not None: await self.sub_manager.close_async() - self.session = None - except Exception: - _LOGGER.exception("Error in watchdog close") + except Exception as e: + _LOGGER.exception(f"Error in watchdog close: {e}") + finally: + # Reset connection state so _create_sub_manager() builds a fresh + # transport with current credentials instead of reusing the stale one. + self.session = None + self.sub_manager = None if not self._watchdog_running: _LOGGER.debug("Watchdog: Stopping") return + delay_seconds = min( + random.SystemRandom().randint(1, 30) + _retry_count ** 2, + 5 * 60, + ) + if self._on_reconnect is not None: + try: + await self._on_reconnect() # fetch fresh websocketSubscriptionUrl before reconnecting + except Exception as err: + # Tibber API unreachable or token expired. No point connecting + # with stale credentials, wait and retry. + _retry_count += 1 + _LOGGER.error( + "Failed to refresh connection info before reconnect, aborting: %s", err + ) + await asyncio.sleep(delay_seconds) + continue + self._create_sub_manager() + try: self.session = await self.sub_manager.connect_async() await self._resubscribe_homes() except Exception as err: # noqa: BLE001 - delay_seconds = min( - random.SystemRandom().randint(1, 30) + _retry_count**2, - 5 * 60, - ) _retry_count += 1 _LOGGER.error( "Error in watchdog connect, retrying in %s seconds, %s: %s", @@ -180,6 +234,14 @@ async def _watchdog(self) -> None: exc_info=_retry_count > 1, ) await asyncio.sleep(delay_seconds) + # Reset sub_manager on 4403 to force a fresh websocketSubscriptionUrl + # on the next attempt. This handles a race condition where Tibber + # invalidates the session server-side (e.g. during a rolling deployment) + # between us fetching the URL and completing the WebSocket handshake. + # We only reset on 4403 specifically — transient network errors don't + # invalidate the URL so resetting unnecessarily would cause extra API calls. + if "4403" in str(err) or "Invalid token" in str(err): + self.sub_manager = None else: _LOGGER.debug("Watchdog: Reconnected successfully") await asyncio.sleep(60) @@ -221,6 +283,8 @@ def sub_endpoint(self) -> str | None: @sub_endpoint.setter def sub_endpoint(self, sub_endpoint: str) -> None: """Set subscription endpoint.""" + if self._sub_endpoint == sub_endpoint: + return # URL unchanged, don't replace a running sub_manager self._sub_endpoint = sub_endpoint if self.sub_manager is not None and isinstance(self.sub_manager.transport, TibberWebsocketsTransport): self.sub_manager = Client(