From 0f5fc7dc6082c2e425560f5f7ca20ac386bd19c6 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Sun, 21 Dec 2025 18:29:02 +0100 Subject: [PATCH 01/14] Redesign websocket --- pyproject.toml | 1 + test/test_home.py | 281 ++++++++++++++++++++++++++++ test/test_realtime.py | 239 ++++++++++++++++++++++-- test/test_tibber.py | 80 +------- tibber/__init__.py | 30 +-- tibber/data_api.py | 1 + tibber/exceptions.py | 32 +++- tibber/home.py | 218 +++++++++++++--------- tibber/realtime.py | 342 ++++++++++++++++------------------ tibber/websocket_transport.py | 59 ------ 10 files changed, 837 insertions(+), 446 deletions(-) create mode 100644 test/test_home.py delete mode 100644 tibber/websocket_transport.py diff --git a/pyproject.toml b/pyproject.toml index ab98971..b363013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ "aiohttp>=3.0.6", "gql>=4.0.0", + "tenacity>=9.0.0", "websockets>=14.0.0", ] dynamic = ["version"] diff --git a/test/test_home.py b/test/test_home.py new file mode 100644 index 0000000..fa12668 --- /dev/null +++ b/test/test_home.py @@ -0,0 +1,281 @@ +"""Tests for TibberHome.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, call, create_autospec + +import aiohttp +import pytest + +import tibber +from tibber.exceptions import WebsocketReconnectedError, WebsocketTransportError +from tibber.gql_queries import INFO, UPDATE_INFO_PRICE +from tibber.realtime import TibberRT + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +HOME_ID = "test-home-id" + + +@pytest.fixture +def tibber_connection(mock_websession: MagicMock) -> tibber.Tibber: + tibber_client = tibber.Tibber( + access_token="test-token", + websession=mock_websession, + user_agent="test", + ) + tibber_client._user_agent = "test" # noqa: SLF001 + return tibber_client + + +@pytest.fixture +def mock_websession() -> MagicMock: + session = MagicMock(spec=aiohttp.ClientSession) + session.post = AsyncMock() + return session + + +@pytest.fixture +def mock_realtime(tibber_connection: tibber.Tibber) -> MagicMock: + rt = create_autospec(TibberRT, instance=True, subscription_running=False) + rt.connect = AsyncMock(side_effect=lambda: setattr(rt, "subscription_running", True)) + tibber_connection.realtime = rt + return rt + + +@pytest.fixture +def home(tibber_connection: tibber.Tibber) -> tibber.TibberHome: + home = tibber.TibberHome(HOME_ID, tibber_connection) + home._has_real_time_consumption = True # noqa: SLF001 + return home + + +def _make_blocking_subscribe( + yielded: list[Any], +) -> tuple[asyncio.Event, Any]: + """Return (release_event, subscribe_fn) that yields *yielded* then blocks.""" + release = asyncio.Event() + + async def subscribe(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: # noqa: ANN401, ARG001 + for item in yielded: + yield item + await release.wait() + + return release, subscribe + + +async def test_rt_subscribe_connects_and_calls_callback( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """Test that rt_subscribe connects via realtime and delivers subscription data to the callback.""" + sample_data = {"key": "value"} + _, subscribe_fn = _make_blocking_subscribe([sample_data]) + mock_realtime.subscribe = subscribe_fn + + received: list[dict] = [] + callback_called = asyncio.Event() + + def callback(data: dict) -> None: + received.append(data) + callback_called.set() + + await home.rt_subscribe(callback) + await asyncio.wait_for(callback_called.wait(), timeout=1.0) + + mock_realtime.connect.assert_awaited_once() + assert received == [{"data": sample_data}] + assert home.rt_subscription_running + + home.rt_unsubscribe() + assert not home.rt_subscription_running + + +async def test_rt_unsubscribe_noop_when_not_subscribed(home: tibber.TibberHome) -> None: + """Calling rt_unsubscribe on a fresh home must not raise.""" + assert not home.rt_subscription_running + home.rt_unsubscribe() # should be a no-op + assert not home.rt_subscription_running + + +async def test_rt_subscribe_multiple_items_all_delivered( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """All items yielded by subscribe must be delivered to the callback in order.""" + items = [{"n": 1}, {"n": 2}, {"n": 3}] + _, subscribe_fn = _make_blocking_subscribe(items) + mock_realtime.subscribe = subscribe_fn + + received: list[dict] = [] + all_received = asyncio.Event() + + def callback(data: dict) -> None: + received.append(data) + if len(received) == len(items): + all_received.set() + + await home.rt_subscribe(callback) + await asyncio.wait_for(all_received.wait(), timeout=1.0) + + assert received == [{"data": item} for item in items] + + home.rt_unsubscribe() + + +@pytest.mark.parametrize( + ("real_time_consumption", "http_calls"), + [ + ( + False, + [ + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": UPDATE_INFO_PRICE % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + ], + ), + ( + True, + [ + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": UPDATE_INFO_PRICE % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + ], + ), + ], +) +@pytest.mark.parametrize( + "error", + [ + WebsocketReconnectedError("reconnected"), + WebsocketTransportError("transport error"), + RuntimeError("unexpected"), + ], +) +async def test_rt_subscribe_on_error_called_on_exception( + mock_websession: MagicMock, + home: tibber.TibberHome, + mock_realtime: MagicMock, + error: Exception, + real_time_consumption: bool, + http_calls: list, +) -> None: + """on_error must be called when subscribe raises an exception.""" + home._has_real_time_consumption = real_time_consumption # noqa: SLF001 + wait_for_events = asyncio.Event() + wait_for_events.set() # allow subscribe to raise immediately + + async def subscribe_raises(*args: Any, **kwargs: Any) -> AsyncGenerator: # noqa: ANN401, ARG001 + await wait_for_events.wait() + raise error + yield + + mock_realtime.subscribe = subscribe_raises + + on_error_called = asyncio.Event() + caught: list[Exception] = [] + + def on_error(exc: Exception) -> None: + caught.append(exc) + on_error_called.set() + wait_for_events.clear() # allow test to control the flow after error is caught + + await home.rt_subscribe(MagicMock(), on_error=on_error) + await asyncio.wait_for(on_error_called.wait(), timeout=1.0) + + assert caught == [error] + # resubscription should have been triggered + assert mock_websession.post.call_count == len(http_calls) + assert mock_websession.post.call_args_list == http_calls + assert home.rt_subscription_running is real_time_consumption + + home.rt_unsubscribe() + + assert not home.rt_subscription_running + + +async def test_rt_subscribe_no_crash_when_subscribe_raises_without_on_error( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """_start_listen must not propagate exceptions when no on_error is provided.""" + + async def subscribe_raises(*args: Any, **kwargs: Any) -> AsyncGenerator: # noqa: ANN401, ARG001 + raise WebsocketTransportError("transport error") + yield + + mock_realtime.subscribe = subscribe_raises + + callback = MagicMock() + await home.rt_subscribe(callback) + + # give the listener task a chance to run and finish without raising + await asyncio.sleep(0) + await asyncio.sleep(0) + + callback.assert_not_called() + home.rt_unsubscribe() + + +async def test_rt_resubscribe_raises_without_prior_subscribe(home: tibber.TibberHome) -> None: + """rt_resubscribe must raise RuntimeError when rt_subscribe has not been called.""" + with pytest.raises(RuntimeError, match="rt_subscribe"): + await home.rt_resubscribe() + + +async def test_rt_subscribe_raises_when_already_subscribed( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """rt_subscribe must raise RuntimeError when called while already subscribed.""" + _, subscribe_fn = _make_blocking_subscribe([]) + mock_realtime.subscribe = subscribe_fn + + callback = MagicMock() + await home.rt_subscribe(callback) + + with pytest.raises(RuntimeError, match="rt_unsubscribe"): + await home.rt_subscribe(callback) + + home.rt_unsubscribe() + + +async def test_rt_resubscribe_emits_deprecation_warning( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """rt_resubscribe must emit a DeprecationWarning.""" + _, subscribe_fn = _make_blocking_subscribe([]) + mock_realtime.subscribe = subscribe_fn + + callback = MagicMock() + await home.rt_subscribe(callback) + + with pytest.warns(DeprecationWarning, match="deprecated"): + await home.rt_resubscribe() + + home.rt_unsubscribe() diff --git a/test/test_realtime.py b/test/test_realtime.py index d113e7e..6d00cc3 100644 --- a/test/test_realtime.py +++ b/test/test_realtime.py @@ -10,25 +10,31 @@ import pytest from gql.client import AsyncClientSession, Client from gql.transport.common.adapters.websockets import WebSocketsAdapter +from gql.transport.exceptions import TransportConnectionFailed, TransportError from websockets.asyncio.connection import State -from tibber.realtime import TibberRT -from tibber.websocket_transport import TibberWebsocketsTransport +from tibber.exceptions import SubscriptionEndpointMissingError, WebsocketReconnectedError, WebsocketTransportError +from tibber.realtime import TibberRT, TibberWebsocketsTransport if TYPE_CHECKING: from collections.abc import Generator +@pytest.fixture +def timeout() -> int: + return 30 + + @pytest.fixture(name="tibber_rt") -def tibber_rt_fixture() -> TibberRT: +async def tibber_rt_fixture(mock_client: MagicMock, timeout: int) -> TibberRT: # noqa: ARG001, ASYNC109 """Create a TibberRT instance for testing.""" tibber_rt = TibberRT( access_token="test_token", - timeout=30, + timeout=timeout, user_agent="test_agent", ssl=True, ) - tibber_rt.sub_endpoint = "wss://test.endpoint" + await tibber_rt.set_subscription_endpoint("wss://test.endpoint") return tibber_rt @@ -71,7 +77,7 @@ async def test_connect_disconnect( mock_client.connect_async.assert_awaited_once() - # Second connect should not call connect_async again since subscription_running is True + # Second connect should not call connect_async again since the client is already connected await tibber_rt.connect() # connect_async should still only have been called once @@ -83,20 +89,15 @@ async def test_connect_disconnect( async def test_subscription_running( - mock_client: MagicMock, tibber_rt: TibberRT, ) -> None: - """Test subscription_running.""" + """Test subscription running.""" assert tibber_rt.subscription_running is False await tibber_rt.connect() assert tibber_rt.subscription_running is True - mock_client.transport.adapter.websocket.state = State.CLOSED - - assert tibber_rt.subscription_running is False - await tibber_rt.disconnect() assert tibber_rt.subscription_running is False @@ -105,30 +106,75 @@ async def test_subscription_running( assert tibber_rt.subscription_running is True - mock_client.transport.adapter.websocket = None - assert tibber_rt.subscription_running is False +async def test_update_endpoint(mock_client: MagicMock) -> None: + """Test update subscription endpoint.""" + tibber_rt = TibberRT( + access_token="test_token", + timeout=30, + user_agent="test_agent", + ssl=True, + ) + with pytest.raises(SubscriptionEndpointMissingError, match="Subscription endpoint not initialized"): + await tibber_rt.connect() -async def test_update_endpoint(mock_client: MagicMock, tibber_rt: TibberRT) -> None: - """Test update subscription endpoint.""" + mock_client.reset_mock() + await tibber_rt.set_subscription_endpoint("wss://new.endpoint") await tibber_rt.connect() - assert mock_client.transport.url == "wss://test.endpoint" + assert mock_client.transport.url == "wss://new.endpoint" + assert mock_client.close_async.call_count == 0 + assert mock_client.connect_async.call_count == 1 + mock_client.reset_mock() - # Set new endpoint - tibber_rt.sub_endpoint = "wss://new.endpoint" + await tibber_rt.set_subscription_endpoint("wss://new.endpoint") assert mock_client.transport.url == "wss://new.endpoint" + assert mock_client.close_async.call_count == 0 + assert mock_client.connect_async.call_count == 0 + mock_client.reset_mock() + + await tibber_rt.set_subscription_endpoint("wss://another_connected.endpoint") + + assert mock_client.transport.url == "wss://another_connected.endpoint" + assert mock_client.close_async.call_count == 1 + assert mock_client.connect_async.call_count == 1 + mock_client.reset_mock() + + connect_event = asyncio.Event() + original_connect_async = mock_client.connect_async + + async def mock_connect_async(**kwargs: Any) -> MagicMock: # noqa: ANN401 + session = await original_connect_async(**kwargs) + await connect_event.wait() + return session + + mock_client.connect_async = AsyncMock(wraps=mock_connect_async) + + set_endpoint_task_1 = asyncio.create_task(tibber_rt.set_subscription_endpoint("wss://connected.endpoint.1")) + set_endpoint_task_2 = asyncio.create_task(tibber_rt.set_subscription_endpoint("wss://connected.endpoint.2")) + + await asyncio.sleep(0.1) + assert mock_client.transport.url == "wss://connected.endpoint.1" + connect_event.set() + await asyncio.gather(set_endpoint_task_1, set_endpoint_task_2) + + assert mock_client.transport.url == "wss://connected.endpoint.2" + assert mock_client.close_async.call_count == 2 + assert mock_client.connect_async.call_count == 2 async def test_websocket_transport() -> None: """Test websocket transport.""" + tibber_connected = asyncio.Event() transport = TibberWebsocketsTransport( url="wss://test.endpoint", access_token="test_token", user_agent="test_agent", + tibber_connected=tibber_connected, ) + transport.keep_alive_timeout = 0 mock_adapter = MagicMock(spec=WebSocketsAdapter) sent_messages: asyncio.Queue[str] = asyncio.Queue() @@ -151,6 +197,9 @@ async def mock_send(message: str) -> None: client = Client(transport=transport) connect_task = asyncio.create_task(client.connect_async()) + await tibber_connected.wait() + + assert tibber_connected.is_set() await connect_task await client.close_async() @@ -159,3 +208,155 @@ async def mock_send(message: str) -> None: assert mock_adapter.send.await_count == 1 mock_adapter.receive.assert_awaited() assert mock_adapter.close.await_count == 1 + assert not tibber_connected.is_set() + + +async def test_subscribe_raises_when_not_connected(tibber_rt: TibberRT) -> None: + """subscribe must raise RuntimeError when called before connect.""" + with pytest.raises(RuntimeError, match="Connect must be called before subscribe"): + await anext(tibber_rt.subscribe(MagicMock())) + + +async def test_subscribe_yields_results( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """subscribe must yield every item produced by the underlying session.""" + await tibber_rt.connect() + + sample = {"key": "value"} + + async def mock_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + yield sample + + mock_client.session.subscribe = mock_subscribe + + results = [item async for item in tibber_rt.subscribe(MagicMock())] + + assert results == [sample] + + +async def test_subscribe_transport_connection_failed_calls_on_error_and_raises_reconnected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """TransportConnectionFailed must call on_error, wait for reconnect, then raise WebsocketReconnectedError.""" + await tibber_rt.connect() + + err = TransportConnectionFailed("connection failed") + caught: list[Exception] = [] + + def on_error(exc: Exception) -> None: + caught.append(exc) + # Unblock _tibber_connected.wait() so the generator can finish + tibber_rt._tibber_connected.set() # noqa: SLF001 + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + with pytest.raises(WebsocketReconnectedError): + await anext(tibber_rt.subscribe(MagicMock(), on_error=on_error)) + + assert caught == [err] + + +async def test_subscribe_other_transport_error_raises_websocket_transport_error( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """A TransportError that is not TransportConnectionFailed must raise WebsocketTransportError.""" + await tibber_rt.connect() + + err = TransportError("generic transport error") + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + with pytest.raises(WebsocketTransportError): + await anext(tibber_rt.subscribe(MagicMock())) + + +async def test_reconnect_noop_when_not_connected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """reconnect must be a no-op when the client is not connected.""" + await tibber_rt.reconnect() + + mock_client.connect_async.assert_not_awaited() + mock_client.close_async.assert_not_awaited() + + +async def test_set_access_token_reconnects_with_new_token( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """set_access_token must update the token and reconnect so the new token is used.""" + await tibber_rt.connect() + mock_client.connect_async.reset_mock() + mock_client.close_async.reset_mock() + + await tibber_rt.set_access_token("new_token") + + mock_client.close_async.assert_awaited_once() + mock_client.connect_async.assert_awaited_once() + assert mock_client.transport.init_payload["token"] == "new_token" + + +@pytest.mark.parametrize("timeout", [0]) +async def test_connect_timeout_leaves_no_session_and_subscription_not_running( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """When connect_async times out, subscription_running must remain False and no session is set.""" + + async def slow_connect(**kwargs: Any) -> Any: # noqa: ANN401, ARG001 + await asyncio.sleep(9999) + + mock_client.connect_async = AsyncMock(side_effect=slow_connect) + + await tibber_rt.connect() + + assert tibber_rt.subscription_running is False + + await tibber_rt.disconnect() + + mock_client.close_async.assert_not_awaited() + assert tibber_rt.subscription_running is False + + +async def test_subscribe_transport_connection_failed_without_on_error_raises_reconnected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """TransportConnectionFailed with on_error=None must still raise WebsocketReconnectedError after reconnect.""" + await tibber_rt.connect() + + err = TransportConnectionFailed("connection failed") + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + async def set_connected_after_clear() -> None: + # Wait until subscribe() clears the event, then unblock the wait() + while tibber_rt._tibber_connected.is_set(): # noqa: ASYNC110, SLF001 + await asyncio.sleep(0) + tibber_rt._tibber_connected.set() # noqa: SLF001 + + unblock_task = asyncio.create_task(set_connected_after_clear()) + + with pytest.raises(WebsocketReconnectedError): + await anext(tibber_rt.subscribe(MagicMock(), on_error=None)) + + await unblock_task + + assert tibber_rt.subscription_running is True diff --git a/test/test_tibber.py b/test/test_tibber.py index 6b198b3..9b2b508 100644 --- a/test/test_tibber.py +++ b/test/test_tibber.py @@ -9,10 +9,8 @@ import pytest import tibber -import tibber.realtime as tibber_realtime from tibber.const import RESOLUTION_DAILY from tibber.exceptions import FatalHttpExceptionError, InvalidLoginError, NotForDemoUserError -from tibber.websocket_transport import TibberWebsocketsTransport @pytest.mark.asyncio @@ -179,18 +177,14 @@ def _callback(_: dict) -> None: @pytest.mark.asyncio -async def test_set_access_token_updates_clients_without_realtime(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_set_access_token_updates_clients(monkeypatch: pytest.MonkeyPatch) -> None: tibber_connection = tibber.Tibber( websession=MagicMock(), user_agent="test", ) - update_info = AsyncMock() - reconnect = AsyncMock() rt_set_access_token = AsyncMock() data_api_set_access_token = MagicMock() - monkeypatch.setattr(tibber_connection, "update_info", update_info) - monkeypatch.setattr(tibber_connection.realtime, "reconnect", reconnect) monkeypatch.setattr(tibber_connection.realtime, "set_access_token", rt_set_access_token) monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) @@ -198,78 +192,22 @@ async def test_set_access_token_updates_clients_without_realtime(monkeypatch: py rt_set_access_token.assert_awaited_once_with("new-token") data_api_set_access_token.assert_called_once_with("new-token") - update_info.assert_awaited_once_with() - reconnect.assert_not_awaited() @pytest.mark.asyncio -async def test_set_access_token_reconnects_active_realtime(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_set_access_token_noop_when_token_unchanged(monkeypatch: pytest.MonkeyPatch) -> None: tibber_connection = tibber.Tibber( + access_token="existing-token", websession=MagicMock(), user_agent="test", ) - calls: list[str] = [] - - async def fake_realtime_set_access_token(_access_token: str) -> None: - calls.append("realtime.set_access_token") - - async def fake_update_info() -> None: - calls.append("update_info") - - async def fake_reconnect() -> None: - calls.append("reconnect") - - monkeypatch.setattr( - type(tibber_connection.realtime), - "should_restore_connection", - property(lambda _: True), - ) - monkeypatch.setattr( - tibber_connection.realtime, - "set_access_token", - AsyncMock(side_effect=fake_realtime_set_access_token), - ) + rt_set_access_token = AsyncMock() data_api_set_access_token = MagicMock() - monkeypatch.setattr(tibber_connection, "update_info", AsyncMock(side_effect=fake_update_info)) - monkeypatch.setattr(tibber_connection.realtime, "reconnect", AsyncMock(side_effect=fake_reconnect)) - monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) - - await tibber_connection.set_access_token("new-token") - data_api_set_access_token.assert_called_once_with("new-token") - assert calls == ["realtime.set_access_token", "update_info", "reconnect"] + monkeypatch.setattr(tibber_connection.realtime, "set_access_token", rt_set_access_token) + monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) + await tibber_connection.set_access_token("existing-token") -@pytest.mark.asyncio -async def test_realtime_set_access_token_recreates_subscription_manager(monkeypatch: pytest.MonkeyPatch) -> None: - class FakeClient: - def __init__(self, transport: TibberWebsocketsTransport) -> None: - self.transport = transport - self.connect_async = AsyncMock(return_value=object()) - self.close_async_mock = AsyncMock() - self.close_async = self.close_async_mock - - monkeypatch.setattr(tibber_realtime, "Client", FakeClient) - - realtime = tibber_realtime.TibberRT("old-token", 10, "test-agent", True) - realtime.sub_endpoint = "wss://example.test/v1-beta/gql/subscriptions" - - await realtime.connect() - old_manager = realtime.sub_manager - assert old_manager is not None - assert isinstance(old_manager, FakeClient) - assert isinstance(old_manager.transport, TibberWebsocketsTransport) - assert old_manager.transport.init_payload["token"] == "old-token" - - await realtime.set_access_token("new-token") - - old_manager.close_async_mock.assert_awaited_once_with() - assert realtime.session is None - assert realtime.sub_manager is None - - await realtime.connect() - - assert realtime.sub_manager is not None - assert realtime.sub_manager is not old_manager - assert isinstance(realtime.sub_manager.transport, TibberWebsocketsTransport) - assert realtime.sub_manager.transport.init_payload["token"] == "new-token" + rt_set_access_token.assert_not_awaited() + data_api_set_access_token.assert_not_called() diff --git a/tibber/__init__.py b/tibber/__init__.py index cacad7b..8d981d1 100644 --- a/tibber/__init__.py +++ b/tibber/__init__.py @@ -65,10 +65,9 @@ def __init__( self.websession = websession self.timeout: int = timeout self._access_token: str = access_token - - self.realtime: TibberRT = TibberRT( - self._access_token, - self.timeout, + self.realtime = TibberRT( + access_token, + timeout, self._user_agent, ssl=ssl, ) @@ -110,6 +109,11 @@ async def execute( payload = {"query": document, "variables": variable_values or {}} + _LOGGER.debug( + "Executing query: %s with variables: %s", + document.replace(" ", "").replace("\n", "_"), + variable_values, + ) try: resp = await self.websession.post( API_ENDPOINT, @@ -144,8 +148,7 @@ async def update_info(self) -> None: return if sub_endpoint := viewer.get("websocketSubscriptionUrl"): - _LOGGER.debug("Using websocket subscription url %s", sub_endpoint) - self.realtime.sub_endpoint = sub_endpoint + await self.realtime.set_subscription_endpoint(sub_endpoint) self._name = viewer.get("name") self._user_id = viewer.get("userId") @@ -222,22 +225,19 @@ async def fetch_production_data_active_homes(self) -> None: ) async def rt_disconnect(self) -> None: - """Stop subscription manager. - This method simply calls the stop method of the SubscriptionManager if it is defined. - """ - return await self.realtime.disconnect() + """Stop subscription manager.""" + for home in self._homes.values(): + home.rt_unsubscribe() + await self.realtime.disconnect() async def set_access_token(self, access_token: str) -> None: + """Set access token and reauthorize clients.""" if access_token == self._access_token: return - """Set access token and reauthorize clients.""" - restore_realtime = self.realtime.should_restore_connection + _LOGGER.debug("Updating access token") self._access_token = access_token await self.realtime.set_access_token(access_token) self.data_api.set_access_token(access_token) - await self.update_info() - if restore_realtime: - await self.realtime.reconnect() @property def user_id(self) -> str | None: diff --git a/tibber/data_api.py b/tibber/data_api.py index 39a6195..4736194 100644 --- a/tibber/data_api.py +++ b/tibber/data_api.py @@ -98,6 +98,7 @@ async def _make_request( headers[aiohttp.hdrs.USER_AGENT] = self._user_agent response: aiohttp.ClientResponse | None = None + _LOGGER.debug("Request %s: %s with params %s", method, url, params) try: response = await self.websession.request( method, diff --git a/tibber/exceptions.py b/tibber/exceptions.py index 8be9a75..b2f2715 100644 --- a/tibber/exceptions.py +++ b/tibber/exceptions.py @@ -3,16 +3,20 @@ from .const import API_ERR_CODE_UNKNOWN -class SubscriptionEndpointMissingError(Exception): - """Exception raised when subscription endpoint is missing""" +class TibberError(Exception): + """Base exception for Tibber errors.""" -class UserAgentMissingError(Exception): - """Exception raised when user agent is missing""" +class SubscriptionEndpointMissingError(TibberError): + """Exception raised when subscription endpoint is missing.""" -class HttpExceptionError(Exception): - """Exception base for HTTP errors +class UserAgentMissingError(TibberError): + """Exception raised when user agent is missing.""" + + +class HttpExceptionError(TibberError): + """Exception base for HTTP errors. :param status: http response code :param message: http response message if any @@ -32,11 +36,11 @@ def __init__( class FatalHttpExceptionError(HttpExceptionError): - """Exception raised for HTTP codes that are non-retriable""" + """Exception raised for HTTP codes that are non-retriable.""" class RetryableHttpExceptionError(HttpExceptionError): - """Exception raised for HTTP codes that are possible to retry""" + """Exception raised for HTTP codes that are possible to retry.""" class RateLimitExceededError(RetryableHttpExceptionError): @@ -53,3 +57,15 @@ class InvalidLoginError(FatalHttpExceptionError): class NotForDemoUserError(FatalHttpExceptionError): """Exception raised when trying to use a feature not available for demo users""" + + +class WebsocketError(TibberError): + """Base exception for Tibber websocket errors.""" + + +class WebsocketReconnectedError(WebsocketError): + """Exception raised when websocket has been reconnected.""" + + +class WebsocketTransportError(WebsocketError): + """Exception raised when websocket transport fails.""" diff --git a/tibber/home.py b/tibber/home.py index 7ed1786..7d995b7 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -7,11 +7,13 @@ import contextlib import datetime as dt import logging +import warnings from typing import TYPE_CHECKING, Any from gql import gql from .const import RESOLUTION_DAILY, RESOLUTION_HOURLY, RESOLUTION_MONTHLY, RESOLUTION_WEEKLY +from .exceptions import WebsocketReconnectedError, WebsocketTransportError from .gql_queries import ( HISTORIC_DATA, HISTORIC_DATA_DATE, @@ -75,14 +77,13 @@ def __init__(self, home_id: str, tibber_control: Tibber) -> None: self.info: dict[str, dict[Any, Any]] = {} self.last_data_timestamp: dt.datetime | None = None - self._hourly_consumption_data: HourlyData = HourlyData() - self._hourly_production_data: HourlyData = HourlyData(production=True) - self._last_rt_data_received: dt.datetime = dt.datetime.now(tz=dt.UTC) - self._rt_listener: None | asyncio.Task[Any] = None + self._hourly_consumption_data = HourlyData() + self._hourly_production_data = HourlyData(production=True) + self._rt_listener: asyncio.Task[Any] | None = None self._rt_callback: Callable[..., Any] | None = None - self._rt_stopped: bool = True self._has_real_time_consumption: None | bool = None self._real_time_consumption_suggested_disabled: dt.datetime | None = None + self._resubscribe_task: asyncio.Task[None] | None = None async def _fetch_data(self, hourly_data: HourlyData) -> None: """Update hourly consumption or production data asynchronously.""" @@ -382,109 +383,146 @@ def current_price_data(self) -> tuple[float | None, dt.datetime | None, float | return round(price_total, 3), price_time, price_rank return None, None, None - async def rt_subscribe(self, callback: Callable[..., Any]) -> None: + async def rt_subscribe( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: """Connect to Tibber and subscribe to Tibber real time subscription. :param callback: The function to call when data is received. """ + if self._rt_listener is not None: + raise RuntimeError("Already subscribed to real time data, call rt_unsubscribe first") + _LOGGER.debug("Subscribe, %s", self.home_id) + self._rt_callback = callback + await self._tibber_control.realtime.connect() + self._rt_listener = asyncio.create_task(self._start_listen(callback, on_error=on_error)) - def _add_extra_data(data: dict[str, Any]) -> dict[str, Any]: - live_data = data["data"]["liveMeasurement"] - _timestamp = dt.datetime.fromisoformat(live_data["timestamp"]).astimezone(self._tibber_control.time_zone) - while self._rt_power and self._rt_power[0][0] < _timestamp - dt.timedelta(minutes=5): - self._rt_power.pop(0) + async def _start_listen( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: + """Subscribe to Tibber.""" + try: + async for _data in self._tibber_control.realtime.subscribe( + gql( + LIVE_SUBSCRIBE % self.home_id, + ), + on_error=on_error, + ): + data = {"data": _data} + with contextlib.suppress(KeyError): + data = self._add_extra_data(data) + _LOGGER.debug( + "Data received for %s: %s", + self.home_id, + data, + ) + callback(data) + except WebsocketReconnectedError as err: + _LOGGER.debug("Websocket reconnected for home %s, restarting subscription", self.home_id) + if on_error: + on_error(err) + except Exception as err: + if not isinstance(err, WebsocketTransportError): + _LOGGER.exception("Error in subscription") + else: + level = logging.DEBUG if on_error is not None else logging.ERROR + _LOGGER.log( + level, + "Error in subscription for home %s: %s: %s", + self.home_id, + err.__class__.__name__, + err, + ) + if on_error is not None: + on_error(err) - self._rt_power.append((_timestamp, live_data["power"] / 1000)) - if "lastMeterProduction" in live_data: - live_data["lastMeterProduction"] = max(0, live_data["lastMeterProduction"] or 0) + if self._resubscribe_task is not None: + self._resubscribe_task.cancel() + self._resubscribe_task = asyncio.create_task(self._rt_resubscribe(callback, on_error=on_error)) - if ( - (power_production := live_data.get("powerProduction")) - and power_production > 0 - and live_data.get("power") is None - ): - live_data["power"] = 0 + def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: + """Add extra data to live subscription result.""" + live_data = data["data"]["liveMeasurement"] + _timestamp = dt.datetime.fromisoformat(live_data["timestamp"]).astimezone(self._tibber_control.time_zone) + while self._rt_power and self._rt_power[0][0] < _timestamp - dt.timedelta(minutes=5): + self._rt_power.pop(0) - if live_data.get("power", 0) > 0 and live_data.get("powerProduction") is None: - live_data["powerProduction"] = 0 + self._rt_power.append((_timestamp, live_data["power"] / 1000)) + if "lastMeterProduction" in live_data: + live_data["lastMeterProduction"] = max(0, live_data["lastMeterProduction"] or 0) - current_hour = live_data["accumulatedConsumptionLastHour"] - if current_hour is not None: - power = sum(p[1] for p in self._rt_power) / len(self._rt_power) - live_data["estimatedHourConsumption"] = round( - current_hour + power * (3600 - (_timestamp.minute * 60 + _timestamp.second)) / 3600, - 3, - ) - if self._hourly_consumption_data.peak_hour and current_hour > self._hourly_consumption_data.peak_hour: - self._hourly_consumption_data.peak_hour = round(current_hour, 2) - self._hourly_consumption_data.peak_hour_time = _timestamp - return data - - async def _start() -> None: - """Subscribe to Tibber.""" - for _ in range(30): - if self._rt_stopped: - _LOGGER.debug("Stopping rt_subscribe") - return - if self._tibber_control.realtime.subscription_running: - break - - _LOGGER.debug("Waiting for rt_connect") - await asyncio.sleep(1) - else: - _LOGGER.error("rt not running") - return + if ( + (power_production := live_data.get("powerProduction")) + and power_production > 0 + and live_data.get("power") is None + ): + live_data["power"] = 0 - try: - session = self._tibber_control.realtime.session - if session is None or not hasattr(session, "subscribe"): - _LOGGER.error("Session is not connected or does not support subscribe method") - return - async for _data in session.subscribe( - gql(LIVE_SUBSCRIBE % self.home_id), - ): - data = {"data": _data} - with contextlib.suppress(KeyError): - data = _add_extra_data(data) - callback(data) - self._last_rt_data_received = dt.datetime.now(tz=dt.UTC) - _LOGGER.debug( - "Data received for %s: %s", - self.home_id, - data, - ) - if self._rt_stopped or not self._tibber_control.realtime.subscription_running: - _LOGGER.debug("Stopping rt_subscribe loop") - return - except Exception: - _LOGGER.exception("Error in rt_subscribe") + if live_data.get("power", 0) > 0 and live_data.get("powerProduction") is None: + live_data["powerProduction"] = 0 - self._rt_callback = callback - self._tibber_control.realtime.add_home(self) - await self._tibber_control.realtime.connect() - self._rt_listener = asyncio.create_task(_start()) - self._rt_stopped = False + current_hour = live_data["accumulatedConsumptionLastHour"] + if current_hour is not None: + power = sum(p[1] for p in self._rt_power) / len(self._rt_power) + live_data["estimatedHourConsumption"] = round( + current_hour + power * (3600 - (_timestamp.minute * 60 + _timestamp.second)) / 3600, + 3, + ) + if self._hourly_consumption_data.peak_hour and current_hour > self._hourly_consumption_data.peak_hour: + self._hourly_consumption_data.peak_hour = round(current_hour, 2) + self._hourly_consumption_data.peak_hour_time = _timestamp + return data async def rt_resubscribe(self) -> None: + """Resubscribe to Tibber data. + + Deprecated. Resubscription will happen automatically. + """ + if self._rt_callback is None: + raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") + + warnings.warn( + "TibberHome.rt_resubscribe is deprecated, resubscription will happen automatically", + DeprecationWarning, + stacklevel=2, + ) + await self._rt_resubscribe(self._rt_callback) + + async def _rt_resubscribe( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: """Resubscribe to Tibber data.""" - self.rt_unsubscribe() + if self._rt_listener is None: + _LOGGER.debug("No active subscription to resubscribe for %s", self.home_id) + return + _LOGGER.debug("Resubscribe, %s", self.home_id) - await asyncio.gather( - *[ - self.update_info(), - self._tibber_control.update_info(), - ], - return_exceptions=False, - ) - if self._rt_callback is None: - _LOGGER.warning("No callback set for rt_resubscribe") + self.rt_unsubscribe() + + with contextlib.suppress(Exception): + await self.update_info() # Update home info to check if real time is enabled + if not self.has_real_time_consumption: + _LOGGER.debug("Home %s does not have real time consumption enabled", self.home_id) return - await self.rt_subscribe(self._rt_callback) + + # Update info to set websocket subscription url + with contextlib.suppress(Exception): + await self._tibber_control.update_info() + + await self.rt_subscribe(callback, on_error=on_error) def rt_unsubscribe(self) -> None: """Unsubscribe to Tibber data.""" _LOGGER.debug("Unsubscribe, %s", self.home_id) - self._rt_stopped = True if self._rt_listener is None: return self._rt_listener.cancel() @@ -493,9 +531,7 @@ def rt_unsubscribe(self) -> None: @property def rt_subscription_running(self) -> bool: """Is real time subscription running.""" - if not self._tibber_control.realtime.subscription_running: - return False - return not self._last_rt_data_received < dt.datetime.now(tz=dt.UTC) - dt.timedelta(seconds=60) + return self._tibber_control.realtime.subscription_running and self._rt_listener is not None async def get_historic_data( self, diff --git a/tibber/realtime.py b/tibber/realtime.py index 5bdf645..fe358df 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -1,19 +1,27 @@ """Tibber RT connection.""" import asyncio -import datetime as dt import logging -import random +from collections.abc import AsyncGenerator, Callable from ssl import SSLContext -from typing import Any +from typing import TYPE_CHECKING, Any -from gql import Client +from gql import Client, GraphQLRequest +from gql.transport.exceptions import TransportClosed, TransportConnectionFailed, TransportError +from gql.transport.websockets import WebsocketsTransport +from tenacity import before_sleep_log, retry, wait_exponential_jitter -from .exceptions import SubscriptionEndpointMissingError -from .home import TibberHome -from .websocket_transport import TibberWebsocketsTransport +from .exceptions import SubscriptionEndpointMissingError, WebsocketReconnectedError, WebsocketTransportError +if TYPE_CHECKING: + from gql.client import AsyncClientSession + +KEEP_ALIVE_TIMEOUT = 90 LOCK_CONNECT = asyncio.Lock() +MIN_RECONNECT_INTERVAL = 1 +MAX_RECONNECT_INTERVAL = 60 +PING_INTERVAL = 30 +PONG_TIMEOUT = 20 _LOGGER = logging.getLogger(__name__) @@ -32,202 +40,170 @@ def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLCon self._timeout: int = timeout self._user_agent: str = user_agent self._ssl_context = ssl - self._sub_endpoint: str | None = None - self._homes: list[TibberHome] = [] - self._watchdog_runner: None | asyncio.Task[Any] = None - self._watchdog_running: bool = False - - self.sub_manager: Client | None = None - self.session: Any | None = None + self._tibber_connected = asyncio.Event() + self._client: Client | None = None + self.subscription_running = False + self._session: AsyncClientSession | None = None + + def _create_client(self) -> Client: + """Create a new gql Client with the current transport settings.""" + self._tibber_connected.clear() + return Client( + transport=TibberWebsocketsTransport( + self._sub_endpoint, + self._access_token, + self._user_agent, + ssl=self._ssl_context, + tibber_connected=self._tibber_connected, + ), + ) async def disconnect(self) -> None: - """Stop subscription manager. - This method simply calls the stop method of the SubscriptionManager if it is defined. - """ + """Disconnect the websocket client.""" _LOGGER.debug("Stopping subscription manager") - await self._reset_connection(unsubscribe_homes=True) - - async def _reset_connection(self, unsubscribe_homes: bool = False) -> None: - """Reset websocket connection state.""" - if self._watchdog_runner is not None: - _LOGGER.debug("Stopping watchdog") - self._watchdog_running = False - self._watchdog_runner.cancel() - self._watchdog_runner = None - if unsubscribe_homes: - for home in self._homes: - home.rt_unsubscribe() - try: - if self.session is not None and self.sub_manager is not None: - await self.sub_manager.close_async() - finally: - self.session = None - self.sub_manager = None + async with LOCK_CONNECT: + await self._disconnect() + + async def _disconnect(self) -> None: + """Disconnect the websocket client.""" + if self._client is not None and self._session is not None: + await self._client.close_async() + self._session = None + self.subscription_running = False async def connect(self) -> None: - """Start subscription manager.""" - self._create_sub_manager() + """Connect the websocket client.""" + async with LOCK_CONNECT: + await self._connect() - assert self.sub_manager is not None + async def _connect(self) -> None: + """Connect the websocket client.""" + if self._sub_endpoint is None: + raise SubscriptionEndpointMissingError("Subscription endpoint not initialized") - async with LOCK_CONNECT: - if self.subscription_running: - return - if self._watchdog_runner is None: - _LOGGER.debug("Starting watchdog") - self._watchdog_running = True - self._watchdog_runner = asyncio.create_task(self._watchdog()) - self.session = await self.sub_manager.connect_async() + if self.subscription_running or self._session: + return + + self._client = self._create_client() + try: + self._session = await asyncio.wait_for( + self._client.connect_async( + reconnecting=True, + retry_connect=retry( + wait=wait_exponential_jitter( + initial=MIN_RECONNECT_INTERVAL, + max=MAX_RECONNECT_INTERVAL, + jitter=MAX_RECONNECT_INTERVAL, + ), + before_sleep=before_sleep_log(_LOGGER, logging.INFO), + ), + ), + timeout=self._timeout, + ) + except TimeoutError as err: + _LOGGER.debug("Timeout connecting to websocket: %s", err) + # The connection will be retried by the reconnecting task + else: + self.subscription_running = True async def reconnect(self) -> None: - """Reconnect and resubscribe all homes.""" - await self.connect() - await self._resubscribe_homes() + """Reconnect the websocket client.""" + async with LOCK_CONNECT: + if self._session is None: + return + _LOGGER.debug("Reconnecting websocket client") + await self._disconnect() + await self._connect() async def set_access_token(self, access_token: str) -> None: """Set access token.""" - reconnect_running = self.subscription_running or self._watchdog_runner is not None self._access_token = access_token - await self._reset_connection(unsubscribe_homes=reconnect_running) + await self.reconnect() + + async def subscribe( + self, + request: GraphQLRequest, + *, + on_error: Callable[[Exception], None] | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Subscribe to a GraphQL query.""" + if self._session is None: + raise RuntimeError("Connect must be called before subscribe") - def _create_sub_manager(self) -> None: - if self.sub_endpoint is None: - raise SubscriptionEndpointMissingError("Subscription endpoint not initialized") - if self.sub_manager is not None: + try: + async for result in self._session.subscribe(request): + yield result + except TransportError as err: + _LOGGER.debug("%s: %s", err.__class__.__name__, err) + self.subscription_running = False + self._tibber_connected.clear() + if isinstance(err, TransportConnectionFailed): + if on_error: + on_error(err) + _LOGGER.debug("Waiting for reconnect") + await self._tibber_connected.wait() + self.subscription_running = True + _LOGGER.info("Reconnected") + raise WebsocketReconnectedError("Websocket reconnected") from err + raise WebsocketTransportError(err) from err + + async def set_subscription_endpoint(self, url: str) -> None: + """Set subscription endpoint.""" + old_url = self._sub_endpoint + if url == old_url: return - self.sub_manager = Client( - transport=TibberWebsocketsTransport( - self.sub_endpoint, - self._access_token, - self._user_agent, - ssl=self._ssl_context, - ), + _LOGGER.debug("Updating subscription endpoint to %s", url) + self._sub_endpoint = url + await self.reconnect() + + +class TibberWebsocketsTransport(WebsocketsTransport): + """Tibber websockets transport.""" + + def __init__( + self, + url: str, + access_token: str, + user_agent: str, + *, + ssl: SSLContext | bool = True, + tibber_connected: asyncio.Event, + ) -> None: + """Initialize TibberWebsocketsTransport.""" + super().__init__( + url=url, + init_payload={"token": access_token}, + headers={"User-Agent": user_agent}, + ssl=ssl, + keep_alive_timeout=KEEP_ALIVE_TIMEOUT, + ping_interval=PING_INTERVAL, + pong_timeout=PONG_TIMEOUT, ) + self._tibber_connected = tibber_connected + self._user_agent = user_agent - async def _watchdog(self) -> None: - """Watchdog to keep connection alive.""" - assert self.sub_manager is not None - assert isinstance(self.sub_manager.transport, TibberWebsocketsTransport) - - await asyncio.sleep(60) - - _retry_count = 0 - next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) - while self._watchdog_running: - await asyncio.sleep(5) - if ( - self.sub_manager.transport.running - and self.sub_manager.transport.reconnect_at - > dt.datetime.now( - tz=dt.UTC, - ) - and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running - ): - is_running = True - for home in self._homes: - _LOGGER.debug( - "Watchdog: Checking if home %s is alive, %s, %s", - home.home_id, - home.has_real_time_consumption, - home.rt_subscription_running, - ) - if not home.rt_subscription_running: - is_running = False - next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60) - break - _LOGGER.debug( - "Watchdog: Home %s is alive", - home.home_id, - ) - if is_running: - _retry_count = 0 - _LOGGER.debug("Watchdog: Connection is alive") - continue - - self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - _LOGGER.error( - "Watchdog: Connection is down, %s", - self.sub_manager.transport.reconnect_at, - ) - - try: - if self.session is not None: - await self.sub_manager.close_async() - self.session = None - except Exception: - _LOGGER.exception("Error in watchdog close") + async def _after_connect(self) -> None: + """Hook to add custom code for subclasses. - if not self._watchdog_running: - _LOGGER.debug("Watchdog: Stopping") - return + Called after the connection has been established. + """ + await super()._after_connect() + self._tibber_connected.set() - self._create_sub_manager() - try: - self.session = await self.sub_manager.connect_async() - await self._resubscribe_homes() - except Exception as err: # noqa: BLE001 - delay_seconds = min( - random.SystemRandom().randint(1, 30) + _retry_count**2, - 5 * 60, - ) - _retry_count += 1 - _LOGGER.error( - "Error in watchdog connect, retrying in %s seconds, %s: %s", - delay_seconds, - _retry_count, - err, - exc_info=_retry_count > 1, - ) - await asyncio.sleep(delay_seconds) - else: - _LOGGER.debug("Watchdog: Reconnected successfully") - await asyncio.sleep(60) - - async def _resubscribe_homes(self) -> None: - """Resubscribe to all homes.""" - _LOGGER.debug("Resubscribing to homes") - await asyncio.gather(*[home.rt_resubscribe() for home in self._homes]) - - def add_home(self, home: TibberHome) -> bool: - """Add home to real time subscription.""" - if home.has_real_time_consumption is False: - return False - if home in self._homes: - return False - self._homes.append(home) - return True - - @property - def subscription_running(self) -> bool: - """Is real time subscription running.""" - return ( - self.sub_manager is not None - and isinstance(self.sub_manager.transport, TibberWebsocketsTransport) - and self.sub_manager.transport.running - and self.session is not None - ) + async def close(self) -> None: + """Close the websocket connection. - @property - def should_restore_connection(self) -> bool: - """Whether realtime subscriptions should be restored after a reset.""" - return self.subscription_running or self._watchdog_runner is not None + This method is only called by the client. + """ + await self._fail(TransportClosed(f"Tibber websocket closed by {self._user_agent}")) + await self.wait_closed() - @property - def sub_endpoint(self) -> str | None: - """Get subscription endpoint.""" - return self._sub_endpoint + async def _close_hook(self) -> None: + """Hook called by WebsocketsTransportBase on connection close. - @sub_endpoint.setter - def sub_endpoint(self, sub_endpoint: str) -> None: - """Set subscription endpoint.""" - self._sub_endpoint = sub_endpoint - if self.sub_manager is not None and isinstance(self.sub_manager.transport, TibberWebsocketsTransport): - self.sub_manager = Client( - transport=TibberWebsocketsTransport( - sub_endpoint, - self._access_token, - self._user_agent, - ssl=self._ssl_context, - ), - ) + This method is called when the connection is closed + for any reason (not only by the client). + """ + self._tibber_connected.clear() + await super()._close_hook() diff --git a/tibber/websocket_transport.py b/tibber/websocket_transport.py deleted file mode 100644 index b4387b6..0000000 --- a/tibber/websocket_transport.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Websocket transport for Tibber.""" - -import asyncio -import datetime as dt -import logging -from ssl import SSLContext - -from gql.transport.exceptions import TransportClosed -from gql.transport.websockets import WebsocketsTransport -from websockets.asyncio.connection import State - -_LOGGER = logging.getLogger(__name__) - - -class TibberWebsocketsTransport(WebsocketsTransport): - """Tibber websockets transport.""" - - def __init__(self, url: str, access_token: str, user_agent: str, ssl: SSLContext | bool = True) -> None: - """Initialize TibberWebsocketsTransport. - Configures the gql.transport.websockets logger to WARNING level to suppress - verbose INFO-level messages (<<< and >>> websocket traffic logs). - """ - logging.getLogger("gql.transport.websockets").setLevel(logging.WARNING) - - super().__init__( - url=url, - init_payload={"token": access_token}, - headers={"User-Agent": user_agent}, - ping_interval=30, - ssl=ssl, - ) - self._user_agent: str = user_agent - self._timeout: int = 90 - self.reconnect_at: dt.datetime = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - - @property - def running(self) -> bool: - """Is real time subscription running.""" - return ( - hasattr(self, "adapter") - and hasattr(self.adapter, "websocket") - and self.adapter.websocket is not None - and self.adapter.websocket.state is State.OPEN - ) - - async def _receive(self) -> str: - """Wait the next message from the websocket connection.""" - try: - msg = await asyncio.wait_for(super()._receive(), timeout=self._timeout) - except TimeoutError: - _LOGGER.error("No data received from Tibber for %s seconds", self._timeout) - raise - self.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - return msg - - async def close(self) -> None: - """Close the websocket connection.""" - await self._fail(TransportClosed(f"Tibber websocket closed by {self._user_agent}")) - await self.wait_closed() From bf7313a054bb13bdbcef228cbb997791691be120 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Sat, 4 Apr 2026 22:01:14 +0200 Subject: [PATCH 02/14] Add home rt watchdog --- tibber/exceptions.py | 4 ++++ tibber/home.py | 48 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/tibber/exceptions.py b/tibber/exceptions.py index b2f2715..431feeb 100644 --- a/tibber/exceptions.py +++ b/tibber/exceptions.py @@ -11,6 +11,10 @@ class SubscriptionEndpointMissingError(TibberError): """Exception raised when subscription endpoint is missing.""" +class SubscriptionFailedError(TibberError): + """Exception raised when subscription fails.""" + + class UserAgentMissingError(TibberError): """Exception raised when user agent is missing.""" diff --git a/tibber/home.py b/tibber/home.py index 7d995b7..1c63605 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -7,13 +7,15 @@ import contextlib import datetime as dt import logging +import random +import time import warnings from typing import TYPE_CHECKING, Any from gql import gql from .const import RESOLUTION_DAILY, RESOLUTION_HOURLY, RESOLUTION_MONTHLY, RESOLUTION_WEEKLY -from .exceptions import WebsocketReconnectedError, WebsocketTransportError +from .exceptions import SubscriptionFailedError, WebsocketReconnectedError, WebsocketTransportError from .gql_queries import ( HISTORIC_DATA, HISTORIC_DATA_DATE, @@ -31,6 +33,7 @@ MIN_IN_HOUR: int = 60 MIN_IN_QUARTER: int = 15 +RT_SUBSCRIPTION_TIMEOUT = 60 class HourlyData: @@ -79,8 +82,10 @@ def __init__(self, home_id: str, tibber_control: Tibber) -> None: self._hourly_consumption_data = HourlyData() self._hourly_production_data = HourlyData(production=True) - self._rt_listener: asyncio.Task[Any] | None = None + self._last_rt_data_received: float | None = None + self._rt_listener: asyncio.Task[None] | None = None self._rt_callback: Callable[..., Any] | None = None + self._rt_subscription_timeout_task: asyncio.Task[None] | None = None self._has_real_time_consumption: None | bool = None self._real_time_consumption_suggested_disabled: dt.datetime | None = None self._resubscribe_task: asyncio.Task[None] | None = None @@ -399,6 +404,9 @@ async def rt_subscribe( self._rt_callback = callback await self._tibber_control.realtime.connect() self._rt_listener = asyncio.create_task(self._start_listen(callback, on_error=on_error)) + self._rt_subscription_timeout_task = asyncio.create_task( + self._rt_subscription_timeout(callback, on_error=on_error), + ) async def _start_listen( self, @@ -417,6 +425,7 @@ async def _start_listen( data = {"data": _data} with contextlib.suppress(KeyError): data = self._add_extra_data(data) + self._last_rt_data_received = time.time() _LOGGER.debug( "Data received for %s: %s", self.home_id, @@ -527,6 +536,41 @@ def rt_unsubscribe(self) -> None: return self._rt_listener.cancel() self._rt_listener = None + if self._rt_subscription_timeout_task is not None: + self._rt_subscription_timeout_task.cancel() + self._rt_subscription_timeout_task = None + self._last_rt_data_received = None + + async def _rt_subscription_timeout( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: + """Resubscribe if realtime subscription is unresponsive.""" + while True: + # Add some random time to avoid all homes resubscribing at the same time + # if there is an issue with the subscription + await asyncio.sleep(RT_SUBSCRIPTION_TIMEOUT + random.random() * RT_SUBSCRIPTION_TIMEOUT) # noqa: S311 + if ( + self._last_rt_data_received is None + or time.time() - self._last_rt_data_received <= RT_SUBSCRIPTION_TIMEOUT + ): + continue + if on_error: + on_error( + SubscriptionFailedError( + f"No real time data received for home {self.home_id} " + f"in the last {RT_SUBSCRIPTION_TIMEOUT} seconds", + ), + ) + else: + _LOGGER.error( + "No real time data received for home %s in the last %d seconds, resubscribing", + self.home_id, + RT_SUBSCRIPTION_TIMEOUT, + ) + await self._rt_resubscribe(callback=callback, on_error=on_error) @property def rt_subscription_running(self) -> bool: From 71630354a1cc1b48ac2f74ef94840b29e4a39094 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 8 Apr 2026 10:22:06 +0200 Subject: [PATCH 03/14] Log error when connection closed waiting for reconnect --- tibber/realtime.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tibber/realtime.py b/tibber/realtime.py index fe358df..c492d12 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -135,10 +135,11 @@ async def subscribe( async for result in self._session.subscribe(request): yield result except TransportError as err: - _LOGGER.debug("%s: %s", err.__class__.__name__, err) self.subscription_running = False self._tibber_connected.clear() if isinstance(err, TransportConnectionFailed): + level = logging.DEBUG if on_error is not None else logging.ERROR + _LOGGER.log(level, "%s: %s", err.__class__.__name__, err) if on_error: on_error(err) _LOGGER.debug("Waiting for reconnect") From 035cd85de6871d9cc9c7625f3687f9e74998388f Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 8 Apr 2026 11:06:40 +0200 Subject: [PATCH 04/14] Schedule resubscribe instead of awaiting it --- tibber/home.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tibber/home.py b/tibber/home.py index 1c63605..a65a8ac 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -451,9 +451,7 @@ async def _start_listen( if on_error is not None: on_error(err) - if self._resubscribe_task is not None: - self._resubscribe_task.cancel() - self._resubscribe_task = asyncio.create_task(self._rt_resubscribe(callback, on_error=on_error)) + self._schedule_resubscribe(callback, on_error=on_error) def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: """Add extra data to live subscription result.""" @@ -488,6 +486,16 @@ def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: self._hourly_consumption_data.peak_hour_time = _timestamp return data + def _schedule_resubscribe( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: + if self._resubscribe_task is not None: + self._resubscribe_task.cancel() + self._resubscribe_task = asyncio.create_task(self._rt_resubscribe(callback=callback, on_error=on_error)) + async def rt_resubscribe(self) -> None: """Resubscribe to Tibber data. @@ -570,7 +578,7 @@ async def _rt_subscription_timeout( self.home_id, RT_SUBSCRIPTION_TIMEOUT, ) - await self._rt_resubscribe(callback=callback, on_error=on_error) + self._schedule_resubscribe(callback, on_error=on_error) @property def rt_subscription_running(self) -> bool: From 82c930da1560d8fb65fc321a9afb185afbd29152 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 16:58:04 +0200 Subject: [PATCH 05/14] Handle missing first data in _last_rt_data_received --- tibber/home.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tibber/home.py b/tibber/home.py index a65a8ac..204f408 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -561,8 +561,8 @@ async def _rt_subscription_timeout( # if there is an issue with the subscription await asyncio.sleep(RT_SUBSCRIPTION_TIMEOUT + random.random() * RT_SUBSCRIPTION_TIMEOUT) # noqa: S311 if ( - self._last_rt_data_received is None - or time.time() - self._last_rt_data_received <= RT_SUBSCRIPTION_TIMEOUT + self._last_rt_data_received is not None + and time.time() - self._last_rt_data_received <= RT_SUBSCRIPTION_TIMEOUT ): continue if on_error: From 3b834d4e0ae79274f73beeb16a28cca67b117cad Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 17:19:02 +0200 Subject: [PATCH 06/14] Simplify home subscription method parameters --- tibber/home.py | 53 ++++++++++++++++++-------------------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/tibber/home.py b/tibber/home.py index 204f408..642b20c 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -85,6 +85,7 @@ def __init__(self, home_id: str, tibber_control: Tibber) -> None: self._last_rt_data_received: float | None = None self._rt_listener: asyncio.Task[None] | None = None self._rt_callback: Callable[..., Any] | None = None + self._rt_on_error: Callable[[Exception], None] | None = None self._rt_subscription_timeout_task: asyncio.Task[None] | None = None self._has_real_time_consumption: None | bool = None self._real_time_consumption_suggested_disabled: dt.datetime | None = None @@ -402,19 +403,17 @@ async def rt_subscribe( raise RuntimeError("Already subscribed to real time data, call rt_unsubscribe first") _LOGGER.debug("Subscribe, %s", self.home_id) self._rt_callback = callback + self._rt_on_error = on_error await self._tibber_control.realtime.connect() - self._rt_listener = asyncio.create_task(self._start_listen(callback, on_error=on_error)) + self._rt_listener = asyncio.create_task(self._start_listen()) self._rt_subscription_timeout_task = asyncio.create_task( - self._rt_subscription_timeout(callback, on_error=on_error), + self._rt_subscription_timeout(), ) - async def _start_listen( - self, - callback: Callable[..., Any], - *, - on_error: Callable[[Exception], None] | None = None, - ) -> None: + async def _start_listen(self) -> None: """Subscribe to Tibber.""" + callback = self._rt_callback + on_error = self._rt_on_error try: async for _data in self._tibber_control.realtime.subscribe( gql( @@ -451,7 +450,7 @@ async def _start_listen( if on_error is not None: on_error(err) - self._schedule_resubscribe(callback, on_error=on_error) + self._schedule_resubscribe() def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: """Add extra data to live subscription result.""" @@ -486,38 +485,28 @@ def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: self._hourly_consumption_data.peak_hour_time = _timestamp return data - def _schedule_resubscribe( - self, - callback: Callable[..., Any], - *, - on_error: Callable[[Exception], None] | None = None, - ) -> None: + def _schedule_resubscribe(self) -> None: if self._resubscribe_task is not None: self._resubscribe_task.cancel() - self._resubscribe_task = asyncio.create_task(self._rt_resubscribe(callback=callback, on_error=on_error)) + self._resubscribe_task = asyncio.create_task(self._rt_resubscribe()) async def rt_resubscribe(self) -> None: """Resubscribe to Tibber data. Deprecated. Resubscription will happen automatically. """ - if self._rt_callback is None: - raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") - warnings.warn( "TibberHome.rt_resubscribe is deprecated, resubscription will happen automatically", DeprecationWarning, stacklevel=2, ) - await self._rt_resubscribe(self._rt_callback) + await self._rt_resubscribe() - async def _rt_resubscribe( - self, - callback: Callable[..., Any], - *, - on_error: Callable[[Exception], None] | None = None, - ) -> None: + async def _rt_resubscribe(self) -> None: """Resubscribe to Tibber data.""" + if (callback := self._rt_callback) is None: + raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") + if self._rt_listener is None: _LOGGER.debug("No active subscription to resubscribe for %s", self.home_id) return @@ -535,7 +524,7 @@ async def _rt_resubscribe( with contextlib.suppress(Exception): await self._tibber_control.update_info() - await self.rt_subscribe(callback, on_error=on_error) + await self.rt_subscribe(callback, on_error=self._rt_on_error) def rt_unsubscribe(self) -> None: """Unsubscribe to Tibber data.""" @@ -549,13 +538,9 @@ def rt_unsubscribe(self) -> None: self._rt_subscription_timeout_task = None self._last_rt_data_received = None - async def _rt_subscription_timeout( - self, - callback: Callable[..., Any], - *, - on_error: Callable[[Exception], None] | None = None, - ) -> None: + async def _rt_subscription_timeout(self) -> None: """Resubscribe if realtime subscription is unresponsive.""" + on_error = self._rt_on_error while True: # Add some random time to avoid all homes resubscribing at the same time # if there is an issue with the subscription @@ -578,7 +563,7 @@ async def _rt_subscription_timeout( self.home_id, RT_SUBSCRIPTION_TIMEOUT, ) - self._schedule_resubscribe(callback, on_error=on_error) + self._schedule_resubscribe() @property def rt_subscription_running(self) -> bool: From 521952ddd15502f2cd4b9875f23c2c90113ebdc2 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 19:10:09 +0200 Subject: [PATCH 07/14] Always set subscription url and check realtime status before subscribing --- tibber/home.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tibber/home.py b/tibber/home.py index 642b20c..2c86482 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -398,12 +398,17 @@ async def rt_subscribe( """Connect to Tibber and subscribe to Tibber real time subscription. :param callback: The function to call when data is received. + :param on_error: The function to call when an error occurs. """ if self._rt_listener is not None: raise RuntimeError("Already subscribed to real time data, call rt_unsubscribe first") _LOGGER.debug("Subscribe, %s", self.home_id) self._rt_callback = callback self._rt_on_error = on_error + await self._rt_resubscribe() + + async def _rt_subscribe(self) -> None: + """Subscribe to Tibber real time subscription.""" await self._tibber_control.realtime.connect() self._rt_listener = asyncio.create_task(self._start_listen()) self._rt_subscription_timeout_task = asyncio.create_task( @@ -504,13 +509,6 @@ async def rt_resubscribe(self) -> None: async def _rt_resubscribe(self) -> None: """Resubscribe to Tibber data.""" - if (callback := self._rt_callback) is None: - raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") - - if self._rt_listener is None: - _LOGGER.debug("No active subscription to resubscribe for %s", self.home_id) - return - _LOGGER.debug("Resubscribe, %s", self.home_id) self.rt_unsubscribe() @@ -524,7 +522,7 @@ async def _rt_resubscribe(self) -> None: with contextlib.suppress(Exception): await self._tibber_control.update_info() - await self.rt_subscribe(callback, on_error=self._rt_on_error) + await self._rt_subscribe() def rt_unsubscribe(self) -> None: """Unsubscribe to Tibber data.""" From e29f982628d0f603b6c9cbbdca385fbc693e51da Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 20:07:49 +0200 Subject: [PATCH 08/14] Minimize realtime consumption enabled query --- tibber/gql_queries.py | 12 ++++++++++++ tibber/home.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tibber/gql_queries.py b/tibber/gql_queries.py index 601e890..caef1b6 100644 --- a/tibber/gql_queries.py +++ b/tibber/gql_queries.py @@ -112,6 +112,18 @@ }} }} """ +REAL_TIME_CONSUMPTION_ENABLED = """ + { + viewer { + home(id: "%s") { + id + features { + realTimeConsumptionEnabled + } + } + } + } + """ UPDATE_CURRENT_PRICE = """ { viewer { diff --git a/tibber/home.py b/tibber/home.py index 2c86482..43fa479 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -21,6 +21,7 @@ HISTORIC_DATA_DATE, HISTORIC_PRICE, LIVE_SUBSCRIBE, + REAL_TIME_CONSUMPTION_ENABLED, UPDATE_INFO_PRICE, ) @@ -247,6 +248,16 @@ async def update_info_and_price_info(self) -> None: _LOGGER.error("Malformed price info data for home %s: %s", self._home_id, err) self.price_total = {} + async def update_real_time_consumption_enabled(self) -> None: + """Update the real time consumption enabled status.""" + if not (data := await self._tibber_control.execute(REAL_TIME_CONSUMPTION_ENABLED % self._home_id)): + _LOGGER.error("Could not get the data.") + return + self.info["viewer"]["home"]["features"]["realTimeConsumptionEnabled"] = data["viewer"]["home"]["features"][ + "realTimeConsumptionEnabled" + ] + self._update_has_real_time_consumption() + def _update_has_real_time_consumption(self) -> None: try: _has_real_time_consumption = self.info["viewer"]["home"]["features"]["realTimeConsumptionEnabled"] @@ -513,7 +524,7 @@ async def _rt_resubscribe(self) -> None: self.rt_unsubscribe() with contextlib.suppress(Exception): - await self.update_info() # Update home info to check if real time is enabled + await self.update_real_time_consumption_enabled() if not self.has_real_time_consumption: _LOGGER.debug("Home %s does not have real time consumption enabled", self.home_id) return From 9f1316e52935d8ae9b2aa33ff50d2e4d5fba88e0 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 21:05:48 +0200 Subject: [PATCH 09/14] Reconnect websocket on subscription timeout according to guide --- tibber/home.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tibber/home.py b/tibber/home.py index 43fa479..ad9abf3 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -538,10 +538,9 @@ async def _rt_resubscribe(self) -> None: def rt_unsubscribe(self) -> None: """Unsubscribe to Tibber data.""" _LOGGER.debug("Unsubscribe, %s", self.home_id) - if self._rt_listener is None: - return - self._rt_listener.cancel() - self._rt_listener = None + if self._rt_listener is not None: + self._rt_listener.cancel() + self._rt_listener = None if self._rt_subscription_timeout_task is not None: self._rt_subscription_timeout_task.cancel() self._rt_subscription_timeout_task = None @@ -568,10 +567,13 @@ async def _rt_subscription_timeout(self) -> None: ) else: _LOGGER.error( - "No real time data received for home %s in the last %d seconds, resubscribing", + "No real time data received for home %s in the last %d seconds, reconnecting and resubscribing", self.home_id, RT_SUBSCRIPTION_TIMEOUT, ) + self._rt_listener.cancel() + self._rt_listener = None + await self._tibber_control.realtime.reconnect() self._schedule_resubscribe() @property From 97ed62ca51702a6dc77f8f14425498bc820c5cfe Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 21:41:50 +0200 Subject: [PATCH 10/14] Add callback check back --- tibber/home.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tibber/home.py b/tibber/home.py index ad9abf3..026d6fe 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -516,6 +516,8 @@ async def rt_resubscribe(self) -> None: DeprecationWarning, stacklevel=2, ) + if self._rt_callback is None: + raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") await self._rt_resubscribe() async def _rt_resubscribe(self) -> None: From 8b85c35e3ebab991288846a66806b947a725114a Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 22:05:10 +0200 Subject: [PATCH 11/14] Fix test_rt_subscribe_on_error_called_on_exception --- test/test_home.py | 86 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 4 deletions(-) diff --git a/test/test_home.py b/test/test_home.py index fa12668..c23b588 100644 --- a/test/test_home.py +++ b/test/test_home.py @@ -11,7 +11,7 @@ import tibber from tibber.exceptions import WebsocketReconnectedError, WebsocketTransportError -from tibber.gql_queries import INFO, UPDATE_INFO_PRICE +from tibber.gql_queries import INFO, REAL_TIME_CONSUMPTION_ENABLED from tibber.realtime import TibberRT if TYPE_CHECKING: @@ -132,13 +132,33 @@ def callback(data: dict) -> None: ( False, [ + # Initial subscription (first call returns True) call( "https://api.tibber.com/v1-beta/gql", headers={ "Authorization": "Bearer test-token", "User-Agent": "test", }, - data={"query": UPDATE_INFO_PRICE % HOME_ID, "variables": {}}, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + # Resubscription (returns False, so no INFO call) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, timeout=aiohttp.ClientTimeout(total=10), ), ], @@ -146,13 +166,33 @@ def callback(data: dict) -> None: ( True, [ + # Initial subscription (first call returns True) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), call( "https://api.tibber.com/v1-beta/gql", headers={ "Authorization": "Bearer test-token", "User-Agent": "test", }, - data={"query": UPDATE_INFO_PRICE % HOME_ID, "variables": {}}, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + # Resubscription (returns True) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, timeout=aiohttp.ClientTimeout(total=10), ), call( @@ -185,7 +225,45 @@ async def test_rt_subscribe_on_error_called_on_exception( http_calls: list, ) -> None: """on_error must be called when subscribe raises an exception.""" - home._has_real_time_consumption = real_time_consumption # noqa: SLF001 + # Initialize info structure so update_real_time_consumption_enabled can update it + home.info = { + "viewer": { + "home": { + "features": {"realTimeConsumptionEnabled": real_time_consumption}, + }, + }, + } + + # Track which call number we're on to return different responses + call_count = 0 + + def make_response(rt_enabled: bool) -> MagicMock: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.content_type = "application/json" + mock_response.json = AsyncMock( + return_value={ + "data": { + "viewer": { + "home": { + "id": HOME_ID, + "features": {"realTimeConsumptionEnabled": rt_enabled}, + }, + }, + }, + }, + ) + return mock_response + + async def post_side_effect(*args: Any, **kwargs: Any) -> MagicMock: # noqa: ARG001, ANN401 + nonlocal call_count + call_count += 1 + # First call (initial subscription) always returns True to start subscription + # Subsequent calls (resubscription) return the real_time_consumption value + return make_response(True if call_count == 1 else real_time_consumption) + + mock_websession.post.side_effect = post_side_effect + wait_for_events = asyncio.Event() wait_for_events.set() # allow subscribe to raise immediately From dda72b515d0b67a81930b128fb2a7391af71fda3 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 23:07:11 +0200 Subject: [PATCH 12/14] Wait for http calls in tests --- test/test_home.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_home.py b/test/test_home.py index c23b588..3344cea 100644 --- a/test/test_home.py +++ b/test/test_home.py @@ -236,6 +236,7 @@ async def test_rt_subscribe_on_error_called_on_exception( # Track which call number we're on to return different responses call_count = 0 + resubscribe_called = asyncio.Event() def make_response(rt_enabled: bool) -> MagicMock: mock_response = MagicMock() @@ -258,9 +259,12 @@ def make_response(rt_enabled: bool) -> MagicMock: async def post_side_effect(*args: Any, **kwargs: Any) -> MagicMock: # noqa: ARG001, ANN401 nonlocal call_count call_count += 1 - # First call (initial subscription) always returns True to start subscription + # First two calls (initial subscription) always returns True to start subscription # Subsequent calls (resubscription) return the real_time_consumption value - return make_response(True if call_count == 1 else real_time_consumption) + if call_count <= 2: + return make_response(True) + resubscribe_called.set() + return make_response(real_time_consumption) mock_websession.post.side_effect = post_side_effect @@ -286,7 +290,8 @@ def on_error(exc: Exception) -> None: await asyncio.wait_for(on_error_called.wait(), timeout=1.0) assert caught == [error] - # resubscription should have been triggered + # resubscription should have been triggered - wait for HTTP calls to complete + await asyncio.wait_for(resubscribe_called.wait(), timeout=1.0) assert mock_websession.post.call_count == len(http_calls) assert mock_websession.post.call_args_list == http_calls assert home.rt_subscription_running is real_time_consumption From d98513d38cc3b8920e319d6df1909ebecaa85965 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Fri, 17 Apr 2026 23:55:04 +0200 Subject: [PATCH 13/14] Wait some random time before resubscribing --- test/test_home.py | 3 ++- tibber/home.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_home.py b/test/test_home.py index 3344cea..da445bd 100644 --- a/test/test_home.py +++ b/test/test_home.py @@ -4,7 +4,7 @@ import asyncio from typing import TYPE_CHECKING, Any -from unittest.mock import AsyncMock, MagicMock, call, create_autospec +from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch import aiohttp import pytest @@ -216,6 +216,7 @@ def callback(data: dict) -> None: RuntimeError("unexpected"), ], ) +@patch("tibber.home.RESUBSCRIBE_WAIT_TIME", 0) async def test_rt_subscribe_on_error_called_on_exception( mock_websession: MagicMock, home: tibber.TibberHome, diff --git a/tibber/home.py b/tibber/home.py index 026d6fe..4770e20 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -35,6 +35,7 @@ MIN_IN_HOUR: int = 60 MIN_IN_QUARTER: int = 15 RT_SUBSCRIPTION_TIMEOUT = 60 +RESUBSCRIBE_WAIT_TIME = 60 class HourlyData: @@ -466,6 +467,7 @@ async def _start_listen(self) -> None: if on_error is not None: on_error(err) + await asyncio.sleep(random.random() * RESUBSCRIBE_WAIT_TIME) # noqa: S311 self._schedule_resubscribe() def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: From baa999e50a1495646b9ce2c9d88956df2a54aad8 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Sat, 18 Apr 2026 00:25:39 +0200 Subject: [PATCH 14/14] Handle connect timeout correctly --- test/test_realtime.py | 13 ++++--------- tibber/realtime.py | 25 +++++++++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/test_realtime.py b/test/test_realtime.py index ac08a8d..97d987b 100644 --- a/test/test_realtime.py +++ b/test/test_realtime.py @@ -57,6 +57,7 @@ def create_client( async def mock_connect_async(**kwargs: Any) -> MagicMock: # noqa: ANN401, ARG001 session = mock_client.session = MagicMock(spec=AsyncClientSession) mock_client.transport.adapter.websocket = MagicMock(state=State.OPEN) + await asyncio.sleep(0) # Simulate some delay in connecting return session mock_client.connect_async = AsyncMock(wraps=mock_connect_async) @@ -320,24 +321,18 @@ async def test_set_access_token_reconnects_with_new_token( @pytest.mark.parametrize("timeout", [0]) -async def test_connect_timeout_leaves_no_session_and_subscription_not_running( +async def test_connect_timeout_leaves_with_session_and_subscription_not_running( mock_client: MagicMock, tibber_rt: TibberRT, ) -> None: - """When connect_async times out, subscription_running must remain False and no session is set.""" - - async def slow_connect(**kwargs: Any) -> Any: # noqa: ANN401, ARG001 - await asyncio.sleep(9999) - - mock_client.connect_async = AsyncMock(side_effect=slow_connect) - + """When connect_async times out, a session should be set and subscription_running must remain False.""" await tibber_rt.connect() assert tibber_rt.subscription_running is False await tibber_rt.disconnect() - mock_client.close_async.assert_not_awaited() + mock_client.close_async.assert_awaited_once() assert tibber_rt.subscription_running is False diff --git a/tibber/realtime.py b/tibber/realtime.py index c492d12..8553a95 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -4,7 +4,7 @@ import logging from collections.abc import AsyncGenerator, Callable from ssl import SSLContext -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from gql import Client, GraphQLRequest from gql.transport.exceptions import TransportClosed, TransportConnectionFailed, TransportError @@ -87,16 +87,18 @@ async def _connect(self) -> None: self._client = self._create_client() try: - self._session = await asyncio.wait_for( - self._client.connect_async( - reconnecting=True, - retry_connect=retry( - wait=wait_exponential_jitter( - initial=MIN_RECONNECT_INTERVAL, - max=MAX_RECONNECT_INTERVAL, - jitter=MAX_RECONNECT_INTERVAL, + await asyncio.wait_for( + asyncio.shield( + self._client.connect_async( + reconnecting=True, + retry_connect=retry( + wait=wait_exponential_jitter( + initial=MIN_RECONNECT_INTERVAL, + max=MAX_RECONNECT_INTERVAL, + jitter=MAX_RECONNECT_INTERVAL, + ), + before_sleep=before_sleep_log(_LOGGER, logging.INFO), ), - before_sleep=before_sleep_log(_LOGGER, logging.INFO), ), ), timeout=self._timeout, @@ -107,6 +109,9 @@ async def _connect(self) -> None: else: self.subscription_running = True + # The client session is set even if the connection times out. + self._session = cast("AsyncClientSession", self._client.session) + async def reconnect(self) -> None: """Reconnect the websocket client.""" async with LOCK_CONNECT: