"""Unit tests for WebSocketMonitor.""" from __future__ import annotations import asyncio import time from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiohttp import pytest from ha_diag.config import Settings from ha_diag.event_emitter import EventEmitter from ha_diag.models import HAEventType from ha_diag.monitors.websocket_monitor import ( WebSocketMonitor, _AuthError, _make_ws_url, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_settings(**overrides) -> Settings: defaults: dict = { "ha_url": "http://test.local:8123", "ha_token": "test-token", "node_name": "test-node", "location_tag": "test-loc", "websocket_enabled": True, "websocket_silence_threshold_seconds": 300, "websocket_watchdog_interval_seconds": 30, "websocket_reconnect_initial_delay": 1.0, "websocket_reconnect_max_delay": 60.0, "websocket_reconnect_jitter": 0.0, "websocket_down_alert_repeat_minutes": 10, } defaults.update(overrides) return Settings(**defaults) class FakeWS: """Fake aiohttp ClientWebSocketResponse for unit tests.""" def __init__(self, auth_messages: list, event_messages: list | None = None): self._auth_queue = list(auth_messages) self._event_queue = list(event_messages or []) self.sent: list = [] async def receive_json(self) -> dict: if not self._auth_queue: raise ConnectionError("FakeWS: no more auth messages") return self._auth_queue.pop(0) async def send_json(self, data: dict) -> None: self.sent.append(data) def __aiter__(self): return self async def __anext__(self): if not self._event_queue: raise StopAsyncIteration item = self._event_queue.pop(0) if isinstance(item, BaseException): raise item return item def _text_msg(data: str = '{"type":"event"}') -> aiohttp.WSMessage: return aiohttp.WSMessage(type=aiohttp.WSMsgType.TEXT, data=data, extra=None) def _close_msg() -> aiohttp.WSMessage: return aiohttp.WSMessage(type=aiohttp.WSMsgType.CLOSE, data=b"", extra=None) def _mock_session(fake_ws: FakeWS) -> MagicMock: cm = MagicMock() cm.__aenter__ = AsyncMock(return_value=fake_ws) cm.__aexit__ = AsyncMock(return_value=False) session = MagicMock() session.ws_connect.return_value = cm return session def _make_monitor( settings: Settings | None = None, session=None, emitter: EventEmitter | None = None, tmp_path: Path | None = None, ) -> WebSocketMonitor: if settings is None: settings = _make_settings() if emitter is None: p = (tmp_path or Path("/tmp/ws_test_events")).absolute() p.mkdir(parents=True, exist_ok=True) emitter = EventEmitter(p, node_name="test-node") if session is None: session = MagicMock() return WebSocketMonitor( ha_url=settings.ha_url, token=settings.ha_token, settings=settings, emitter=emitter, session=session, ) # --------------------------------------------------------------------------- # URL derivation # --------------------------------------------------------------------------- def test_make_ws_url_http(): assert _make_ws_url("http://ha.local:8123") == "ws://ha.local:8123/api/websocket" def test_make_ws_url_https(): assert _make_ws_url("https://ha.example.com") == "wss://ha.example.com/api/websocket" def test_make_ws_url_strips_trailing_slash(): assert _make_ws_url("http://ha.local:8123/") == "ws://ha.local:8123/api/websocket" # --------------------------------------------------------------------------- # Auth flow (via _connect_and_listen) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_normal_auth_sends_correct_messages(tmp_path): """Happy path: sends auth + subscribe, ends in subscribed state.""" fake_ws = FakeWS( [{"type": "auth_required"}, {"type": "auth_ok"}], [_text_msg('{"type":"result","id":1,"success":true}')], ) monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) await monitor._connect_and_listen() assert fake_ws.sent[0] == {"type": "auth", "access_token": "test-token"} assert fake_ws.sent[1]["type"] == "subscribe_events" assert fake_ws.sent[1]["event_type"] == "state_changed" assert monitor._state == "subscribed" @pytest.mark.asyncio async def test_last_event_monotonic_updated_on_text_message(tmp_path): """Receiving TEXT messages updates last_event_monotonic.""" fake_ws = FakeWS( [{"type": "auth_required"}, {"type": "auth_ok"}], [_text_msg(), _text_msg()], ) monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) before = time.monotonic() await monitor._connect_and_listen() assert monitor._last_event_monotonic >= before @pytest.mark.asyncio async def test_auth_invalid_raises_auth_error(tmp_path): """auth_invalid → _AuthError propagates.""" fake_ws = FakeWS([ {"type": "auth_required"}, {"type": "auth_invalid", "message": "invalid token"}, ]) monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) with pytest.raises(_AuthError, match="invalid token"): await monitor._connect_and_listen() @pytest.mark.asyncio async def test_unexpected_initial_message_raises(tmp_path): """Anything other than auth_required on connect → ConnectionError.""" fake_ws = FakeWS([{"type": "unexpected"}]) monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) with pytest.raises(ConnectionError, match="Unexpected initial"): await monitor._connect_and_listen() @pytest.mark.asyncio async def test_empty_auth_queue_raises_connection_error(tmp_path): """Connection closed before auth_required → ConnectionError.""" fake_ws = FakeWS([]) monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) with pytest.raises(ConnectionError): await monitor._connect_and_listen() # --------------------------------------------------------------------------- # Disconnect / dead alerts (_on_disconnected) # --------------------------------------------------------------------------- def test_on_disconnected_emits_ha_websocket_dead(tmp_path): emitter = MagicMock() monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path) monitor._state = "disconnected" monitor._on_disconnected() emitter.emit.assert_called_once() assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value def test_on_disconnected_within_cooldown_suppresses_second_emit(tmp_path): emitter = MagicMock() monitor = _make_monitor( settings=_make_settings(websocket_down_alert_repeat_minutes=10), emitter=emitter, tmp_path=tmp_path, ) monitor._state = "disconnected" monitor._on_disconnected() # first emit emitter.emit.reset_mock() monitor._on_disconnected() # within cooldown → suppressed emitter.emit.assert_not_called() def test_on_disconnected_after_cooldown_emits_again(tmp_path): emitter = MagicMock() monitor = _make_monitor( settings=_make_settings(websocket_down_alert_repeat_minutes=10), emitter=emitter, tmp_path=tmp_path, ) monitor._state = "disconnected" monitor._on_disconnected() # Backdate to simulate cooldown expiry monitor._last_dead_alert_at = time.monotonic() - (10 * 60 + 5) emitter.emit.reset_mock() monitor._on_disconnected() emitter.emit.assert_called_once() def test_on_disconnected_noop_when_stopping(tmp_path): emitter = MagicMock() monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path) monitor._stopping = True monitor._on_disconnected() emitter.emit.assert_not_called() # --------------------------------------------------------------------------- # Recovery # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_reconnect_after_dead_emits_recovered(tmp_path): """Successful reconnect after a dead alert emits ha_websocket_recovered.""" emitter = MagicMock() fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], []) settings = _make_settings() monitor = WebSocketMonitor( ha_url=settings.ha_url, token=settings.ha_token, settings=settings, emitter=emitter, session=_mock_session(fake_ws), ) monitor._last_dead_alert_at = time.monotonic() - 30.0 # prior dead was sent await monitor._connect_and_listen() emitted_types = [c[1]["event_type"] for c in emitter.emit.call_args_list] assert HAEventType.ha_websocket_recovered.value in emitted_types assert monitor._last_dead_alert_at == 0.0 # reset after recovery @pytest.mark.asyncio async def test_no_recovered_if_no_prior_dead(tmp_path): """First-ever connect with no prior dead alert → no recovered emitted.""" emitter = MagicMock() fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], []) settings = _make_settings() monitor = WebSocketMonitor( ha_url=settings.ha_url, token=settings.ha_token, settings=settings, emitter=emitter, session=_mock_session(fake_ws), ) monitor._last_dead_alert_at = 0.0 await monitor._connect_and_listen() emitter.emit.assert_not_called() # --------------------------------------------------------------------------- # Watchdog loop # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_watchdog_emits_dead_when_silent_over_threshold(tmp_path): """Watchdog detects silence > threshold and emits ha_websocket_dead.""" emitter = MagicMock() settings = _make_settings( websocket_silence_threshold_seconds=60, websocket_watchdog_interval_seconds=30, websocket_down_alert_repeat_minutes=0, ) monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) monitor._state = "subscribed" monitor._last_event_monotonic = time.monotonic() - 120.0 # 120s > 60s threshold monitor._last_dead_alert_at = 0.0 sleep_calls = 0 async def one_iteration(t): nonlocal sleep_calls sleep_calls += 1 if sleep_calls >= 2: raise asyncio.CancelledError() with patch("asyncio.sleep", side_effect=one_iteration): with pytest.raises(asyncio.CancelledError): await monitor._watchdog_loop() emitter.emit.assert_called_once() assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value @pytest.mark.asyncio async def test_watchdog_no_emit_when_events_recent(tmp_path): """Watchdog does not emit when last event is within silence threshold.""" emitter = MagicMock() settings = _make_settings( websocket_silence_threshold_seconds=300, websocket_watchdog_interval_seconds=30, websocket_down_alert_repeat_minutes=0, ) monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) monitor._state = "subscribed" monitor._last_event_monotonic = time.monotonic() - 10.0 # recent sleep_calls = 0 async def one_iteration(t): nonlocal sleep_calls sleep_calls += 1 if sleep_calls >= 2: raise asyncio.CancelledError() with patch("asyncio.sleep", side_effect=one_iteration): with pytest.raises(asyncio.CancelledError): await monitor._watchdog_loop() emitter.emit.assert_not_called() @pytest.mark.asyncio async def test_watchdog_skips_when_not_subscribed(tmp_path): """Watchdog does not emit when state is not 'subscribed'.""" emitter = MagicMock() settings = _make_settings( websocket_silence_threshold_seconds=1, websocket_watchdog_interval_seconds=30, websocket_down_alert_repeat_minutes=0, ) monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) monitor._state = "disconnected" monitor._last_event_monotonic = time.monotonic() - 9999.0 # very old sleep_calls = 0 async def one_iteration(t): nonlocal sleep_calls sleep_calls += 1 if sleep_calls >= 2: raise asyncio.CancelledError() with patch("asyncio.sleep", side_effect=one_iteration): with pytest.raises(asyncio.CancelledError): await monitor._watchdog_loop() emitter.emit.assert_not_called() @pytest.mark.asyncio async def test_watchdog_repeat_alert_respects_cooldown(tmp_path): """Second watchdog dead alert fires only after cooldown.""" emitter = MagicMock() settings = _make_settings( websocket_silence_threshold_seconds=60, websocket_watchdog_interval_seconds=30, websocket_down_alert_repeat_minutes=10, ) monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) monitor._state = "subscribed" monitor._last_event_monotonic = time.monotonic() - 3600.0 # 1hr silent # Set last alert to just now → still within 10-min cooldown monitor._last_dead_alert_at = time.monotonic() sleep_calls = 0 async def one_iteration(t): nonlocal sleep_calls sleep_calls += 1 if sleep_calls >= 2: raise asyncio.CancelledError() with patch("asyncio.sleep", side_effect=one_iteration): with pytest.raises(asyncio.CancelledError): await monitor._watchdog_loop() emitter.emit.assert_not_called() # within cooldown # --------------------------------------------------------------------------- # Reconnect backoff # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_reconnect_backoff_doubles_each_attempt(tmp_path): """Retry delay doubles on consecutive failures.""" delays: list[float] = [] call_count = 0 async def fail_connect(): nonlocal call_count call_count += 1 raise ConnectionError("refused") async def capture_sleep(t): delays.append(t) if call_count >= 3: raise asyncio.CancelledError() settings = _make_settings( websocket_reconnect_initial_delay=1.0, websocket_reconnect_max_delay=60.0, websocket_reconnect_jitter=0.0, ) monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path) monitor._connect_and_listen = fail_connect with patch("asyncio.sleep", side_effect=capture_sleep): with pytest.raises(asyncio.CancelledError): await monitor._connection_loop() assert len(delays) >= 2 assert delays[0] == pytest.approx(1.0) assert delays[1] == pytest.approx(2.0) @pytest.mark.asyncio async def test_reconnect_delay_capped_at_max(tmp_path): """Delay never exceeds websocket_reconnect_max_delay.""" delays: list[float] = [] call_count = 0 async def fail_connect(): nonlocal call_count call_count += 1 raise ConnectionError("refused") async def capture_sleep(t): delays.append(t) if call_count >= 8: raise asyncio.CancelledError() settings = _make_settings( websocket_reconnect_initial_delay=1.0, websocket_reconnect_max_delay=8.0, websocket_reconnect_jitter=0.0, ) monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path) monitor._connect_and_listen = fail_connect with patch("asyncio.sleep", side_effect=capture_sleep): with pytest.raises(asyncio.CancelledError): await monitor._connection_loop() assert max(delays) <= 8.0 # --------------------------------------------------------------------------- # is_healthy # --------------------------------------------------------------------------- def test_is_healthy_true_when_subscribed(tmp_path): monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) monitor._state = "subscribed" assert monitor.is_healthy is True def test_is_healthy_false_when_disconnected(tmp_path): monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) monitor._state = "disconnected" assert monitor.is_healthy is False def test_is_healthy_false_when_connecting(tmp_path): monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) monitor._state = "connecting" assert monitor.is_healthy is False def test_is_healthy_true_when_disabled(tmp_path): """Disabled monitor reports healthy — it's off, not broken.""" monitor = _make_monitor(settings=_make_settings(websocket_enabled=False), tmp_path=tmp_path) monitor._state = "disconnected" assert monitor.is_healthy is True # --------------------------------------------------------------------------- # start / stop lifecycle # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_stop_cancels_background_tasks(tmp_path): """stop() cancels the main and watchdog tasks.""" async def hang(): await asyncio.sleep(9999) monitor = _make_monitor(tmp_path=tmp_path) monitor._main_task = asyncio.create_task(hang()) monitor._watchdog_task = asyncio.create_task(hang()) await monitor.stop() assert monitor._main_task is None assert monitor._watchdog_task is None @pytest.mark.asyncio async def test_start_no_tasks_when_disabled(tmp_path): """start() with websocket_enabled=False does not spawn tasks.""" monitor = _make_monitor( settings=_make_settings(websocket_enabled=False), tmp_path=tmp_path, ) await monitor.start() assert monitor._main_task is None assert monitor._watchdog_task is None