From 31b48d162a561b13913e79fb7332cd86d8390b48 Mon Sep 17 00:00:00 2001 From: Oskar Kapala Date: Fri, 29 May 2026 15:00:18 +0200 Subject: [PATCH] feat(ha-diag-agent): WebSocketMonitor for real-time HA liveness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - persistent WS connection to HA with auth + state_changed subscription - watchdog detects silence > 5min → emits ha_websocket_dead - immediate ha_websocket_dead on disconnect, exponential reconnect with jitter - cooldown prevents alert spam (10min repeat window while HA stays down) - ha_websocket_recovered emitted on reconnect after a dead alert (allows supervisor to clear active incidents in Phase 5) - new monitors/ subpackage for long-running tasks (vs interval checks/) - /health endpoint now includes ws_connected field - 26 unit tests, 3 integration tests (real HA + container stop/restart) Co-Authored-By: Claude Sonnet 4.6 --- services/ha-diag-agent/README.md | 24 +- services/ha-diag-agent/src/ha_diag/api.py | 12 +- services/ha-diag-agent/src/ha_diag/config.py | 9 + services/ha-diag-agent/src/ha_diag/main.py | 15 +- services/ha-diag-agent/src/ha_diag/models.py | 1 + .../src/ha_diag/monitors/__init__.py | 4 + .../src/ha_diag/monitors/base.py | 24 + .../src/ha_diag/monitors/websocket_monitor.py | 286 +++++++++ .../test_websocket_monitor_integration.py | 186 ++++++ .../tests/test_websocket_monitor.py | 558 ++++++++++++++++++ 10 files changed, 1112 insertions(+), 7 deletions(-) create mode 100644 services/ha-diag-agent/src/ha_diag/monitors/__init__.py create mode 100644 services/ha-diag-agent/src/ha_diag/monitors/base.py create mode 100644 services/ha-diag-agent/src/ha_diag/monitors/websocket_monitor.py create mode 100644 services/ha-diag-agent/tests/integration/test_websocket_monitor_integration.py create mode 100644 services/ha-diag-agent/tests/test_websocket_monitor.py diff --git a/services/ha-diag-agent/README.md b/services/ha-diag-agent/README.md index 1955d1f..f74008b 100644 --- a/services/ha-diag-agent/README.md +++ b/services/ha-diag-agent/README.md @@ -10,12 +10,21 @@ no direct supervisor integration, events processed by the VPS observer. ## Architecture ``` -APScheduler (every CHECK_INTERVAL s) - └─ HeartbeatCheck → pings /api/, emits ha_websocket_dead on failure - [Phase 3: EntityUnavailableCheck, SystemHealthCheck, UpdateCheck, ...] +APScheduler (interval-based REST checks) + ├─ HeartbeatCheck → pings /api/, emits ha_websocket_dead on failure + ├─ UnavailableEntitiesCheck → entity unavailable > threshold + ├─ SystemHealthCheck → /api/system_health per-integration status + ├─ AutomationFailuresCheck → automation last-run error traces + └─ UpdatesAvailableCheck → pending HA/integration updates + +WebSocketMonitor (persistent, long-running — Phase 4b) + └─ Maintains a live WS subscription to state_changed events + Any traffic = HA is alive. Watchdog fires ha_websocket_dead on + silence > 5min or on disconnect. Emits ha_websocket_recovered + when the connection is restored after a dead alert. FastAPI (port 8087) - GET /health → liveness probe + GET /health → liveness probe (includes ws_connected field) POST /trigger/ → run a named check on demand SQLite (/data/ha_diag.db) @@ -24,11 +33,15 @@ SQLite (/data/ha_diag.db) alerts_sent → dedup gate for alert events ``` +The WebSocketMonitor is the only persistent-connection component; all other +checks are APScheduler intervals (stateless REST polls). + ## Event Types | Type | Severity | Trigger | |------|----------|---------| -| `ha_websocket_dead` | error | HA /api/ unreachable | +| `ha_websocket_dead` | error | WS disconnect, silence > 5min, or /api/ unreachable | +| `ha_websocket_recovered` | info | WS reconnected after a dead alert (clears incident) | | `ha_integration_failed` | error | Integration in error state | | `ha_entity_unavailable_long` | warning | Entity unavailable > threshold | | `ha_automation_failing` | warning | Automation last run errored | @@ -37,6 +50,7 @@ SQLite (/data/ha_diag.db) | `ha_system_health_degraded` | warning | System health check failed | Event routing in supervisor (Phase 5) maps these to `notify` actions. +`ha_websocket_recovered` should be routed to clear any active `ha_websocket_dead` incident. ## Deployment model diff --git a/services/ha-diag-agent/src/ha_diag/api.py b/services/ha-diag-agent/src/ha_diag/api.py index 5e7cda4..6d99ddc 100644 --- a/services/ha-diag-agent/src/ha_diag/api.py +++ b/services/ha-diag-agent/src/ha_diag/api.py @@ -6,11 +6,13 @@ from fastapi import FastAPI, HTTPException if TYPE_CHECKING: from .checks.base import Check + from .monitors.base import Monitor app = FastAPI(title="ha-diag-agent", version="0.1.0") # Populated by main.py during startup _checks: dict[str, "Check"] = {} +_ws_monitor: "Monitor | None" = None _node_name: str = "unknown" _location_tag: str = "default" @@ -22,14 +24,22 @@ def register_checks(checks: list["Check"], node_name: str, location_tag: str) -> _location_tag = location_tag +def register_ws_monitor(monitor: "Monitor") -> None: + global _ws_monitor + _ws_monitor = monitor + + @app.get("/health") async def health() -> dict: - return { + response: dict = { "status": "ok", "node": _node_name, "location_tag": _location_tag, "checks": list(_checks.keys()), } + if _ws_monitor is not None: + response["ws_connected"] = _ws_monitor.is_healthy + return response @app.post("/trigger/{check_name}") diff --git a/services/ha-diag-agent/src/ha_diag/config.py b/services/ha-diag-agent/src/ha_diag/config.py index 50b3b16..6eb8be8 100644 --- a/services/ha-diag-agent/src/ha_diag/config.py +++ b/services/ha-diag-agent/src/ha_diag/config.py @@ -45,6 +45,15 @@ class Settings(BaseSettings): updates_check_minute: int = 0 updates_cooldown_days: int = 7 # don't re-alert same update within N days + # WebSocket monitor + websocket_enabled: bool = True + websocket_silence_threshold_seconds: int = 300 # 5 min + websocket_watchdog_interval_seconds: int = 30 + websocket_reconnect_initial_delay: float = 1.0 + websocket_reconnect_max_delay: float = 60.0 + websocket_reconnect_jitter: float = 0.2 # ±20% of delay + websocket_down_alert_repeat_minutes: int = 10 + # API server port: int = 8087 log_level: str = "info" diff --git a/services/ha-diag-agent/src/ha_diag/main.py b/services/ha-diag-agent/src/ha_diag/main.py index 90881fd..d157258 100644 --- a/services/ha-diag-agent/src/ha_diag/main.py +++ b/services/ha-diag-agent/src/ha_diag/main.py @@ -10,7 +10,7 @@ import structlog import uvicorn from apscheduler.schedulers.asyncio import AsyncIOScheduler -from .api import app, register_checks +from .api import app, register_checks, register_ws_monitor from .checks.automation_failures import AutomationFailuresCheck from .checks.heartbeat import HeartbeatCheck from .checks.system_health import SystemHealthCheck @@ -19,6 +19,7 @@ from .checks.updates_available import UpdatesAvailableCheck, UpdatesDigestCheck from .config import Settings from .event_emitter import EventEmitter from .ha_client import HAClient, make_session +from .monitors import WebSocketMonitor from .storage import Storage _log = structlog.get_logger() @@ -112,6 +113,15 @@ async def run(settings: Settings) -> None: updates_daily, updates_digest] register_checks(all_checks, settings.node_name, settings.location_tag) + ws_monitor = WebSocketMonitor( + ha_url=settings.ha_url, + token=settings.ha_token, + settings=settings, + emitter=emitter, + session=session, + ) + register_ws_monitor(ws_monitor) + scheduler = AsyncIOScheduler() scheduler.add_job( _run_check_and_emit, "interval", @@ -167,6 +177,8 @@ async def run(settings: Settings) -> None: updates_hour=settings.updates_check_hour, ) + await ws_monitor.start() + config = uvicorn.Config( app, host="0.0.0.0", @@ -177,6 +189,7 @@ async def run(settings: Settings) -> None: try: await server.serve() finally: + await ws_monitor.stop() scheduler.shutdown(wait=False) await storage.close() await session.close() diff --git a/services/ha-diag-agent/src/ha_diag/models.py b/services/ha-diag-agent/src/ha_diag/models.py index 55d6477..b256b83 100644 --- a/services/ha-diag-agent/src/ha_diag/models.py +++ b/services/ha-diag-agent/src/ha_diag/models.py @@ -16,6 +16,7 @@ class HAEventType(str, Enum): ha_integration_failed = "ha_integration_failed" ha_entity_unavailable_long = "ha_entity_unavailable_long" ha_websocket_dead = "ha_websocket_dead" + ha_websocket_recovered = "ha_websocket_recovered" ha_automation_failing = "ha_automation_failing" ha_update_available = "ha_update_available" ha_recorder_lag = "ha_recorder_lag" diff --git a/services/ha-diag-agent/src/ha_diag/monitors/__init__.py b/services/ha-diag-agent/src/ha_diag/monitors/__init__.py new file mode 100644 index 0000000..2ab0e34 --- /dev/null +++ b/services/ha-diag-agent/src/ha_diag/monitors/__init__.py @@ -0,0 +1,4 @@ +from .base import Monitor +from .websocket_monitor import WebSocketMonitor + +__all__ = ["Monitor", "WebSocketMonitor"] diff --git a/services/ha-diag-agent/src/ha_diag/monitors/base.py b/services/ha-diag-agent/src/ha_diag/monitors/base.py new file mode 100644 index 0000000..d5f42f5 --- /dev/null +++ b/services/ha-diag-agent/src/ha_diag/monitors/base.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class Monitor(ABC): + """Base class for long-running background monitors. + + Unlike checks (one-shot, APScheduler-driven), monitors maintain + persistent state — connections, subscriptions, background tasks. + """ + + @abstractmethod + async def start(self) -> None: + """Spawn background task(s). Idempotent if already started.""" + + @abstractmethod + async def stop(self) -> None: + """Cancel background tasks and wait for cleanup.""" + + @property + @abstractmethod + def is_healthy(self) -> bool: + """True when the monitor is running and its connection is live.""" diff --git a/services/ha-diag-agent/src/ha_diag/monitors/websocket_monitor.py b/services/ha-diag-agent/src/ha_diag/monitors/websocket_monitor.py new file mode 100644 index 0000000..439ab7f --- /dev/null +++ b/services/ha-diag-agent/src/ha_diag/monitors/websocket_monitor.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import asyncio +import json +import random +import time +from datetime import datetime, timezone + +import aiohttp +import structlog + +from ..config import Settings +from ..event_emitter import EventEmitter +from ..models import HAEventType, Severity +from .base import Monitor + +_log = structlog.get_logger().bind(monitor="websocket") + + +class _AuthError(Exception): + """Raised when HA returns auth_invalid during the WS handshake.""" + + +def _make_ws_url(ha_url: str) -> str: + if ha_url.startswith("https://"): + base = ha_url.replace("https://", "wss://", 1) + else: + base = ha_url.replace("http://", "ws://", 1) + return base.rstrip("/") + "/api/websocket" + + +class WebSocketMonitor(Monitor): + """Persistent WebSocket connection to HA for real-time liveness monitoring. + + Subscribes to state_changed events — any traffic proves HA is alive. + The watchdog fires ha_websocket_dead when the connection is silent for + longer than silence_threshold, or immediately on disconnect. + ha_websocket_recovered is emitted when the connection is restored after + a dead alert was sent (allows supervisor to clear active incidents). + """ + + def __init__( + self, + ha_url: str, + token: str, + settings: Settings, + emitter: EventEmitter, + session: aiohttp.ClientSession, + ) -> None: + self._ws_url = _make_ws_url(ha_url) + self._token = token + self._settings = settings + self._emitter = emitter + self._session = session + + self._state: str = "disconnected" + self._last_event_monotonic: float = time.monotonic() + # 0.0 means no ha_websocket_dead has been emitted yet (for this session) + self._last_dead_alert_at: float = 0.0 + + self._stopping = False + self._msg_id = 0 + self._main_task: asyncio.Task | None = None + self._watchdog_task: asyncio.Task | None = None + + # ------------------------------------------------------------------ + # Monitor ABC + # ------------------------------------------------------------------ + + async def start(self) -> None: + if not self._settings.websocket_enabled: + _log.info("ws_monitor_disabled") + return + self._stopping = False + self._last_event_monotonic = time.monotonic() + self._main_task = asyncio.create_task( + self._connection_loop(), name="ws_connection_loop" + ) + self._watchdog_task = asyncio.create_task( + self._watchdog_loop(), name="ws_watchdog" + ) + _log.info("ws_monitor_started", ws_url=self._ws_url) + + async def stop(self) -> None: + self._stopping = True + self._state = "stopped" + tasks = [t for t in [self._main_task, self._watchdog_task] if t is not None] + for t in tasks: + t.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + self._main_task = None + self._watchdog_task = None + _log.info("ws_monitor_stopped") + + @property + def is_healthy(self) -> bool: + if not self._settings.websocket_enabled: + return True # disabled monitors are not unhealthy + return self._state == "subscribed" + + # ------------------------------------------------------------------ + # Connection loop — reconnects with exponential back-off + # ------------------------------------------------------------------ + + async def _connection_loop(self) -> None: + delay = float(self._settings.websocket_reconnect_initial_delay) + while not self._stopping: + self._state = "connecting" + clean_close = False + try: + await self._connect_and_listen() + clean_close = True + delay = float(self._settings.websocket_reconnect_initial_delay) + except asyncio.CancelledError: + raise + except _AuthError as exc: + _log.error("ws_auth_failed", error=str(exc)) + # Auth failures won't self-heal on fast retry — jump to max delay + delay = float(self._settings.websocket_reconnect_max_delay) + except Exception as exc: + _log.warning("ws_connect_error", error=str(exc)) + + self._state = "disconnected" + if not self._stopping: + self._on_disconnected() + + if self._stopping: + break + + if clean_close: + wait = 1.0 # brief pause before reconnecting after a clean HA close + else: + jitter_range = delay * self._settings.websocket_reconnect_jitter + wait = max(0.1, delay + random.uniform(-jitter_range, jitter_range)) + delay = min(delay * 2, float(self._settings.websocket_reconnect_max_delay)) + + _log.debug("ws_reconnect_wait", seconds=round(wait, 2)) + await asyncio.sleep(wait) + + # ------------------------------------------------------------------ + # Connect, auth, subscribe, receive + # ------------------------------------------------------------------ + + async def _connect_and_listen(self) -> None: + # Override the session-level timeout: WS must stay open indefinitely, + # only the initial TCP connect should be bounded. + ws_timeout = aiohttp.ClientTimeout(total=None, connect=10.0, sock_connect=10.0) + async with self._session.ws_connect( + self._ws_url, + timeout=ws_timeout, + heartbeat=30.0, + ) as ws: + self._state = "authenticating" + + # Receive auth_required + try: + msg = await asyncio.wait_for(ws.receive_json(), timeout=10.0) + except (asyncio.TimeoutError, TypeError, json.JSONDecodeError) as exc: + raise ConnectionError(f"Failed to receive auth_required: {exc}") from exc + + if msg.get("type") != "auth_required": + raise ConnectionError( + f"Unexpected initial message type: {msg.get('type')!r}" + ) + + await ws.send_json({"type": "auth", "access_token": self._token}) + + # Receive auth_ok or auth_invalid + try: + msg = await asyncio.wait_for(ws.receive_json(), timeout=10.0) + except (asyncio.TimeoutError, TypeError, json.JSONDecodeError) as exc: + raise ConnectionError(f"Failed to receive auth response: {exc}") from exc + + if msg.get("type") == "auth_invalid": + raise _AuthError(msg.get("message", "auth_invalid")) + if msg.get("type") != "auth_ok": + raise ConnectionError( + f"Unexpected auth response type: {msg.get('type')!r}" + ) + + # Subscribe to state_changed events + self._msg_id += 1 + await ws.send_json({ + "id": self._msg_id, + "type": "subscribe_events", + "event_type": "state_changed", + }) + + # Mark connected — capture prior dead state before resetting + prev_dead_at = self._last_dead_alert_at + self._state = "subscribed" + self._last_event_monotonic = time.monotonic() + + # Emit recovery if this reconnect follows a dead alert + if prev_dead_at > 0.0: + self._last_dead_alert_at = 0.0 + self._emit_recovered() + + _log.info("ws_subscribed", ws_url=self._ws_url) + + # Receive loop — any TEXT message proves HA is alive + async for raw in ws: + if self._stopping: + break + if raw.type == aiohttp.WSMsgType.TEXT: + self._last_event_monotonic = time.monotonic() + elif raw.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE): + _log.warning("ws_closed_by_server", msg_type=raw.type.name) + break + + # ------------------------------------------------------------------ + # Watchdog loop — detects silence while the WS appears connected + # ------------------------------------------------------------------ + + async def _watchdog_loop(self) -> None: + while not self._stopping: + try: + await asyncio.sleep(self._settings.websocket_watchdog_interval_seconds) + except asyncio.CancelledError: + raise + + if self._state != "subscribed": + continue # disconnects are handled by the connection loop + + now = time.monotonic() + silent_secs = now - self._last_event_monotonic + if silent_secs <= self._settings.websocket_silence_threshold_seconds: + continue + + cooldown = self._settings.websocket_down_alert_repeat_minutes * 60 + if self._last_dead_alert_at == 0.0 or (now - self._last_dead_alert_at) >= cooldown: + self._emitter.emit( + event_type=HAEventType.ha_websocket_dead.value, + severity=Severity.error.value, + service="homeassistant", + message=( + f"HA WebSocket silent for {silent_secs:.0f}s — no events received" + ), + payload=self._dead_payload(silent_secs), + ) + self._last_dead_alert_at = now + _log.warning("ws_silent_dead_emitted", silent_seconds=round(silent_secs)) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _on_disconnected(self) -> None: + """Emit ha_websocket_dead on connection loss, respecting cooldown.""" + if self._stopping: + return + now = time.monotonic() + cooldown = self._settings.websocket_down_alert_repeat_minutes * 60 + if self._last_dead_alert_at == 0.0 or (now - self._last_dead_alert_at) >= cooldown: + silent_secs = now - self._last_event_monotonic + self._emitter.emit( + event_type=HAEventType.ha_websocket_dead.value, + severity=Severity.error.value, + service="homeassistant", + message=f"HA WebSocket disconnected — silent for {silent_secs:.0f}s", + payload=self._dead_payload(silent_secs), + ) + self._last_dead_alert_at = now + _log.warning("ws_dead_emitted", silent_seconds=round(silent_secs)) + + def _emit_recovered(self) -> None: + self._emitter.emit( + event_type=HAEventType.ha_websocket_recovered.value, + severity=Severity.info.value, + service="homeassistant", + message="HA WebSocket reconnected and receiving events", + payload={"connection_state": "subscribed"}, + ) + _log.info("ws_recovered_emitted") + + def _dead_payload(self, silent_secs: float) -> dict: + event_age = time.monotonic() - self._last_event_monotonic + last_event_wall = time.time() - event_age + return { + "silent_seconds": round(silent_secs), + "last_event_at": datetime.fromtimestamp( + last_event_wall, tz=timezone.utc + ).isoformat(), + "connection_state": self._state, + } diff --git a/services/ha-diag-agent/tests/integration/test_websocket_monitor_integration.py b/services/ha-diag-agent/tests/integration/test_websocket_monitor_integration.py new file mode 100644 index 0000000..6466ac0 --- /dev/null +++ b/services/ha-diag-agent/tests/integration/test_websocket_monitor_integration.py @@ -0,0 +1,186 @@ +"""Integration tests for WebSocketMonitor against real HA instances. + +Requires: + docker compose -f tests/integration/docker-compose.ken.yml up -d + tests/integration/scripts/wait-for-ha.sh http://localhost:8123 + TEST_HA_TOKEN= pytest tests/ -m integration + +Container stop/restart tests additionally need Docker access from the host. +""" +from __future__ import annotations + +import asyncio +import subprocess +import time +from pathlib import Path + +import pytest + +from ha_diag.config import Settings +from ha_diag.event_emitter import EventEmitter +from ha_diag.models import HAEventType +from ha_diag.monitors.websocket_monitor import WebSocketMonitor +from ha_diag.ha_client import make_session + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_settings(ha_url: str, ha_token: str, **overrides) -> Settings: + defaults: dict = { + "ha_url": ha_url, + "ha_token": ha_token, + "node_name": "test-piha", + "location_tag": "ken", + "websocket_enabled": True, + "websocket_silence_threshold_seconds": 30, # low for fast test + "websocket_watchdog_interval_seconds": 5, + "websocket_reconnect_initial_delay": 1.0, + "websocket_reconnect_max_delay": 10.0, + "websocket_reconnect_jitter": 0.0, + "websocket_down_alert_repeat_minutes": 0, # always re-alert + } + defaults.update(overrides) + return Settings(**defaults) + + +def _emitted_types(events_dir: Path) -> list[str]: + return [ + __import__("json").loads(f.read_text())["type"] + for f in sorted(events_dir.glob("*.json")) + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +async def test_ws_normal_operation_no_false_alerts( + ha_ken_url: str, ha_token: str, tmp_path: Path +): + """Normal operation: monitor connects, subscribes, no dead alerts emitted.""" + events_dir = tmp_path / "events" + events_dir.mkdir() + settings = _make_settings(ha_ken_url, ha_token) + emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken") + + async with make_session(ha_token) as session: + monitor = WebSocketMonitor( + ha_url=ha_ken_url, + token=ha_token, + settings=settings, + emitter=emitter, + session=session, + ) + await monitor.start() + await asyncio.sleep(5) # let it connect and settle + assert monitor.is_healthy, "Monitor should be subscribed and healthy" + await monitor.stop() + + # No dead alerts during normal operation + types = _emitted_types(events_dir) + assert HAEventType.ha_websocket_dead.value not in types, ( + f"Unexpected dead alerts during normal operation: {types}" + ) + + +@pytest.mark.integration +async def test_ws_dead_emitted_when_ha_stops(ha_ken_url: str, ha_token: str, tmp_path: Path): + """Stopping the HA container triggers ha_websocket_dead.""" + events_dir = tmp_path / "events" + events_dir.mkdir() + settings = _make_settings(ha_ken_url, ha_token) + emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken") + + async with make_session(ha_token) as session: + monitor = WebSocketMonitor( + ha_url=ha_ken_url, + token=ha_token, + settings=settings, + emitter=emitter, + session=session, + ) + await monitor.start() + # Wait for initial subscription + for _ in range(20): + if monitor.is_healthy: + break + await asyncio.sleep(0.5) + assert monitor.is_healthy, "Monitor did not subscribe within 10s" + + # Stop HA container + subprocess.run( + ["docker", "stop", "ha-test-ken"], + check=True, capture_output=True, timeout=30, + ) + try: + # Wait for dead alert (up to 15s) + deadline = time.monotonic() + 15 + while time.monotonic() < deadline: + types = _emitted_types(events_dir) + if HAEventType.ha_websocket_dead.value in types: + break + await asyncio.sleep(0.5) + + types = _emitted_types(events_dir) + assert HAEventType.ha_websocket_dead.value in types, ( + "ha_websocket_dead not emitted after HA container stopped" + ) + finally: + await monitor.stop() + subprocess.run( + ["docker", "start", "ha-test-ken"], + check=False, capture_output=True, timeout=30, + ) + + +@pytest.mark.integration +async def test_ws_recovered_after_ha_restart(ha_ken_url: str, ha_token: str, tmp_path: Path): + """After HA restarts, monitor reconnects and emits ha_websocket_recovered.""" + events_dir = tmp_path / "events" + events_dir.mkdir() + settings = _make_settings(ha_ken_url, ha_token) + emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken") + + async with make_session(ha_token) as session: + monitor = WebSocketMonitor( + ha_url=ha_ken_url, + token=ha_token, + settings=settings, + emitter=emitter, + session=session, + ) + await monitor.start() + for _ in range(20): + if monitor.is_healthy: + break + await asyncio.sleep(0.5) + assert monitor.is_healthy + + # Stop then restart HA + subprocess.run(["docker", "stop", "ha-test-ken"], check=True, timeout=30) + await asyncio.sleep(2) + subprocess.run(["docker", "start", "ha-test-ken"], check=True, timeout=30) + + try: + # Wait for recovery (up to 60s — HA takes time to start) + deadline = time.monotonic() + 60 + while time.monotonic() < deadline: + types = _emitted_types(events_dir) + if HAEventType.ha_websocket_recovered.value in types: + break + await asyncio.sleep(1.0) + + types = _emitted_types(events_dir) + assert HAEventType.ha_websocket_dead.value in types, ( + "ha_websocket_dead not emitted after container stop" + ) + assert HAEventType.ha_websocket_recovered.value in types, ( + "ha_websocket_recovered not emitted after HA restarted" + ) + finally: + await monitor.stop() diff --git a/services/ha-diag-agent/tests/test_websocket_monitor.py b/services/ha-diag-agent/tests/test_websocket_monitor.py new file mode 100644 index 0000000..964a24d --- /dev/null +++ b/services/ha-diag-agent/tests/test_websocket_monitor.py @@ -0,0 +1,558 @@ +"""Unit tests for WebSocketMonitor.""" +from __future__ import annotations + +import asyncio +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from ha_diag.config import Settings +from ha_diag.event_emitter import EventEmitter +from ha_diag.models import HAEventType +from ha_diag.monitors.websocket_monitor import ( + WebSocketMonitor, + _AuthError, + _make_ws_url, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_settings(**overrides) -> Settings: + defaults: dict = { + "ha_url": "http://test.local:8123", + "ha_token": "test-token", + "node_name": "test-node", + "location_tag": "test-loc", + "websocket_enabled": True, + "websocket_silence_threshold_seconds": 300, + "websocket_watchdog_interval_seconds": 30, + "websocket_reconnect_initial_delay": 1.0, + "websocket_reconnect_max_delay": 60.0, + "websocket_reconnect_jitter": 0.0, + "websocket_down_alert_repeat_minutes": 10, + } + defaults.update(overrides) + return Settings(**defaults) + + +class FakeWS: + """Fake aiohttp ClientWebSocketResponse for unit tests.""" + + def __init__(self, auth_messages: list, event_messages: list | None = None): + self._auth_queue = list(auth_messages) + self._event_queue = list(event_messages or []) + self.sent: list = [] + + async def receive_json(self) -> dict: + if not self._auth_queue: + raise ConnectionError("FakeWS: no more auth messages") + return self._auth_queue.pop(0) + + async def send_json(self, data: dict) -> None: + self.sent.append(data) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._event_queue: + raise StopAsyncIteration + item = self._event_queue.pop(0) + if isinstance(item, BaseException): + raise item + return item + + +def _text_msg(data: str = '{"type":"event"}') -> aiohttp.WSMessage: + return aiohttp.WSMessage(type=aiohttp.WSMsgType.TEXT, data=data, extra=None) + + +def _close_msg() -> aiohttp.WSMessage: + return aiohttp.WSMessage(type=aiohttp.WSMsgType.CLOSE, data=b"", extra=None) + + +def _mock_session(fake_ws: FakeWS) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=fake_ws) + cm.__aexit__ = AsyncMock(return_value=False) + session = MagicMock() + session.ws_connect.return_value = cm + return session + + +def _make_monitor( + settings: Settings | None = None, + session=None, + emitter: EventEmitter | None = None, + tmp_path: Path | None = None, +) -> WebSocketMonitor: + if settings is None: + settings = _make_settings() + if emitter is None: + p = (tmp_path or Path("/tmp/ws_test_events")).absolute() + p.mkdir(parents=True, exist_ok=True) + emitter = EventEmitter(p, node_name="test-node") + if session is None: + session = MagicMock() + return WebSocketMonitor( + ha_url=settings.ha_url, + token=settings.ha_token, + settings=settings, + emitter=emitter, + session=session, + ) + + +# --------------------------------------------------------------------------- +# URL derivation +# --------------------------------------------------------------------------- + + +def test_make_ws_url_http(): + assert _make_ws_url("http://ha.local:8123") == "ws://ha.local:8123/api/websocket" + + +def test_make_ws_url_https(): + assert _make_ws_url("https://ha.example.com") == "wss://ha.example.com/api/websocket" + + +def test_make_ws_url_strips_trailing_slash(): + assert _make_ws_url("http://ha.local:8123/") == "ws://ha.local:8123/api/websocket" + + +# --------------------------------------------------------------------------- +# Auth flow (via _connect_and_listen) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normal_auth_sends_correct_messages(tmp_path): + """Happy path: sends auth + subscribe, ends in subscribed state.""" + fake_ws = FakeWS( + [{"type": "auth_required"}, {"type": "auth_ok"}], + [_text_msg('{"type":"result","id":1,"success":true}')], + ) + monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) + + await monitor._connect_and_listen() + + assert fake_ws.sent[0] == {"type": "auth", "access_token": "test-token"} + assert fake_ws.sent[1]["type"] == "subscribe_events" + assert fake_ws.sent[1]["event_type"] == "state_changed" + assert monitor._state == "subscribed" + + +@pytest.mark.asyncio +async def test_last_event_monotonic_updated_on_text_message(tmp_path): + """Receiving TEXT messages updates last_event_monotonic.""" + fake_ws = FakeWS( + [{"type": "auth_required"}, {"type": "auth_ok"}], + [_text_msg(), _text_msg()], + ) + monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) + before = time.monotonic() + + await monitor._connect_and_listen() + + assert monitor._last_event_monotonic >= before + + +@pytest.mark.asyncio +async def test_auth_invalid_raises_auth_error(tmp_path): + """auth_invalid → _AuthError propagates.""" + fake_ws = FakeWS([ + {"type": "auth_required"}, + {"type": "auth_invalid", "message": "invalid token"}, + ]) + monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) + + with pytest.raises(_AuthError, match="invalid token"): + await monitor._connect_and_listen() + + +@pytest.mark.asyncio +async def test_unexpected_initial_message_raises(tmp_path): + """Anything other than auth_required on connect → ConnectionError.""" + fake_ws = FakeWS([{"type": "unexpected"}]) + monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) + + with pytest.raises(ConnectionError, match="Unexpected initial"): + await monitor._connect_and_listen() + + +@pytest.mark.asyncio +async def test_empty_auth_queue_raises_connection_error(tmp_path): + """Connection closed before auth_required → ConnectionError.""" + fake_ws = FakeWS([]) + monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path) + + with pytest.raises(ConnectionError): + await monitor._connect_and_listen() + + +# --------------------------------------------------------------------------- +# Disconnect / dead alerts (_on_disconnected) +# --------------------------------------------------------------------------- + + +def test_on_disconnected_emits_ha_websocket_dead(tmp_path): + emitter = MagicMock() + monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path) + monitor._state = "disconnected" + + monitor._on_disconnected() + + emitter.emit.assert_called_once() + assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value + + +def test_on_disconnected_within_cooldown_suppresses_second_emit(tmp_path): + emitter = MagicMock() + monitor = _make_monitor( + settings=_make_settings(websocket_down_alert_repeat_minutes=10), + emitter=emitter, + tmp_path=tmp_path, + ) + monitor._state = "disconnected" + + monitor._on_disconnected() # first emit + emitter.emit.reset_mock() + monitor._on_disconnected() # within cooldown → suppressed + + emitter.emit.assert_not_called() + + +def test_on_disconnected_after_cooldown_emits_again(tmp_path): + emitter = MagicMock() + monitor = _make_monitor( + settings=_make_settings(websocket_down_alert_repeat_minutes=10), + emitter=emitter, + tmp_path=tmp_path, + ) + monitor._state = "disconnected" + monitor._on_disconnected() + # Backdate to simulate cooldown expiry + monitor._last_dead_alert_at = time.monotonic() - (10 * 60 + 5) + emitter.emit.reset_mock() + + monitor._on_disconnected() + + emitter.emit.assert_called_once() + + +def test_on_disconnected_noop_when_stopping(tmp_path): + emitter = MagicMock() + monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path) + monitor._stopping = True + + monitor._on_disconnected() + + emitter.emit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Recovery +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reconnect_after_dead_emits_recovered(tmp_path): + """Successful reconnect after a dead alert emits ha_websocket_recovered.""" + emitter = MagicMock() + fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], []) + settings = _make_settings() + monitor = WebSocketMonitor( + ha_url=settings.ha_url, + token=settings.ha_token, + settings=settings, + emitter=emitter, + session=_mock_session(fake_ws), + ) + monitor._last_dead_alert_at = time.monotonic() - 30.0 # prior dead was sent + + await monitor._connect_and_listen() + + emitted_types = [c[1]["event_type"] for c in emitter.emit.call_args_list] + assert HAEventType.ha_websocket_recovered.value in emitted_types + assert monitor._last_dead_alert_at == 0.0 # reset after recovery + + +@pytest.mark.asyncio +async def test_no_recovered_if_no_prior_dead(tmp_path): + """First-ever connect with no prior dead alert → no recovered emitted.""" + emitter = MagicMock() + fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], []) + settings = _make_settings() + monitor = WebSocketMonitor( + ha_url=settings.ha_url, + token=settings.ha_token, + settings=settings, + emitter=emitter, + session=_mock_session(fake_ws), + ) + monitor._last_dead_alert_at = 0.0 + + await monitor._connect_and_listen() + + emitter.emit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Watchdog loop +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_watchdog_emits_dead_when_silent_over_threshold(tmp_path): + """Watchdog detects silence > threshold and emits ha_websocket_dead.""" + emitter = MagicMock() + settings = _make_settings( + websocket_silence_threshold_seconds=60, + websocket_watchdog_interval_seconds=30, + websocket_down_alert_repeat_minutes=0, + ) + monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) + monitor._state = "subscribed" + monitor._last_event_monotonic = time.monotonic() - 120.0 # 120s > 60s threshold + monitor._last_dead_alert_at = 0.0 + + sleep_calls = 0 + + async def one_iteration(t): + nonlocal sleep_calls + sleep_calls += 1 + if sleep_calls >= 2: + raise asyncio.CancelledError() + + with patch("asyncio.sleep", side_effect=one_iteration): + with pytest.raises(asyncio.CancelledError): + await monitor._watchdog_loop() + + emitter.emit.assert_called_once() + assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value + + +@pytest.mark.asyncio +async def test_watchdog_no_emit_when_events_recent(tmp_path): + """Watchdog does not emit when last event is within silence threshold.""" + emitter = MagicMock() + settings = _make_settings( + websocket_silence_threshold_seconds=300, + websocket_watchdog_interval_seconds=30, + websocket_down_alert_repeat_minutes=0, + ) + monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) + monitor._state = "subscribed" + monitor._last_event_monotonic = time.monotonic() - 10.0 # recent + + sleep_calls = 0 + + async def one_iteration(t): + nonlocal sleep_calls + sleep_calls += 1 + if sleep_calls >= 2: + raise asyncio.CancelledError() + + with patch("asyncio.sleep", side_effect=one_iteration): + with pytest.raises(asyncio.CancelledError): + await monitor._watchdog_loop() + + emitter.emit.assert_not_called() + + +@pytest.mark.asyncio +async def test_watchdog_skips_when_not_subscribed(tmp_path): + """Watchdog does not emit when state is not 'subscribed'.""" + emitter = MagicMock() + settings = _make_settings( + websocket_silence_threshold_seconds=1, + websocket_watchdog_interval_seconds=30, + websocket_down_alert_repeat_minutes=0, + ) + monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) + monitor._state = "disconnected" + monitor._last_event_monotonic = time.monotonic() - 9999.0 # very old + + sleep_calls = 0 + + async def one_iteration(t): + nonlocal sleep_calls + sleep_calls += 1 + if sleep_calls >= 2: + raise asyncio.CancelledError() + + with patch("asyncio.sleep", side_effect=one_iteration): + with pytest.raises(asyncio.CancelledError): + await monitor._watchdog_loop() + + emitter.emit.assert_not_called() + + +@pytest.mark.asyncio +async def test_watchdog_repeat_alert_respects_cooldown(tmp_path): + """Second watchdog dead alert fires only after cooldown.""" + emitter = MagicMock() + settings = _make_settings( + websocket_silence_threshold_seconds=60, + websocket_watchdog_interval_seconds=30, + websocket_down_alert_repeat_minutes=10, + ) + monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path) + monitor._state = "subscribed" + monitor._last_event_monotonic = time.monotonic() - 3600.0 # 1hr silent + # Set last alert to just now → still within 10-min cooldown + monitor._last_dead_alert_at = time.monotonic() + + sleep_calls = 0 + + async def one_iteration(t): + nonlocal sleep_calls + sleep_calls += 1 + if sleep_calls >= 2: + raise asyncio.CancelledError() + + with patch("asyncio.sleep", side_effect=one_iteration): + with pytest.raises(asyncio.CancelledError): + await monitor._watchdog_loop() + + emitter.emit.assert_not_called() # within cooldown + + +# --------------------------------------------------------------------------- +# Reconnect backoff +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reconnect_backoff_doubles_each_attempt(tmp_path): + """Retry delay doubles on consecutive failures.""" + delays: list[float] = [] + call_count = 0 + + async def fail_connect(): + nonlocal call_count + call_count += 1 + raise ConnectionError("refused") + + async def capture_sleep(t): + delays.append(t) + if call_count >= 3: + raise asyncio.CancelledError() + + settings = _make_settings( + websocket_reconnect_initial_delay=1.0, + websocket_reconnect_max_delay=60.0, + websocket_reconnect_jitter=0.0, + ) + monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path) + monitor._connect_and_listen = fail_connect + + with patch("asyncio.sleep", side_effect=capture_sleep): + with pytest.raises(asyncio.CancelledError): + await monitor._connection_loop() + + assert len(delays) >= 2 + assert delays[0] == pytest.approx(1.0) + assert delays[1] == pytest.approx(2.0) + + +@pytest.mark.asyncio +async def test_reconnect_delay_capped_at_max(tmp_path): + """Delay never exceeds websocket_reconnect_max_delay.""" + delays: list[float] = [] + call_count = 0 + + async def fail_connect(): + nonlocal call_count + call_count += 1 + raise ConnectionError("refused") + + async def capture_sleep(t): + delays.append(t) + if call_count >= 8: + raise asyncio.CancelledError() + + settings = _make_settings( + websocket_reconnect_initial_delay=1.0, + websocket_reconnect_max_delay=8.0, + websocket_reconnect_jitter=0.0, + ) + monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path) + monitor._connect_and_listen = fail_connect + + with patch("asyncio.sleep", side_effect=capture_sleep): + with pytest.raises(asyncio.CancelledError): + await monitor._connection_loop() + + assert max(delays) <= 8.0 + + +# --------------------------------------------------------------------------- +# is_healthy +# --------------------------------------------------------------------------- + + +def test_is_healthy_true_when_subscribed(tmp_path): + monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) + monitor._state = "subscribed" + assert monitor.is_healthy is True + + +def test_is_healthy_false_when_disconnected(tmp_path): + monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) + monitor._state = "disconnected" + assert monitor.is_healthy is False + + +def test_is_healthy_false_when_connecting(tmp_path): + monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path) + monitor._state = "connecting" + assert monitor.is_healthy is False + + +def test_is_healthy_true_when_disabled(tmp_path): + """Disabled monitor reports healthy — it's off, not broken.""" + monitor = _make_monitor(settings=_make_settings(websocket_enabled=False), tmp_path=tmp_path) + monitor._state = "disconnected" + assert monitor.is_healthy is True + + +# --------------------------------------------------------------------------- +# start / stop lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_background_tasks(tmp_path): + """stop() cancels the main and watchdog tasks.""" + + async def hang(): + await asyncio.sleep(9999) + + monitor = _make_monitor(tmp_path=tmp_path) + monitor._main_task = asyncio.create_task(hang()) + monitor._watchdog_task = asyncio.create_task(hang()) + + await monitor.stop() + + assert monitor._main_task is None + assert monitor._watchdog_task is None + + +@pytest.mark.asyncio +async def test_start_no_tasks_when_disabled(tmp_path): + """start() with websocket_enabled=False does not spawn tasks.""" + monitor = _make_monitor( + settings=_make_settings(websocket_enabled=False), + tmp_path=tmp_path, + ) + await monitor.start() + assert monitor._main_task is None + assert monitor._watchdog_task is None