Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions test/test_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,63 @@ async def mock_send(message: str) -> None:
assert mock_adapter.send.await_count == 1
mock_adapter.receive.assert_awaited()
assert mock_adapter.close.await_count == 1

async def test_watchdog_resets_sub_manager_after_close(
mock_client: MagicMock,
tibber_rt: TibberRT,
) -> None:
"""sub_manager must be None after watchdog closes connection
so _create_sub_manager() builds a fresh transport instead of
reusing the stale one with an expired token."""
await tibber_rt.connect()
assert tibber_rt.sub_manager is not None

await tibber_rt._reset_connection()

assert tibber_rt.sub_manager is None
assert tibber_rt.session is None

async def test_on_reconnect_callback_called_before_reconnect(
mock_client: MagicMock,
) -> None:
"""on_reconnect must be called before _create_sub_manager()
so the fresh websocketSubscriptionUrl is used for the new transport."""
call_order = []

async def mock_reconnect() -> None:
call_order.append("on_reconnect")

tibber_rt = TibberRT(
access_token="test_token",
timeout=30,
user_agent="test_agent",
ssl=True,
on_reconnect=mock_reconnect,
)
tibber_rt.sub_endpoint = "wss://test.endpoint"
await tibber_rt.connect()
call_order.append("connected")

await tibber_rt._reset_connection()
await tibber_rt._on_reconnect()
call_order.append("create_sub_manager")

assert call_order == ["connected", "on_reconnect", "create_sub_manager"]

async def test_sub_endpoint_setter_skips_replacement_on_same_url(
mock_client: MagicMock,
tibber_rt: TibberRT,
) -> None:
"""Setting the same URL must not replace a running sub_manager.
Previously this would orphan the existing websocket connection."""
await tibber_rt.connect()

# Track how many times Client() was instantiated
with patch("tibber.realtime.Client") as mock_client_class:
# Set same URL — should be a no-op, Client() not called
tibber_rt.sub_endpoint = "wss://test.endpoint"
mock_client_class.assert_not_called()

# Set different URL — should create new Client
tibber_rt.sub_endpoint = "wss://new.endpoint"
mock_client_class.assert_called_once()
23 changes: 18 additions & 5 deletions tibber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.timeout,
self._user_agent,
ssl=ssl,
on_reconnect=self.update_info,
)

