diff --git a/pyproject.toml b/pyproject.toml index ab98971..b363013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ "aiohttp>=3.0.6", "gql>=4.0.0", + "tenacity>=9.0.0", "websockets>=14.0.0", ] dynamic = ["version"] diff --git a/test/test_home.py b/test/test_home.py new file mode 100644 index 0000000..da445bd --- /dev/null +++ b/test/test_home.py @@ -0,0 +1,365 @@ +"""Tests for TibberHome.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch + +import aiohttp +import pytest + +import tibber +from tibber.exceptions import WebsocketReconnectedError, WebsocketTransportError +from tibber.gql_queries import INFO, REAL_TIME_CONSUMPTION_ENABLED +from tibber.realtime import TibberRT + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +HOME_ID = "test-home-id" + + +@pytest.fixture +def tibber_connection(mock_websession: MagicMock) -> tibber.Tibber: + tibber_client = tibber.Tibber( + access_token="test-token", + websession=mock_websession, + user_agent="test", + ) + tibber_client._user_agent = "test" # noqa: SLF001 + return tibber_client + + +@pytest.fixture +def mock_websession() -> MagicMock: + session = MagicMock(spec=aiohttp.ClientSession) + session.post = AsyncMock() + return session + + +@pytest.fixture +def mock_realtime(tibber_connection: tibber.Tibber) -> MagicMock: + rt = create_autospec(TibberRT, instance=True, subscription_running=False) + rt.connect = AsyncMock(side_effect=lambda: setattr(rt, "subscription_running", True)) + tibber_connection.realtime = rt + return rt + + +@pytest.fixture +def home(tibber_connection: tibber.Tibber) -> tibber.TibberHome: + home = tibber.TibberHome(HOME_ID, tibber_connection) + home._has_real_time_consumption = True # noqa: SLF001 + return home + + +def _make_blocking_subscribe( + yielded: list[Any], +) -> tuple[asyncio.Event, Any]: + """Return (release_event, subscribe_fn) that yields *yielded* then blocks.""" + release = asyncio.Event() + + async def subscribe(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]: # noqa: ANN401, ARG001 + for item in yielded: + yield item + await release.wait() + + return release, subscribe + + +async def test_rt_subscribe_connects_and_calls_callback( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """Test that rt_subscribe connects via realtime and delivers subscription data to the callback.""" + sample_data = {"key": "value"} + _, subscribe_fn = _make_blocking_subscribe([sample_data]) + mock_realtime.subscribe = subscribe_fn + + received: list[dict] = [] + callback_called = asyncio.Event() + + def callback(data: dict) -> None: + received.append(data) + callback_called.set() + + await home.rt_subscribe(callback) + await asyncio.wait_for(callback_called.wait(), timeout=1.0) + + mock_realtime.connect.assert_awaited_once() + assert received == [{"data": sample_data}] + assert home.rt_subscription_running + + home.rt_unsubscribe() + assert not home.rt_subscription_running + + +async def test_rt_unsubscribe_noop_when_not_subscribed(home: tibber.TibberHome) -> None: + """Calling rt_unsubscribe on a fresh home must not raise.""" + assert not home.rt_subscription_running + home.rt_unsubscribe() # should be a no-op + assert not home.rt_subscription_running + + +async def test_rt_subscribe_multiple_items_all_delivered( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """All items yielded by subscribe must be delivered to the callback in order.""" + items = [{"n": 1}, {"n": 2}, {"n": 3}] + _, subscribe_fn = _make_blocking_subscribe(items) + mock_realtime.subscribe = subscribe_fn + + received: list[dict] = [] + all_received = asyncio.Event() + + def callback(data: dict) -> None: + received.append(data) + if len(received) == len(items): + all_received.set() + + await home.rt_subscribe(callback) + await asyncio.wait_for(all_received.wait(), timeout=1.0) + + assert received == [{"data": item} for item in items] + + home.rt_unsubscribe() + + +@pytest.mark.parametrize( + ("real_time_consumption", "http_calls"), + [ + ( + False, + [ + # Initial subscription (first call returns True) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + # Resubscription (returns False, so no INFO call) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + ], + ), + ( + True, + [ + # Initial subscription (first call returns True) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + # Resubscription (returns True) + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": REAL_TIME_CONSUMPTION_ENABLED % HOME_ID, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + call( + "https://api.tibber.com/v1-beta/gql", + headers={ + "Authorization": "Bearer test-token", + "User-Agent": "test", + }, + data={"query": INFO, "variables": {}}, + timeout=aiohttp.ClientTimeout(total=10), + ), + ], + ), + ], +) +@pytest.mark.parametrize( + "error", + [ + WebsocketReconnectedError("reconnected"), + WebsocketTransportError("transport error"), + RuntimeError("unexpected"), + ], +) +@patch("tibber.home.RESUBSCRIBE_WAIT_TIME", 0) +async def test_rt_subscribe_on_error_called_on_exception( + mock_websession: MagicMock, + home: tibber.TibberHome, + mock_realtime: MagicMock, + error: Exception, + real_time_consumption: bool, + http_calls: list, +) -> None: + """on_error must be called when subscribe raises an exception.""" + # Initialize info structure so update_real_time_consumption_enabled can update it + home.info = { + "viewer": { + "home": { + "features": {"realTimeConsumptionEnabled": real_time_consumption}, + }, + }, + } + + # Track which call number we're on to return different responses + call_count = 0 + resubscribe_called = asyncio.Event() + + def make_response(rt_enabled: bool) -> MagicMock: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.content_type = "application/json" + mock_response.json = AsyncMock( + return_value={ + "data": { + "viewer": { + "home": { + "id": HOME_ID, + "features": {"realTimeConsumptionEnabled": rt_enabled}, + }, + }, + }, + }, + ) + return mock_response + + async def post_side_effect(*args: Any, **kwargs: Any) -> MagicMock: # noqa: ARG001, ANN401 + nonlocal call_count + call_count += 1 + # First two calls (initial subscription) always returns True to start subscription + # Subsequent calls (resubscription) return the real_time_consumption value + if call_count <= 2: + return make_response(True) + resubscribe_called.set() + return make_response(real_time_consumption) + + mock_websession.post.side_effect = post_side_effect + + wait_for_events = asyncio.Event() + wait_for_events.set() # allow subscribe to raise immediately + + async def subscribe_raises(*args: Any, **kwargs: Any) -> AsyncGenerator: # noqa: ANN401, ARG001 + await wait_for_events.wait() + raise error + yield + + mock_realtime.subscribe = subscribe_raises + + on_error_called = asyncio.Event() + caught: list[Exception] = [] + + def on_error(exc: Exception) -> None: + caught.append(exc) + on_error_called.set() + wait_for_events.clear() # allow test to control the flow after error is caught + + await home.rt_subscribe(MagicMock(), on_error=on_error) + await asyncio.wait_for(on_error_called.wait(), timeout=1.0) + + assert caught == [error] + # resubscription should have been triggered - wait for HTTP calls to complete + await asyncio.wait_for(resubscribe_called.wait(), timeout=1.0) + assert mock_websession.post.call_count == len(http_calls) + assert mock_websession.post.call_args_list == http_calls + assert home.rt_subscription_running is real_time_consumption + + home.rt_unsubscribe() + + assert not home.rt_subscription_running + + +async def test_rt_subscribe_no_crash_when_subscribe_raises_without_on_error( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """_start_listen must not propagate exceptions when no on_error is provided.""" + + async def subscribe_raises(*args: Any, **kwargs: Any) -> AsyncGenerator: # noqa: ANN401, ARG001 + raise WebsocketTransportError("transport error") + yield + + mock_realtime.subscribe = subscribe_raises + + callback = MagicMock() + await home.rt_subscribe(callback) + + # give the listener task a chance to run and finish without raising + await asyncio.sleep(0) + await asyncio.sleep(0) + + callback.assert_not_called() + home.rt_unsubscribe() + + +async def test_rt_resubscribe_raises_without_prior_subscribe(home: tibber.TibberHome) -> None: + """rt_resubscribe must raise RuntimeError when rt_subscribe has not been called.""" + with pytest.raises(RuntimeError, match="rt_subscribe"): + await home.rt_resubscribe() + + +async def test_rt_subscribe_raises_when_already_subscribed( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """rt_subscribe must raise RuntimeError when called while already subscribed.""" + _, subscribe_fn = _make_blocking_subscribe([]) + mock_realtime.subscribe = subscribe_fn + + callback = MagicMock() + await home.rt_subscribe(callback) + + with pytest.raises(RuntimeError, match="rt_unsubscribe"): + await home.rt_subscribe(callback) + + home.rt_unsubscribe() + + +async def test_rt_resubscribe_emits_deprecation_warning( + home: tibber.TibberHome, + mock_realtime: MagicMock, +) -> None: + """rt_resubscribe must emit a DeprecationWarning.""" + _, subscribe_fn = _make_blocking_subscribe([]) + mock_realtime.subscribe = subscribe_fn + + callback = MagicMock() + await home.rt_subscribe(callback) + + with pytest.warns(DeprecationWarning, match="deprecated"): + await home.rt_resubscribe() + + home.rt_unsubscribe() diff --git a/test/test_realtime.py b/test/test_realtime.py index bdc070d..97d987b 100644 --- a/test/test_realtime.py +++ b/test/test_realtime.py @@ -4,31 +4,37 @@ import asyncio import json -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from gql.client import AsyncClientSession, Client from gql.transport.common.adapters.websockets import WebSocketsAdapter +from gql.transport.exceptions import TransportConnectionFailed, TransportError from websockets.asyncio.connection import State -from tibber.realtime import TibberRT -from tibber.websocket_transport import TibberWebsocketsTransport +from tibber.exceptions import SubscriptionEndpointMissingError, WebsocketReconnectedError, WebsocketTransportError +from tibber.realtime import TibberRT, TibberWebsocketsTransport if TYPE_CHECKING: from collections.abc import Generator +@pytest.fixture +def timeout() -> int: + return 30 + + @pytest.fixture(name="tibber_rt") -def tibber_rt_fixture() -> TibberRT: +async def tibber_rt_fixture(mock_client: MagicMock, timeout: int) -> TibberRT: # noqa: ARG001, ASYNC109 """Create a TibberRT instance for testing.""" tibber_rt = TibberRT( access_token="test_token", - timeout=30, + timeout=timeout, user_agent="test_agent", ssl=True, ) - tibber_rt.sub_endpoint = "wss://test.endpoint" + await tibber_rt.set_subscription_endpoint("wss://test.endpoint") return tibber_rt @@ -51,6 +57,7 @@ def create_client( async def mock_connect_async(**kwargs: Any) -> MagicMock: # noqa: ANN401, ARG001 session = mock_client.session = MagicMock(spec=AsyncClientSession) mock_client.transport.adapter.websocket = MagicMock(state=State.OPEN) + await asyncio.sleep(0) # Simulate some delay in connecting return session mock_client.connect_async = AsyncMock(wraps=mock_connect_async) @@ -66,12 +73,14 @@ async def test_connect_disconnect( # Should not raise await tibber_rt.disconnect() + mock_client.close_async.assert_not_awaited() + # First connect - transport not running, so connect_async should be called await tibber_rt.connect() mock_client.connect_async.assert_awaited_once() - # Second connect should not call connect_async again since subscription_running is True + # Second connect should not call connect_async again since the client is already connected await tibber_rt.connect() # connect_async should still only have been called once @@ -83,20 +92,15 @@ async def test_connect_disconnect( async def test_subscription_running( - mock_client: MagicMock, tibber_rt: TibberRT, ) -> None: - """Test subscription_running.""" + """Test subscription running.""" assert tibber_rt.subscription_running is False await tibber_rt.connect() assert tibber_rt.subscription_running is True - mock_client.transport.adapter.websocket.state = State.CLOSED - - assert tibber_rt.subscription_running is False - await tibber_rt.disconnect() assert tibber_rt.subscription_running is False @@ -105,59 +109,83 @@ async def test_subscription_running( assert tibber_rt.subscription_running is True - mock_client.transport.adapter.websocket = None - assert tibber_rt.subscription_running is False +async def test_update_endpoint(mock_client: MagicMock) -> None: + """Test update subscription endpoint.""" + tibber_rt = TibberRT( + access_token="test_token", + timeout=30, + user_agent="test_agent", + ssl=True, + ) + with pytest.raises(SubscriptionEndpointMissingError, match="Subscription endpoint not initialized"): + await tibber_rt.connect() -async def test_update_endpoint(mock_client: MagicMock, tibber_rt: TibberRT) -> None: - """Delay endpoint replacement until the current connection is reset.""" + mock_client.reset_mock() + await tibber_rt.set_subscription_endpoint("wss://new.endpoint") await tibber_rt.connect() - assert mock_client.transport.url == "wss://test.endpoint" + assert mock_client.transport.url == "wss://new.endpoint" + assert mock_client.close_async.call_count == 0 + assert mock_client.connect_async.call_count == 1 + mock_client.reset_mock() - # Set new endpoint - tibber_rt.sub_endpoint = "wss://new.endpoint" + await tibber_rt.set_subscription_endpoint("wss://new.endpoint") - assert tibber_rt.sub_endpoint == "wss://new.endpoint" - assert mock_client.transport.url == "wss://test.endpoint" + assert mock_client.transport.url == "wss://new.endpoint" + assert mock_client.close_async.call_count == 0 + assert mock_client.connect_async.call_count == 0 + mock_client.reset_mock() await tibber_rt.disconnect() await tibber_rt.connect() assert mock_client.transport.url == "wss://new.endpoint" + assert mock_client.close_async.call_count == 1 + assert mock_client.connect_async.call_count == 1 + mock_client.reset_mock() + await tibber_rt.set_subscription_endpoint("wss://another_connected.endpoint") -async def test_close_sub_manager_skips_clients_without_session( - tibber_rt: TibberRT, -) -> None: - """Avoid calling gql close_async when the client never got a session.""" + assert mock_client.transport.url == "wss://another_connected.endpoint" + assert mock_client.close_async.call_count == 1 + assert mock_client.connect_async.call_count == 1 + mock_client.reset_mock() - class FakeClient: - def __init__(self) -> None: - self.transport = TibberWebsocketsTransport( - url="wss://test.endpoint", - access_token="test_token", - user_agent="test_agent", - ) - self.close_async = AsyncMock() + connect_event = asyncio.Event() + original_connect_async = mock_client.connect_async - mock_client = FakeClient() + async def mock_connect_async(**kwargs: Any) -> MagicMock: # noqa: ANN401 + session = await original_connect_async(**kwargs) + await connect_event.wait() + return session - tibber_rt.sub_manager = cast("Client", mock_client) + mock_client.connect_async = AsyncMock(wraps=mock_connect_async) - await tibber_rt.disconnect() + set_endpoint_task_1 = asyncio.create_task(tibber_rt.set_subscription_endpoint("wss://connected.endpoint.1")) + set_endpoint_task_2 = asyncio.create_task(tibber_rt.set_subscription_endpoint("wss://connected.endpoint.2")) - mock_client.close_async.assert_not_awaited() + await asyncio.sleep(0.1) + assert mock_client.transport.url == "wss://connected.endpoint.1" + connect_event.set() + await asyncio.gather(set_endpoint_task_1, set_endpoint_task_2) + + assert mock_client.transport.url == "wss://connected.endpoint.2" + assert mock_client.close_async.call_count == 2 + assert mock_client.connect_async.call_count == 2 async def test_websocket_transport() -> None: """Test websocket transport.""" + tibber_connected = asyncio.Event() transport = TibberWebsocketsTransport( url="wss://test.endpoint", access_token="test_token", user_agent="test_agent", + tibber_connected=tibber_connected, ) + transport.keep_alive_timeout = 0 mock_adapter = MagicMock(spec=WebSocketsAdapter) sent_messages: asyncio.Queue[str] = asyncio.Queue() @@ -180,6 +208,9 @@ async def mock_send(message: str) -> None: client = Client(transport=transport) connect_task = asyncio.create_task(client.connect_async()) + await tibber_connected.wait() + + assert tibber_connected.is_set() await connect_task await client.close_async() @@ -188,3 +219,149 @@ async def mock_send(message: str) -> None: assert mock_adapter.send.await_count == 1 mock_adapter.receive.assert_awaited() assert mock_adapter.close.await_count == 1 + assert not tibber_connected.is_set() + + +async def test_subscribe_raises_when_not_connected(tibber_rt: TibberRT) -> None: + """subscribe must raise RuntimeError when called before connect.""" + with pytest.raises(RuntimeError, match="Connect must be called before subscribe"): + await anext(tibber_rt.subscribe(MagicMock())) + + +async def test_subscribe_yields_results( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """subscribe must yield every item produced by the underlying session.""" + await tibber_rt.connect() + + sample = {"key": "value"} + + async def mock_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + yield sample + + mock_client.session.subscribe = mock_subscribe + + results = [item async for item in tibber_rt.subscribe(MagicMock())] + + assert results == [sample] + + +async def test_subscribe_transport_connection_failed_calls_on_error_and_raises_reconnected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """TransportConnectionFailed must call on_error, wait for reconnect, then raise WebsocketReconnectedError.""" + await tibber_rt.connect() + + err = TransportConnectionFailed("connection failed") + caught: list[Exception] = [] + + def on_error(exc: Exception) -> None: + caught.append(exc) + # Unblock _tibber_connected.wait() so the generator can finish + tibber_rt._tibber_connected.set() # noqa: SLF001 + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + with pytest.raises(WebsocketReconnectedError): + await anext(tibber_rt.subscribe(MagicMock(), on_error=on_error)) + + assert caught == [err] + + +async def test_subscribe_other_transport_error_raises_websocket_transport_error( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """A TransportError that is not TransportConnectionFailed must raise WebsocketTransportError.""" + await tibber_rt.connect() + + err = TransportError("generic transport error") + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + with pytest.raises(WebsocketTransportError): + await anext(tibber_rt.subscribe(MagicMock())) + + +async def test_reconnect_noop_when_not_connected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """reconnect must be a no-op when the client is not connected.""" + await tibber_rt.reconnect() + + mock_client.connect_async.assert_not_awaited() + mock_client.close_async.assert_not_awaited() + + +async def test_set_access_token_reconnects_with_new_token( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """set_access_token must update the token and reconnect so the new token is used.""" + await tibber_rt.connect() + mock_client.connect_async.reset_mock() + mock_client.close_async.reset_mock() + + await tibber_rt.set_access_token("new_token") + + mock_client.close_async.assert_awaited_once() + mock_client.connect_async.assert_awaited_once() + assert mock_client.transport.init_payload["token"] == "new_token" + + +@pytest.mark.parametrize("timeout", [0]) +async def test_connect_timeout_leaves_with_session_and_subscription_not_running( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """When connect_async times out, a session should be set and subscription_running must remain False.""" + await tibber_rt.connect() + + assert tibber_rt.subscription_running is False + + await tibber_rt.disconnect() + + mock_client.close_async.assert_awaited_once() + assert tibber_rt.subscription_running is False + + +async def test_subscribe_transport_connection_failed_without_on_error_raises_reconnected( + mock_client: MagicMock, + tibber_rt: TibberRT, +) -> None: + """TransportConnectionFailed with on_error=None must still raise WebsocketReconnectedError after reconnect.""" + await tibber_rt.connect() + + err = TransportConnectionFailed("connection failed") + + async def failing_subscribe(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG001 + raise err + yield + + mock_client.session.subscribe = failing_subscribe + + async def set_connected_after_clear() -> None: + # Wait until subscribe() clears the event, then unblock the wait() + while tibber_rt._tibber_connected.is_set(): # noqa: ASYNC110, SLF001 + await asyncio.sleep(0) + tibber_rt._tibber_connected.set() # noqa: SLF001 + + unblock_task = asyncio.create_task(set_connected_after_clear()) + + with pytest.raises(WebsocketReconnectedError): + await anext(tibber_rt.subscribe(MagicMock(), on_error=None)) + + await unblock_task + + assert tibber_rt.subscription_running is True diff --git a/test/test_tibber.py b/test/test_tibber.py index 3b74a94..9b2b508 100644 --- a/test/test_tibber.py +++ b/test/test_tibber.py @@ -9,10 +9,8 @@ import pytest import tibber -import tibber.realtime as tibber_realtime from tibber.const import RESOLUTION_DAILY from tibber.exceptions import FatalHttpExceptionError, InvalidLoginError, NotForDemoUserError -from tibber.websocket_transport import TibberWebsocketsTransport @pytest.mark.asyncio @@ -179,16 +177,14 @@ def _callback(_: dict) -> None: @pytest.mark.asyncio -async def test_set_access_token_updates_clients_without_realtime(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_set_access_token_updates_clients(monkeypatch: pytest.MonkeyPatch) -> None: tibber_connection = tibber.Tibber( websession=MagicMock(), user_agent="test", ) - reconnect = AsyncMock() rt_set_access_token = AsyncMock() data_api_set_access_token = MagicMock() - monkeypatch.setattr(tibber_connection.realtime, "reconnect", reconnect) monkeypatch.setattr(tibber_connection.realtime, "set_access_token", rt_set_access_token) monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) @@ -196,79 +192,22 @@ async def test_set_access_token_updates_clients_without_realtime(monkeypatch: py rt_set_access_token.assert_awaited_once_with("new-token") data_api_set_access_token.assert_called_once_with("new-token") - reconnect.assert_not_awaited() @pytest.mark.asyncio -async def test_set_access_token_reconnects_active_realtime(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_set_access_token_noop_when_token_unchanged(monkeypatch: pytest.MonkeyPatch) -> None: tibber_connection = tibber.Tibber( + access_token="existing-token", websession=MagicMock(), user_agent="test", ) - calls: list[str] = [] - - async def fake_realtime_set_access_token(_access_token: str) -> None: - calls.append("realtime.set_access_token") - - async def fake_reconnect() -> None: - calls.append("reconnect") - - monkeypatch.setattr( - type(tibber_connection.realtime), - "should_restore_connection", - property(lambda _: True), - ) - monkeypatch.setattr( - tibber_connection.realtime, - "set_access_token", - AsyncMock(side_effect=fake_realtime_set_access_token), - ) + rt_set_access_token = AsyncMock() data_api_set_access_token = MagicMock() - monkeypatch.setattr(tibber_connection.realtime, "reconnect", AsyncMock(side_effect=fake_reconnect)) - monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) - - await tibber_connection.set_access_token("new-token") - - data_api_set_access_token.assert_called_once_with("new-token") - assert calls == ["realtime.set_access_token", "reconnect"] + monkeypatch.setattr(tibber_connection.realtime, "set_access_token", rt_set_access_token) + monkeypatch.setattr(tibber_connection.data_api, "set_access_token", data_api_set_access_token) -@pytest.mark.asyncio -async def test_realtime_set_access_token_recreates_subscription_manager(monkeypatch: pytest.MonkeyPatch) -> None: - class FakeClient: - def __init__(self, transport: TibberWebsocketsTransport) -> None: - self.transport = transport - self.close_async_mock = AsyncMock() - self.close_async = self.close_async_mock - - async def mock_connect_async() -> object: - session = object() - self.session = session - return session - - self.connect_async = AsyncMock(side_effect=mock_connect_async) - - monkeypatch.setattr(tibber_realtime, "Client", FakeClient) - - realtime = tibber_realtime.TibberRT("old-token", 10, "test-agent", True) - realtime.sub_endpoint = "wss://example.test/v1-beta/gql/subscriptions" - - await realtime.connect() - old_manager = realtime.sub_manager - assert old_manager is not None - assert isinstance(old_manager, FakeClient) - assert isinstance(old_manager.transport, TibberWebsocketsTransport) - assert old_manager.transport.init_payload["token"] == "old-token" - - await realtime.set_access_token("new-token") - - old_manager.close_async_mock.assert_awaited_once_with() - assert realtime.session is None - assert realtime.sub_manager is None - - await realtime.connect() + await tibber_connection.set_access_token("existing-token") - assert realtime.sub_manager is not None - assert realtime.sub_manager is not old_manager - assert isinstance(realtime.sub_manager.transport, TibberWebsocketsTransport) - assert realtime.sub_manager.transport.init_payload["token"] == "new-token" + rt_set_access_token.assert_not_awaited() + data_api_set_access_token.assert_not_called() diff --git a/tibber/__init__.py b/tibber/__init__.py index bef72a6..8d981d1 100644 --- a/tibber/__init__.py +++ b/tibber/__init__.py @@ -65,10 +65,9 @@ def __init__( self.websession = websession self.timeout: int = timeout self._access_token: str = access_token - - self.realtime: TibberRT = TibberRT( - self._access_token, - self.timeout, + self.realtime = TibberRT( + access_token, + timeout, self._user_agent, ssl=ssl, ) @@ -110,6 +109,11 @@ async def execute( payload = {"query": document, "variables": variable_values or {}} + _LOGGER.debug( + "Executing query: %s with variables: %s", + document.replace(" ", "").replace("\n", "_"), + variable_values, + ) try: resp = await self.websession.post( API_ENDPOINT, @@ -144,8 +148,7 @@ async def update_info(self) -> None: return if sub_endpoint := viewer.get("websocketSubscriptionUrl"): - _LOGGER.debug("Using websocket subscription url %s", sub_endpoint) - self.realtime.sub_endpoint = sub_endpoint + await self.realtime.set_subscription_endpoint(sub_endpoint) self._name = viewer.get("name") self._user_id = viewer.get("userId") @@ -222,21 +225,19 @@ async def fetch_production_data_active_homes(self) -> None: ) async def rt_disconnect(self) -> None: - """Stop subscription manager. - This method simply calls the stop method of the SubscriptionManager if it is defined. - """ - return await self.realtime.disconnect() + """Stop subscription manager.""" + for home in self._homes.values(): + home.rt_unsubscribe() + await self.realtime.disconnect() async def set_access_token(self, access_token: str) -> None: + """Set access token and reauthorize clients.""" if access_token == self._access_token: return - """Set access token and reauthorize clients.""" - restore_realtime = self.realtime.should_restore_connection + _LOGGER.debug("Updating access token") self._access_token = access_token await self.realtime.set_access_token(access_token) self.data_api.set_access_token(access_token) - if restore_realtime: - await self.realtime.reconnect() @property def user_id(self) -> str | None: diff --git a/tibber/data_api.py b/tibber/data_api.py index 39a6195..4736194 100644 --- a/tibber/data_api.py +++ b/tibber/data_api.py @@ -98,6 +98,7 @@ async def _make_request( headers[aiohttp.hdrs.USER_AGENT] = self._user_agent response: aiohttp.ClientResponse | None = None + _LOGGER.debug("Request %s: %s with params %s", method, url, params) try: response = await self.websession.request( method, diff --git a/tibber/exceptions.py b/tibber/exceptions.py index 8be9a75..431feeb 100644 --- a/tibber/exceptions.py +++ b/tibber/exceptions.py @@ -3,16 +3,24 @@ from .const import API_ERR_CODE_UNKNOWN -class SubscriptionEndpointMissingError(Exception): - """Exception raised when subscription endpoint is missing""" +class TibberError(Exception): + """Base exception for Tibber errors.""" -class UserAgentMissingError(Exception): - """Exception raised when user agent is missing""" +class SubscriptionEndpointMissingError(TibberError): + """Exception raised when subscription endpoint is missing.""" -class HttpExceptionError(Exception): - """Exception base for HTTP errors +class SubscriptionFailedError(TibberError): + """Exception raised when subscription fails.""" + + +class UserAgentMissingError(TibberError): + """Exception raised when user agent is missing.""" + + +class HttpExceptionError(TibberError): + """Exception base for HTTP errors. :param status: http response code :param message: http response message if any @@ -32,11 +40,11 @@ def __init__( class FatalHttpExceptionError(HttpExceptionError): - """Exception raised for HTTP codes that are non-retriable""" + """Exception raised for HTTP codes that are non-retriable.""" class RetryableHttpExceptionError(HttpExceptionError): - """Exception raised for HTTP codes that are possible to retry""" + """Exception raised for HTTP codes that are possible to retry.""" class RateLimitExceededError(RetryableHttpExceptionError): @@ -53,3 +61,15 @@ class InvalidLoginError(FatalHttpExceptionError): class NotForDemoUserError(FatalHttpExceptionError): """Exception raised when trying to use a feature not available for demo users""" + + +class WebsocketError(TibberError): + """Base exception for Tibber websocket errors.""" + + +class WebsocketReconnectedError(WebsocketError): + """Exception raised when websocket has been reconnected.""" + + +class WebsocketTransportError(WebsocketError): + """Exception raised when websocket transport fails.""" diff --git a/tibber/gql_queries.py b/tibber/gql_queries.py index 601e890..caef1b6 100644 --- a/tibber/gql_queries.py +++ b/tibber/gql_queries.py @@ -112,6 +112,18 @@ }} }} """ +REAL_TIME_CONSUMPTION_ENABLED = """ + { + viewer { + home(id: "%s") { + id + features { + realTimeConsumptionEnabled + } + } + } + } + """ UPDATE_CURRENT_PRICE = """ { viewer { diff --git a/tibber/home.py b/tibber/home.py index 7ed1786..4770e20 100644 --- a/tibber/home.py +++ b/tibber/home.py @@ -7,16 +7,21 @@ import contextlib import datetime as dt import logging +import random +import time +import warnings from typing import TYPE_CHECKING, Any from gql import gql from .const import RESOLUTION_DAILY, RESOLUTION_HOURLY, RESOLUTION_MONTHLY, RESOLUTION_WEEKLY +from .exceptions import SubscriptionFailedError, WebsocketReconnectedError, WebsocketTransportError from .gql_queries import ( HISTORIC_DATA, HISTORIC_DATA_DATE, HISTORIC_PRICE, LIVE_SUBSCRIBE, + REAL_TIME_CONSUMPTION_ENABLED, UPDATE_INFO_PRICE, ) @@ -29,6 +34,8 @@ MIN_IN_HOUR: int = 60 MIN_IN_QUARTER: int = 15 +RT_SUBSCRIPTION_TIMEOUT = 60 +RESUBSCRIBE_WAIT_TIME = 60 class HourlyData: @@ -75,14 +82,16 @@ def __init__(self, home_id: str, tibber_control: Tibber) -> None: self.info: dict[str, dict[Any, Any]] = {} self.last_data_timestamp: dt.datetime | None = None - self._hourly_consumption_data: HourlyData = HourlyData() - self._hourly_production_data: HourlyData = HourlyData(production=True) - self._last_rt_data_received: dt.datetime = dt.datetime.now(tz=dt.UTC) - self._rt_listener: None | asyncio.Task[Any] = None + self._hourly_consumption_data = HourlyData() + self._hourly_production_data = HourlyData(production=True) + self._last_rt_data_received: float | None = None + self._rt_listener: asyncio.Task[None] | None = None self._rt_callback: Callable[..., Any] | None = None - self._rt_stopped: bool = True + self._rt_on_error: Callable[[Exception], None] | None = None + self._rt_subscription_timeout_task: asyncio.Task[None] | None = None self._has_real_time_consumption: None | bool = None self._real_time_consumption_suggested_disabled: dt.datetime | None = None + self._resubscribe_task: asyncio.Task[None] | None = None async def _fetch_data(self, hourly_data: HourlyData) -> None: """Update hourly consumption or production data asynchronously.""" @@ -240,6 +249,16 @@ async def update_info_and_price_info(self) -> None: _LOGGER.error("Malformed price info data for home %s: %s", self._home_id, err) self.price_total = {} + async def update_real_time_consumption_enabled(self) -> None: + """Update the real time consumption enabled status.""" + if not (data := await self._tibber_control.execute(REAL_TIME_CONSUMPTION_ENABLED % self._home_id)): + _LOGGER.error("Could not get the data.") + return + self.info["viewer"]["home"]["features"]["realTimeConsumptionEnabled"] = data["viewer"]["home"]["features"][ + "realTimeConsumptionEnabled" + ] + self._update_has_real_time_consumption() + def _update_has_real_time_consumption(self) -> None: try: _has_real_time_consumption = self.info["viewer"]["home"]["features"]["realTimeConsumptionEnabled"] @@ -382,120 +401,189 @@ def current_price_data(self) -> tuple[float | None, dt.datetime | None, float | return round(price_total, 3), price_time, price_rank return None, None, None - async def rt_subscribe(self, callback: Callable[..., Any]) -> None: + async def rt_subscribe( + self, + callback: Callable[..., Any], + *, + on_error: Callable[[Exception], None] | None = None, + ) -> None: """Connect to Tibber and subscribe to Tibber real time subscription. :param callback: The function to call when data is received. + :param on_error: The function to call when an error occurs. """ + if self._rt_listener is not None: + raise RuntimeError("Already subscribed to real time data, call rt_unsubscribe first") + _LOGGER.debug("Subscribe, %s", self.home_id) + self._rt_callback = callback + self._rt_on_error = on_error + await self._rt_resubscribe() - def _add_extra_data(data: dict[str, Any]) -> dict[str, Any]: - live_data = data["data"]["liveMeasurement"] - _timestamp = dt.datetime.fromisoformat(live_data["timestamp"]).astimezone(self._tibber_control.time_zone) - while self._rt_power and self._rt_power[0][0] < _timestamp - dt.timedelta(minutes=5): - self._rt_power.pop(0) - - self._rt_power.append((_timestamp, live_data["power"] / 1000)) - if "lastMeterProduction" in live_data: - live_data["lastMeterProduction"] = max(0, live_data["lastMeterProduction"] or 0) + async def _rt_subscribe(self) -> None: + """Subscribe to Tibber real time subscription.""" + await self._tibber_control.realtime.connect() + self._rt_listener = asyncio.create_task(self._start_listen()) + self._rt_subscription_timeout_task = asyncio.create_task( + self._rt_subscription_timeout(), + ) - if ( - (power_production := live_data.get("powerProduction")) - and power_production > 0 - and live_data.get("power") is None + async def _start_listen(self) -> None: + """Subscribe to Tibber.""" + callback = self._rt_callback + on_error = self._rt_on_error + try: + async for _data in self._tibber_control.realtime.subscribe( + gql( + LIVE_SUBSCRIBE % self.home_id, + ), + on_error=on_error, ): - live_data["power"] = 0 - - if live_data.get("power", 0) > 0 and live_data.get("powerProduction") is None: - live_data["powerProduction"] = 0 - - current_hour = live_data["accumulatedConsumptionLastHour"] - if current_hour is not None: - power = sum(p[1] for p in self._rt_power) / len(self._rt_power) - live_data["estimatedHourConsumption"] = round( - current_hour + power * (3600 - (_timestamp.minute * 60 + _timestamp.second)) / 3600, - 3, + data = {"data": _data} + with contextlib.suppress(KeyError): + data = self._add_extra_data(data) + self._last_rt_data_received = time.time() + _LOGGER.debug( + "Data received for %s: %s", + self.home_id, + data, ) - if self._hourly_consumption_data.peak_hour and current_hour > self._hourly_consumption_data.peak_hour: - self._hourly_consumption_data.peak_hour = round(current_hour, 2) - self._hourly_consumption_data.peak_hour_time = _timestamp - return data - - async def _start() -> None: - """Subscribe to Tibber.""" - for _ in range(30): - if self._rt_stopped: - _LOGGER.debug("Stopping rt_subscribe") - return - if self._tibber_control.realtime.subscription_running: - break - - _LOGGER.debug("Waiting for rt_connect") - await asyncio.sleep(1) + callback(data) + except WebsocketReconnectedError as err: + _LOGGER.debug("Websocket reconnected for home %s, restarting subscription", self.home_id) + if on_error: + on_error(err) + except Exception as err: + if not isinstance(err, WebsocketTransportError): + _LOGGER.exception("Error in subscription") else: - _LOGGER.error("rt not running") - return + level = logging.DEBUG if on_error is not None else logging.ERROR + _LOGGER.log( + level, + "Error in subscription for home %s: %s: %s", + self.home_id, + err.__class__.__name__, + err, + ) + if on_error is not None: + on_error(err) - try: - session = self._tibber_control.realtime.session - if session is None or not hasattr(session, "subscribe"): - _LOGGER.error("Session is not connected or does not support subscribe method") - return - async for _data in session.subscribe( - gql(LIVE_SUBSCRIBE % self.home_id), - ): - data = {"data": _data} - with contextlib.suppress(KeyError): - data = _add_extra_data(data) - callback(data) - self._last_rt_data_received = dt.datetime.now(tz=dt.UTC) - _LOGGER.debug( - "Data received for %s: %s", - self.home_id, - data, - ) - if self._rt_stopped or not self._tibber_control.realtime.subscription_running: - _LOGGER.debug("Stopping rt_subscribe loop") - return - except Exception: - _LOGGER.exception("Error in rt_subscribe") + await asyncio.sleep(random.random() * RESUBSCRIBE_WAIT_TIME) # noqa: S311 + self._schedule_resubscribe() - self._rt_callback = callback - self._tibber_control.realtime.add_home(self) - await self._tibber_control.realtime.connect() - self._rt_listener = asyncio.create_task(_start()) - self._rt_stopped = False + def _add_extra_data(self, data: dict[str, Any]) -> dict[str, Any]: + """Add extra data to live subscription result.""" + live_data = data["data"]["liveMeasurement"] + _timestamp = dt.datetime.fromisoformat(live_data["timestamp"]).astimezone(self._tibber_control.time_zone) + while self._rt_power and self._rt_power[0][0] < _timestamp - dt.timedelta(minutes=5): + self._rt_power.pop(0) + + self._rt_power.append((_timestamp, live_data["power"] / 1000)) + if "lastMeterProduction" in live_data: + live_data["lastMeterProduction"] = max(0, live_data["lastMeterProduction"] or 0) + + if ( + (power_production := live_data.get("powerProduction")) + and power_production > 0 + and live_data.get("power") is None + ): + live_data["power"] = 0 + + if live_data.get("power", 0) > 0 and live_data.get("powerProduction") is None: + live_data["powerProduction"] = 0 + + current_hour = live_data["accumulatedConsumptionLastHour"] + if current_hour is not None: + power = sum(p[1] for p in self._rt_power) / len(self._rt_power) + live_data["estimatedHourConsumption"] = round( + current_hour + power * (3600 - (_timestamp.minute * 60 + _timestamp.second)) / 3600, + 3, + ) + if self._hourly_consumption_data.peak_hour and current_hour > self._hourly_consumption_data.peak_hour: + self._hourly_consumption_data.peak_hour = round(current_hour, 2) + self._hourly_consumption_data.peak_hour_time = _timestamp + return data + + def _schedule_resubscribe(self) -> None: + if self._resubscribe_task is not None: + self._resubscribe_task.cancel() + self._resubscribe_task = asyncio.create_task(self._rt_resubscribe()) async def rt_resubscribe(self) -> None: - """Resubscribe to Tibber data.""" - self.rt_unsubscribe() - _LOGGER.debug("Resubscribe, %s", self.home_id) - await asyncio.gather( - *[ - self.update_info(), - self._tibber_control.update_info(), - ], - return_exceptions=False, + """Resubscribe to Tibber data. + + Deprecated. Resubscription will happen automatically. + """ + warnings.warn( + "TibberHome.rt_resubscribe is deprecated, resubscription will happen automatically", + DeprecationWarning, + stacklevel=2, ) if self._rt_callback is None: - _LOGGER.warning("No callback set for rt_resubscribe") + raise RuntimeError("No callback set for rt_resubscribe, call rt_subscribe first") + await self._rt_resubscribe() + + async def _rt_resubscribe(self) -> None: + """Resubscribe to Tibber data.""" + _LOGGER.debug("Resubscribe, %s", self.home_id) + self.rt_unsubscribe() + + with contextlib.suppress(Exception): + await self.update_real_time_consumption_enabled() + if not self.has_real_time_consumption: + _LOGGER.debug("Home %s does not have real time consumption enabled", self.home_id) return - await self.rt_subscribe(self._rt_callback) + + # Update info to set websocket subscription url + with contextlib.suppress(Exception): + await self._tibber_control.update_info() + + await self._rt_subscribe() def rt_unsubscribe(self) -> None: """Unsubscribe to Tibber data.""" _LOGGER.debug("Unsubscribe, %s", self.home_id) - self._rt_stopped = True - if self._rt_listener is None: - return - self._rt_listener.cancel() - self._rt_listener = None + if self._rt_listener is not None: + self._rt_listener.cancel() + self._rt_listener = None + if self._rt_subscription_timeout_task is not None: + self._rt_subscription_timeout_task.cancel() + self._rt_subscription_timeout_task = None + self._last_rt_data_received = None + + async def _rt_subscription_timeout(self) -> None: + """Resubscribe if realtime subscription is unresponsive.""" + on_error = self._rt_on_error + while True: + # Add some random time to avoid all homes resubscribing at the same time + # if there is an issue with the subscription + await asyncio.sleep(RT_SUBSCRIPTION_TIMEOUT + random.random() * RT_SUBSCRIPTION_TIMEOUT) # noqa: S311 + if ( + self._last_rt_data_received is not None + and time.time() - self._last_rt_data_received <= RT_SUBSCRIPTION_TIMEOUT + ): + continue + if on_error: + on_error( + SubscriptionFailedError( + f"No real time data received for home {self.home_id} " + f"in the last {RT_SUBSCRIPTION_TIMEOUT} seconds", + ), + ) + else: + _LOGGER.error( + "No real time data received for home %s in the last %d seconds, reconnecting and resubscribing", + self.home_id, + RT_SUBSCRIPTION_TIMEOUT, + ) + self._rt_listener.cancel() + self._rt_listener = None + await self._tibber_control.realtime.reconnect() + self._schedule_resubscribe() @property def rt_subscription_running(self) -> bool: """Is real time subscription running.""" - if not self._tibber_control.realtime.subscription_running: - return False - return not self._last_rt_data_received < dt.datetime.now(tz=dt.UTC) - dt.timedelta(seconds=60) + return self._tibber_control.realtime.subscription_running and self._rt_listener is not None async def get_historic_data( self, diff --git a/tibber/realtime.py b/tibber/realtime.py index 53fe89e..8553a95 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -1,19 +1,27 @@ """Tibber RT connection.""" import asyncio -import datetime as dt import logging -import random +from collections.abc import AsyncGenerator, Callable from ssl import SSLContext -from typing import Any +from typing import TYPE_CHECKING, Any, cast -from gql import Client +from gql import Client, GraphQLRequest +from gql.transport.exceptions import TransportClosed, TransportConnectionFailed, TransportError +from gql.transport.websockets import WebsocketsTransport +from tenacity import before_sleep_log, retry, wait_exponential_jitter -from .exceptions import SubscriptionEndpointMissingError -from .home import TibberHome -from .websocket_transport import TibberWebsocketsTransport +from .exceptions import SubscriptionEndpointMissingError, WebsocketReconnectedError, WebsocketTransportError +if TYPE_CHECKING: + from gql.client import AsyncClientSession + +KEEP_ALIVE_TIMEOUT = 90 LOCK_CONNECT = asyncio.Lock() +MIN_RECONNECT_INTERVAL = 1 +MAX_RECONNECT_INTERVAL = 60 +PING_INTERVAL = 30 +PONG_TIMEOUT = 20 _LOGGER = logging.getLogger(__name__) @@ -32,224 +40,176 @@ def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLCon self._timeout: int = timeout self._user_agent: str = user_agent self._ssl_context = ssl - self._sub_endpoint: str | None = None - self._homes: list[TibberHome] = [] - self._watchdog_runner: None | asyncio.Task[Any] = None - self._watchdog_running: bool = False - - self.sub_manager: Client | None = None - self.session: Any | None = None - - async def disconnect(self) -> None: - """Stop subscription manager. - This method simply calls the stop method of the SubscriptionManager if it is defined. - """ - _LOGGER.debug("Stopping subscription manager") - await self._reset_connection(unsubscribe_homes=True) - - async def _reset_connection(self, unsubscribe_homes: bool = False) -> None: - """Reset websocket connection state.""" - if self._watchdog_runner is not None: - _LOGGER.debug("Stopping watchdog") - self._watchdog_running = False - self._watchdog_runner.cancel() - self._watchdog_runner = None - if unsubscribe_homes: - for home in self._homes: - home.rt_unsubscribe() - try: - await self._close_sub_manager() - finally: - self.session = None - self.sub_manager = None - - async def connect(self) -> None: - """Start subscription manager.""" - self._create_sub_manager() - - assert self.sub_manager is not None - - async with LOCK_CONNECT: - if self.subscription_running: - return - if self._watchdog_runner is None: - _LOGGER.debug("Starting watchdog") - self._watchdog_running = True - self._watchdog_runner = asyncio.create_task(self._watchdog()) - self.session = await self.sub_manager.connect_async() - - async def reconnect(self) -> None: - """Reconnect and resubscribe all homes.""" - await self.connect() - await self._resubscribe_homes() - - async def set_access_token(self, access_token: str) -> None: - """Set access token.""" - reconnect_running = self.subscription_running or self._watchdog_runner is not None - self._access_token = access_token - await self._reset_connection(unsubscribe_homes=reconnect_running) - - def _build_sub_manager(self) -> Client: - """Create a subscription manager for the current websocket endpoint.""" - if self.sub_endpoint is None: - raise SubscriptionEndpointMissingError("Subscription endpoint not initialized") - + self._tibber_connected = asyncio.Event() + self._client: Client | None = None + self.subscription_running = False + self._session: AsyncClientSession | None = None + + def _create_client(self) -> Client: + """Create a new gql Client with the current transport settings.""" + self._tibber_connected.clear() return Client( transport=TibberWebsocketsTransport( - self.sub_endpoint, + self._sub_endpoint, self._access_token, self._user_agent, ssl=self._ssl_context, + tibber_connected=self._tibber_connected, ), ) - def _sub_manager_has_session(self) -> bool: - """Return True if the current gql client owns a session.""" - return self.sub_manager is not None and hasattr(self.sub_manager, "session") + async def disconnect(self) -> None: + """Disconnect the websocket client.""" + _LOGGER.debug("Stopping subscription manager") + async with LOCK_CONNECT: + await self._disconnect() - async def _close_sub_manager(self) -> None: - """Close the current gql client if it has an active session object.""" - if self.sub_manager is None: - return + async def _disconnect(self) -> None: + """Disconnect the websocket client.""" + if self._client is not None and self._session is not None: + await self._client.close_async() + self._session = None + self.subscription_running = False - if not self._sub_manager_has_session(): - _LOGGER.debug( - "Skipping subscription manager close because the gql client has no session", - ) - return + async def connect(self) -> None: + """Connect the websocket client.""" + async with LOCK_CONNECT: + await self._connect() - await self.sub_manager.close_async() + async def _connect(self) -> None: + """Connect the websocket client.""" + if self._sub_endpoint is None: + raise SubscriptionEndpointMissingError("Subscription endpoint not initialized") - def _create_sub_manager(self) -> None: - if self.sub_manager is not None: + if self.subscription_running or self._session: return - self.sub_manager = self._build_sub_manager() - - async def _watchdog(self) -> None: - """Watchdog to keep connection alive.""" - assert self.sub_manager is not None - assert isinstance(self.sub_manager.transport, TibberWebsocketsTransport) - - await asyncio.sleep(60) - - _retry_count = 0 - next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) - while self._watchdog_running: - await asyncio.sleep(5) - if ( - self.sub_manager.transport.running - and self.sub_manager.transport.reconnect_at - > dt.datetime.now( - tz=dt.UTC, - ) - and dt.datetime.now(tz=dt.UTC) > next_test_all_homes_running - ): - is_running = True - for home in self._homes: - _LOGGER.debug( - "Watchdog: Checking if home %s is alive, %s, %s", - home.home_id, - home.has_real_time_consumption, - home.rt_subscription_running, - ) - if not home.rt_subscription_running: - is_running = False - next_test_all_homes_running = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=60) - break - _LOGGER.debug( - "Watchdog: Home %s is alive", - home.home_id, - ) - if is_running: - _retry_count = 0 - _LOGGER.debug("Watchdog: Connection is alive") - continue - - self.sub_manager.transport.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - _LOGGER.error( - "Watchdog: Connection is down, %s", - self.sub_manager.transport.reconnect_at, + + self._client = self._create_client() + try: + await asyncio.wait_for( + asyncio.shield( + self._client.connect_async( + reconnecting=True, + retry_connect=retry( + wait=wait_exponential_jitter( + initial=MIN_RECONNECT_INTERVAL, + max=MAX_RECONNECT_INTERVAL, + jitter=MAX_RECONNECT_INTERVAL, + ), + before_sleep=before_sleep_log(_LOGGER, logging.INFO), + ), + ), + ), + timeout=self._timeout, ) + except TimeoutError as err: + _LOGGER.debug("Timeout connecting to websocket: %s", err) + # The connection will be retried by the reconnecting task + else: + self.subscription_running = True - try: - await self._close_sub_manager() - except Exception: - _LOGGER.exception("Error in watchdog close") - finally: - self.session = None - self.sub_manager = None + # The client session is set even if the connection times out. + self._session = cast("AsyncClientSession", self._client.session) - if not self._watchdog_running: - _LOGGER.debug("Watchdog: Stopping") + async def reconnect(self) -> None: + """Reconnect the websocket client.""" + async with LOCK_CONNECT: + if self._session is None: return + _LOGGER.debug("Reconnecting websocket client") + await self._disconnect() + await self._connect() - self._create_sub_manager() - assert self.sub_manager is not None - try: - self.session = await self.sub_manager.connect_async() - await self._resubscribe_homes() - except Exception as err: # noqa: BLE001 - delay_seconds = min( - random.SystemRandom().randint(1, 30) + _retry_count**2, - 5 * 60, - ) - _retry_count += 1 - _LOGGER.error( - "Error in watchdog connect, retrying in %s seconds, %s: %s", - delay_seconds, - _retry_count, - err, - exc_info=_retry_count > 1, - ) - await asyncio.sleep(delay_seconds) - else: - _LOGGER.debug("Watchdog: Reconnected successfully") - await asyncio.sleep(60) - - async def _resubscribe_homes(self) -> None: - """Resubscribe to all homes.""" - _LOGGER.debug("Resubscribing to homes") - await asyncio.gather(*[home.rt_resubscribe() for home in self._homes]) - - def add_home(self, home: TibberHome) -> bool: - """Add home to real time subscription.""" - if home.has_real_time_consumption is False: - return False - if home in self._homes: - return False - self._homes.append(home) - return True - - @property - def subscription_running(self) -> bool: - """Is real time subscription running.""" - return ( - self.sub_manager is not None - and isinstance(self.sub_manager.transport, TibberWebsocketsTransport) - and self.sub_manager.transport.running - and self.session is not None + async def set_access_token(self, access_token: str) -> None: + """Set access token.""" + self._access_token = access_token + await self.reconnect() + + async def subscribe( + self, + request: GraphQLRequest, + *, + on_error: Callable[[Exception], None] | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Subscribe to a GraphQL query.""" + if self._session is None: + raise RuntimeError("Connect must be called before subscribe") + + try: + async for result in self._session.subscribe(request): + yield result + except TransportError as err: + self.subscription_running = False + self._tibber_connected.clear() + if isinstance(err, TransportConnectionFailed): + level = logging.DEBUG if on_error is not None else logging.ERROR + _LOGGER.log(level, "%s: %s", err.__class__.__name__, err) + if on_error: + on_error(err) + _LOGGER.debug("Waiting for reconnect") + await self._tibber_connected.wait() + self.subscription_running = True + _LOGGER.info("Reconnected") + raise WebsocketReconnectedError("Websocket reconnected") from err + raise WebsocketTransportError(err) from err + + async def set_subscription_endpoint(self, url: str) -> None: + """Set subscription endpoint.""" + old_url = self._sub_endpoint + if url == old_url: + return + _LOGGER.debug("Updating subscription endpoint to %s", url) + self._sub_endpoint = url + await self.reconnect() + + +class TibberWebsocketsTransport(WebsocketsTransport): + """Tibber websockets transport.""" + + def __init__( + self, + url: str, + access_token: str, + user_agent: str, + *, + ssl: SSLContext | bool = True, + tibber_connected: asyncio.Event, + ) -> None: + """Initialize TibberWebsocketsTransport.""" + super().__init__( + url=url, + init_payload={"token": access_token}, + headers={"User-Agent": user_agent}, + ssl=ssl, + keep_alive_timeout=KEEP_ALIVE_TIMEOUT, + ping_interval=PING_INTERVAL, + pong_timeout=PONG_TIMEOUT, ) + self._tibber_connected = tibber_connected + self._user_agent = user_agent + + async def _after_connect(self) -> None: + """Hook to add custom code for subclasses. - @property - def should_restore_connection(self) -> bool: - """Whether realtime subscriptions should be restored after a reset.""" - return self.subscription_running or self._watchdog_runner is not None + Called after the connection has been established. + """ + await super()._after_connect() + self._tibber_connected.set() - @property - def sub_endpoint(self) -> str | None: - """Get subscription endpoint.""" - return self._sub_endpoint + async def close(self) -> None: + """Close the websocket connection. - @sub_endpoint.setter - def sub_endpoint(self, sub_endpoint: str) -> None: - """Set subscription endpoint.""" - self._sub_endpoint = sub_endpoint - if self.sub_manager is not None and isinstance(self.sub_manager.transport, TibberWebsocketsTransport): - if self.session is not None or self._sub_manager_has_session(): - _LOGGER.debug( - "Delaying websocket subscription url update until the next reconnect", - ) - return + This method is only called by the client. + """ + await self._fail(TransportClosed(f"Tibber websocket closed by {self._user_agent}")) + await self.wait_closed() + + async def _close_hook(self) -> None: + """Hook called by WebsocketsTransportBase on connection close. - self.sub_manager = self._build_sub_manager() + This method is called when the connection is closed + for any reason (not only by the client). + """ + self._tibber_connected.clear() + await super()._close_hook() diff --git a/tibber/websocket_transport.py b/tibber/websocket_transport.py deleted file mode 100644 index b4387b6..0000000 --- a/tibber/websocket_transport.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Websocket transport for Tibber.""" - -import asyncio -import datetime as dt -import logging -from ssl import SSLContext - -from gql.transport.exceptions import TransportClosed -from gql.transport.websockets import WebsocketsTransport -from websockets.asyncio.connection import State - -_LOGGER = logging.getLogger(__name__) - - -class TibberWebsocketsTransport(WebsocketsTransport): - """Tibber websockets transport.""" - - def __init__(self, url: str, access_token: str, user_agent: str, ssl: SSLContext | bool = True) -> None: - """Initialize TibberWebsocketsTransport. - Configures the gql.transport.websockets logger to WARNING level to suppress - verbose INFO-level messages (<<< and >>> websocket traffic logs). - """ - logging.getLogger("gql.transport.websockets").setLevel(logging.WARNING) - - super().__init__( - url=url, - init_payload={"token": access_token}, - headers={"User-Agent": user_agent}, - ping_interval=30, - ssl=ssl, - ) - self._user_agent: str = user_agent - self._timeout: int = 90 - self.reconnect_at: dt.datetime = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - - @property - def running(self) -> bool: - """Is real time subscription running.""" - return ( - hasattr(self, "adapter") - and hasattr(self.adapter, "websocket") - and self.adapter.websocket is not None - and self.adapter.websocket.state is State.OPEN - ) - - async def _receive(self) -> str: - """Wait the next message from the websocket connection.""" - try: - msg = await asyncio.wait_for(super()._receive(), timeout=self._timeout) - except TimeoutError: - _LOGGER.error("No data received from Tibber for %s seconds", self._timeout) - raise - self.reconnect_at = dt.datetime.now(tz=dt.UTC) + dt.timedelta(seconds=self._timeout) - return msg - - async def close(self) -> None: - """Close the websocket connection.""" - await self._fail(TransportClosed(f"Tibber websocket closed by {self._user_agent}")) - await self.wait_closed()