self.time_zone: dt.tzinfo = time_zone or dt.UTC
Expand Down Expand Up @@ -121,7 +122,7 @@ async def execute(
timeout=aiohttp.ClientTimeout(total=self.timeout),
)
return (await extract_response_data(resp)).get("data")
except (TimeoutError, aiohttp.ClientError) as err:
except (TimeoutError, aiohttp.ClientError, RetryableHttpExceptionError) as err:
if retry > 0:
return await self.execute(
document,
Expand Down Expand Up @@ -228,16 +229,28 @@ async def rt_disconnect(self) -> None:
return await self.realtime.disconnect()

async def set_access_token(self, access_token: str) -> None:
"""Set access token and reauthorize clients."""
if access_token == self._access_token:
return
"""Set access token and reauthorize clients."""

restore_realtime = self.realtime.should_restore_connection
old_token = self._access_token # Store old token in case of rollback.

self._access_token = access_token
await self.realtime.set_access_token(access_token)
self.data_api.set_access_token(access_token)
await self.update_info()
if restore_realtime:
await self.realtime.reconnect()

try:
await self.update_info()
except Exception:
# Rollback: If token was wrong or API threw transient error.
self._access_token = old_token
raise
else:
# Watchdog start only when connection was successful.
if restore_realtime:
await self.realtime.reconnect()


@property
def user_id(self) -> str | None:
Expand Down
14 changes: 6 additions & 8 deletions tibber/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,14 @@ async def _start() -> None:
self._rt_stopped = False

async def rt_resubscribe(self) -> None:
"""Resubscribe to Tibber data."""
"""Resubscribe to Tibber data.

Note: websocketSubscriptionUrl refresh is handled by the watchdog
via on_reconnect callback before calling this method.
"""
self.rt_unsubscribe()
_LOGGER.debug("Resubscribe, %s", self.home_id)
await asyncio.gather(
*[
self.update_info(),
self._tibber_control.update_info(),
],
return_exceptions=False,
)
await self.update_info()
if self._rt_callback is None:
_LOGGER.warning("No callback set for rt_resubscribe")
return
Expand Down
156 changes: 110 additions & 46 deletions tibber/realtime.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
"""Tibber RT connection."""

from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ssl import SSLContext
from collections.abc import Awaitable, Callable

import asyncio
import datetime as dt
import logging
import random
from ssl import SSLContext
from typing import Any

from gql import Client

from .exceptions import SubscriptionEndpointMissingError
from .home import TibberHome
from .websocket_transport import TibberWebsocketsTransport



LOCK_CONNECT = asyncio.Lock()

_LOGGER = logging.getLogger(__name__)
Expand All @@ -21,7 +27,13 @@
class TibberRT:
"""Class to handle real time connection with the Tibber api."""

def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLContext | bool) -> None:
def __init__(self,
access_token: str,
timeout: int,
user_agent: str,
ssl: SSLContext | bool,
on_reconnect: Callable[[], Awaitable[None]] | None = None,
) -> None:
"""Initialize the Tibber connection.

:param access_token: The access token to access the Tibber API with.
Expand All @@ -32,6 +44,7 @@ def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLCon
self._timeout: int = timeout
self._user_agent: str = user_agent
self._ssl_context = ssl
self._on_reconnect = on_reconnect

self._sub_endpoint: str | None = None
self._homes: list[TibberHome] = []
Expand Down Expand Up @@ -69,7 +82,11 @@ async def connect(self) -> None:
"""Start subscription manager."""
self._create_sub_manager()

assert self.sub_manager is not None
# _create_sub_manager() already raises SubscriptionEndpointMissingError
# if sub_endpoint is None, so sub_manager is guaranteed to be set here.
# This guard catches future regressions if _create_sub_manager() changes.
if self.sub_manager is None:
raise RuntimeError("sub_manager not initialized before connect()")

async with LOCK_CONNECT:
if self.subscription_running:
Expand All @@ -78,6 +95,11 @@ async def connect(self) -> None:
_LOGGER.debug("Starting watchdog")
self._watchdog_running = True
self._watchdog_runner = asyncio.create_task(self._watchdog())
# Make sure that we see Watchdog raises in the log.
self._watchdog_runner.add_done_callback(
lambda t: _LOGGER.error("Watchdog task failed: %s", t.exception())
if not t.cancelled() and t.exception() else None
)
self.session = await self.sub_manager.connect_async()

async def reconnect(self) -> None:
Expand Down Expand Up @@ -107,70 +129,102 @@ def _create_sub_manager(self) -> None:

async def _watchdog(self) -> None:
"""Watchdog to keep connection alive."""
assert self.sub_manager is not None
assert isinstance(self.sub_manager.transport, TibberWebsocketsTransport)

# Watchdog is started from connect() which calls _create_sub_manager() first,
# so sub_manager is guaranteed to exist and have the correct transport type.
# This guard catches future regressions and or rouge watchdog calls.
if self.sub_manager is None:
raise RuntimeError("Watchdog started without sub_manager")
if not isinstance(self.sub_manager.transport, TibberWebsocketsTransport):
raise RuntimeError(
f"Watchdog started with unexpected transport type: "
f"{type(self.sub_manager.transport)}"
)

await asyncio.sleep(60)

_retry_count = 0
next_test_all_homes_running = dt.datetime.now(tz=dt.UTC)
while self._watchdog_running:
await asyncio.sleep(5)
if (
self.sub_manager.transport.running
and self.sub_manager.transport.reconnect_at
> dt.datetime.now(
tz=dt.UTC,
)
and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running
):
is_running = True
for home in self._homes:
_LOGGER.debug(
"Watchdog: Checking if home %s is alive, %s, %s",
home.home_id,
home.has_real_time_consumption,
home.rt_subscription_running,
)
if not home.rt_subscription_running:
is_running = False
next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60)
break
_LOGGER.debug(
"Watchdog: Home %s is alive",
home.home_id,

# Reconnect Backoff
if self.sub_manager:
if (
self.sub_manager.transport.running
and self.sub_manager.transport.reconnect_at
> dt.datetime.now(
tz=dt.UTC,
)
if is_running:
_retry_count = 0
_LOGGER.debug("Watchdog: Connection is alive")
continue
and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running
):
is_running = True
for home in self._homes:
_LOGGER.debug(
"Watchdog: Checking if home %s is alive, %s, %s",
home.home_id,
home.has_real_time_consumption,
home.rt_subscription_running,
)
if not home.rt_subscription_running:
is_running = False
next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60)
break
_LOGGER.debug(
"Watchdog: Home %s is alive",
home.home_id,
)
if is_running:
_retry_count = 0
_LOGGER.debug("Watchdog: Connection is alive")
continue
if self.sub_manager:
self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(
seconds=self._timeout)
reconnect_at = self.sub_manager.transport.reconnect_at
else:
reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout)

self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout)
_LOGGER.error(
"Watchdog: Connection is down, %s",
self.sub_manager.transport.reconnect_at,
)
_LOGGER.error("Watchdog: Connection is down, %s", reconnect_at)

try:
if self.session is not None:
if self.session is not None and self.sub_manager is not None:
await self.sub_manager.close_async()
self.session = None
except Exception:
_LOGGER.exception("Error in watchdog close")
except Exception as e:
_LOGGER.exception(f"Error in watchdog close: {e}")
finally:
# Reset connection state so _create_sub_manager() builds a fresh
# transport with current credentials instead of reusing the stale one.
self.session = None
self.sub_manager = None

if not self._watchdog_running:
_LOGGER.debug("Watchdog: Stopping")
return

delay_seconds = min(
random.SystemRandom().randint(1, 30) + _retry_count ** 2,
5 * 60,
)
if self._on_reconnect is not None:
try:
await self._on_reconnect() # fetch fresh websocketSubscriptionUrl before reconnecting
except Exception as err:
# Tibber API unreachable or token expired. No point connecting
# with stale credentials, wait and retry.
_retry_count += 1
_LOGGER.error(
"Failed to refresh connection info before reconnect, aborting: %s", err
)
await asyncio.sleep(delay_seconds)
continue

self._create_sub_manager()

try:
self.session = await self.sub_manager.connect_async()
await self._resubscribe_homes()
except Exception as err: # noqa: BLE001
delay_seconds = min(
random.SystemRandom().randint(1, 30) + _retry_count**2,
5 * 60,
)
_retry_count += 1
_LOGGER.error(
"Error in watchdog connect, retrying in %s seconds, %s: %s",
Expand All @@ -180,6 +234,14 @@ async def _watchdog(self) -> None:
exc_info=_retry_count > 1,
)
await asyncio.sleep(delay_seconds)
# Reset sub_manager on 4403 to force a fresh websocketSubscriptionUrl
# on the next attempt. This handles a race condition where Tibber
# invalidates the session server-side (e.g. during a rolling deployment)
# between us fetching the URL and completing the WebSocket handshake.
# We only reset on 4403 specifically — transient network errors don't
# invalidate the URL so resetting unnecessarily would cause extra API calls.
if "4403" in str(err) or "Invalid token" in str(err):
self.sub_manager = None
else:
_LOGGER.debug("Watchdog: Reconnected successfully")
await asyncio.sleep(60)
Expand Down Expand Up @@ -221,6 +283,8 @@ def sub_endpoint(self) -> str | None:
@sub_endpoint.setter
def sub_endpoint(self, sub_endpoint: str) -> None:
"""Set subscription endpoint."""
if self._sub_endpoint == sub_endpoint:
return # URL unchanged, don't replace a running sub_manager
self._sub_endpoint = sub_endpoint
if self.sub_manager is not None and isinstance(self.sub_manager.transport, TibberWebsocketsTransport):
self.sub_manager = Client(
Expand Down