feat(ha-diag-agent): WebSocketMonitor for real-time HA liveness

- persistent WS connection to HA with auth + state_changed subscription
- watchdog detects silence > 5min → emits ha_websocket_dead
- immediate ha_websocket_dead on disconnect, exponential reconnect with jitter
- cooldown prevents alert spam (10min repeat window while HA stays down)
- ha_websocket_recovered emitted on reconnect after a dead alert (allows
  supervisor to clear active incidents in Phase 5)
- new monitors/ subpackage for long-running tasks (vs interval checks/)
- /health endpoint now includes ws_connected field
- 26 unit tests, 3 integration tests (real HA + container stop/restart)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Oskar Kapala 2026-05-29 15:00:18 +02:00
parent 3499b2f280
commit 31b48d162a
10 changed files with 1112 additions and 7 deletions

View file

@ -10,12 +10,21 @@ no direct supervisor integration, events processed by the VPS observer.
## Architecture ## Architecture
``` ```
APScheduler (every CHECK_INTERVAL s) APScheduler (interval-based REST checks)
└─ HeartbeatCheck → pings /api/, emits ha_websocket_dead on failure ├─ HeartbeatCheck → pings /api/, emits ha_websocket_dead on failure
[Phase 3: EntityUnavailableCheck, SystemHealthCheck, UpdateCheck, ...] ├─ UnavailableEntitiesCheck → entity unavailable > threshold
├─ SystemHealthCheck → /api/system_health per-integration status
├─ AutomationFailuresCheck → automation last-run error traces
└─ UpdatesAvailableCheck → pending HA/integration updates
WebSocketMonitor (persistent, long-running — Phase 4b)
└─ Maintains a live WS subscription to state_changed events
Any traffic = HA is alive. Watchdog fires ha_websocket_dead on
silence > 5min or on disconnect. Emits ha_websocket_recovered
when the connection is restored after a dead alert.
FastAPI (port 8087) FastAPI (port 8087)
GET /health → liveness probe GET /health → liveness probe (includes ws_connected field)
POST /trigger/<check> → run a named check on demand POST /trigger/<check> → run a named check on demand
SQLite (/data/ha_diag.db) SQLite (/data/ha_diag.db)
@ -24,11 +33,15 @@ SQLite (/data/ha_diag.db)
alerts_sent → dedup gate for alert events alerts_sent → dedup gate for alert events
``` ```
The WebSocketMonitor is the only persistent-connection component; all other
checks are APScheduler intervals (stateless REST polls).
## Event Types ## Event Types
| Type | Severity | Trigger | | Type | Severity | Trigger |
|------|----------|---------| |------|----------|---------|
| `ha_websocket_dead` | error | HA /api/ unreachable | | `ha_websocket_dead` | error | WS disconnect, silence > 5min, or /api/ unreachable |
| `ha_websocket_recovered` | info | WS reconnected after a dead alert (clears incident) |
| `ha_integration_failed` | error | Integration in error state | | `ha_integration_failed` | error | Integration in error state |
| `ha_entity_unavailable_long` | warning | Entity unavailable > threshold | | `ha_entity_unavailable_long` | warning | Entity unavailable > threshold |
| `ha_automation_failing` | warning | Automation last run errored | | `ha_automation_failing` | warning | Automation last run errored |
@ -37,6 +50,7 @@ SQLite (/data/ha_diag.db)
| `ha_system_health_degraded` | warning | System health check failed | | `ha_system_health_degraded` | warning | System health check failed |
Event routing in supervisor (Phase 5) maps these to `notify` actions. Event routing in supervisor (Phase 5) maps these to `notify` actions.
`ha_websocket_recovered` should be routed to clear any active `ha_websocket_dead` incident.
## Deployment model ## Deployment model

View file

@ -6,11 +6,13 @@ from fastapi import FastAPI, HTTPException
if TYPE_CHECKING: if TYPE_CHECKING:
from .checks.base import Check from .checks.base import Check
from .monitors.base import Monitor
app = FastAPI(title="ha-diag-agent", version="0.1.0") app = FastAPI(title="ha-diag-agent", version="0.1.0")
# Populated by main.py during startup # Populated by main.py during startup
_checks: dict[str, "Check"] = {} _checks: dict[str, "Check"] = {}
_ws_monitor: "Monitor | None" = None
_node_name: str = "unknown" _node_name: str = "unknown"
_location_tag: str = "default" _location_tag: str = "default"
@ -22,14 +24,22 @@ def register_checks(checks: list["Check"], node_name: str, location_tag: str) ->
_location_tag = location_tag _location_tag = location_tag
def register_ws_monitor(monitor: "Monitor") -> None:
global _ws_monitor
_ws_monitor = monitor
@app.get("/health") @app.get("/health")
async def health() -> dict: async def health() -> dict:
return { response: dict = {
"status": "ok", "status": "ok",
"node": _node_name, "node": _node_name,
"location_tag": _location_tag, "location_tag": _location_tag,
"checks": list(_checks.keys()), "checks": list(_checks.keys()),
} }
if _ws_monitor is not None:
response["ws_connected"] = _ws_monitor.is_healthy
return response
@app.post("/trigger/{check_name}") @app.post("/trigger/{check_name}")

View file

@ -45,6 +45,15 @@ class Settings(BaseSettings):
updates_check_minute: int = 0 updates_check_minute: int = 0
updates_cooldown_days: int = 7 # don't re-alert same update within N days updates_cooldown_days: int = 7 # don't re-alert same update within N days
# WebSocket monitor
websocket_enabled: bool = True
websocket_silence_threshold_seconds: int = 300 # 5 min
websocket_watchdog_interval_seconds: int = 30
websocket_reconnect_initial_delay: float = 1.0
websocket_reconnect_max_delay: float = 60.0
websocket_reconnect_jitter: float = 0.2 # ±20% of delay
websocket_down_alert_repeat_minutes: int = 10
# API server # API server
port: int = 8087 port: int = 8087
log_level: str = "info" log_level: str = "info"

View file

@ -10,7 +10,7 @@ import structlog
import uvicorn import uvicorn
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from .api import app, register_checks from .api import app, register_checks, register_ws_monitor
from .checks.automation_failures import AutomationFailuresCheck from .checks.automation_failures import AutomationFailuresCheck
from .checks.heartbeat import HeartbeatCheck from .checks.heartbeat import HeartbeatCheck
from .checks.system_health import SystemHealthCheck from .checks.system_health import SystemHealthCheck
@ -19,6 +19,7 @@ from .checks.updates_available import UpdatesAvailableCheck, UpdatesDigestCheck
from .config import Settings from .config import Settings
from .event_emitter import EventEmitter from .event_emitter import EventEmitter
from .ha_client import HAClient, make_session from .ha_client import HAClient, make_session
from .monitors import WebSocketMonitor
from .storage import Storage from .storage import Storage
_log = structlog.get_logger() _log = structlog.get_logger()
@ -112,6 +113,15 @@ async def run(settings: Settings) -> None:
updates_daily, updates_digest] updates_daily, updates_digest]
register_checks(all_checks, settings.node_name, settings.location_tag) register_checks(all_checks, settings.node_name, settings.location_tag)
ws_monitor = WebSocketMonitor(
ha_url=settings.ha_url,
token=settings.ha_token,
settings=settings,
emitter=emitter,
session=session,
)
register_ws_monitor(ws_monitor)
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
scheduler.add_job( scheduler.add_job(
_run_check_and_emit, "interval", _run_check_and_emit, "interval",
@ -167,6 +177,8 @@ async def run(settings: Settings) -> None:
updates_hour=settings.updates_check_hour, updates_hour=settings.updates_check_hour,
) )
await ws_monitor.start()
config = uvicorn.Config( config = uvicorn.Config(
app, app,
host="0.0.0.0", host="0.0.0.0",
@ -177,6 +189,7 @@ async def run(settings: Settings) -> None:
try: try:
await server.serve() await server.serve()
finally: finally:
await ws_monitor.stop()
scheduler.shutdown(wait=False) scheduler.shutdown(wait=False)
await storage.close() await storage.close()
await session.close() await session.close()

View file

@ -16,6 +16,7 @@ class HAEventType(str, Enum):
ha_integration_failed = "ha_integration_failed" ha_integration_failed = "ha_integration_failed"
ha_entity_unavailable_long = "ha_entity_unavailable_long" ha_entity_unavailable_long = "ha_entity_unavailable_long"
ha_websocket_dead = "ha_websocket_dead" ha_websocket_dead = "ha_websocket_dead"
ha_websocket_recovered = "ha_websocket_recovered"
ha_automation_failing = "ha_automation_failing" ha_automation_failing = "ha_automation_failing"
ha_update_available = "ha_update_available" ha_update_available = "ha_update_available"
ha_recorder_lag = "ha_recorder_lag" ha_recorder_lag = "ha_recorder_lag"

View file

@ -0,0 +1,4 @@
from .base import Monitor
from .websocket_monitor import WebSocketMonitor
__all__ = ["Monitor", "WebSocketMonitor"]

View file

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
class Monitor(ABC):
"""Base class for long-running background monitors.
Unlike checks (one-shot, APScheduler-driven), monitors maintain
persistent state connections, subscriptions, background tasks.
"""
@abstractmethod
async def start(self) -> None:
"""Spawn background task(s). Idempotent if already started."""
@abstractmethod
async def stop(self) -> None:
"""Cancel background tasks and wait for cleanup."""
@property
@abstractmethod
def is_healthy(self) -> bool:
"""True when the monitor is running and its connection is live."""

View file

@ -0,0 +1,286 @@
from __future__ import annotations
import asyncio
import json
import random
import time
from datetime import datetime, timezone
import aiohttp
import structlog
from ..config import Settings
from ..event_emitter import EventEmitter
from ..models import HAEventType, Severity
from .base import Monitor
_log = structlog.get_logger().bind(monitor="websocket")
class _AuthError(Exception):
"""Raised when HA returns auth_invalid during the WS handshake."""
def _make_ws_url(ha_url: str) -> str:
if ha_url.startswith("https://"):
base = ha_url.replace("https://", "wss://", 1)
else:
base = ha_url.replace("http://", "ws://", 1)
return base.rstrip("/") + "/api/websocket"
class WebSocketMonitor(Monitor):
"""Persistent WebSocket connection to HA for real-time liveness monitoring.
Subscribes to state_changed events any traffic proves HA is alive.
The watchdog fires ha_websocket_dead when the connection is silent for
longer than silence_threshold, or immediately on disconnect.
ha_websocket_recovered is emitted when the connection is restored after
a dead alert was sent (allows supervisor to clear active incidents).
"""
def __init__(
self,
ha_url: str,
token: str,
settings: Settings,
emitter: EventEmitter,
session: aiohttp.ClientSession,
) -> None:
self._ws_url = _make_ws_url(ha_url)
self._token = token
self._settings = settings
self._emitter = emitter
self._session = session
self._state: str = "disconnected"
self._last_event_monotonic: float = time.monotonic()
# 0.0 means no ha_websocket_dead has been emitted yet (for this session)
self._last_dead_alert_at: float = 0.0
self._stopping = False
self._msg_id = 0
self._main_task: asyncio.Task | None = None
self._watchdog_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# Monitor ABC
# ------------------------------------------------------------------
async def start(self) -> None:
if not self._settings.websocket_enabled:
_log.info("ws_monitor_disabled")
return
self._stopping = False
self._last_event_monotonic = time.monotonic()
self._main_task = asyncio.create_task(
self._connection_loop(), name="ws_connection_loop"
)
self._watchdog_task = asyncio.create_task(
self._watchdog_loop(), name="ws_watchdog"
)
_log.info("ws_monitor_started", ws_url=self._ws_url)
async def stop(self) -> None:
self._stopping = True
self._state = "stopped"
tasks = [t for t in [self._main_task, self._watchdog_task] if t is not None]
for t in tasks:
t.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._main_task = None
self._watchdog_task = None
_log.info("ws_monitor_stopped")
@property
def is_healthy(self) -> bool:
if not self._settings.websocket_enabled:
return True # disabled monitors are not unhealthy
return self._state == "subscribed"
# ------------------------------------------------------------------
# Connection loop — reconnects with exponential back-off
# ------------------------------------------------------------------
async def _connection_loop(self) -> None:
delay = float(self._settings.websocket_reconnect_initial_delay)
while not self._stopping:
self._state = "connecting"
clean_close = False
try:
await self._connect_and_listen()
clean_close = True
delay = float(self._settings.websocket_reconnect_initial_delay)
except asyncio.CancelledError:
raise
except _AuthError as exc:
_log.error("ws_auth_failed", error=str(exc))
# Auth failures won't self-heal on fast retry — jump to max delay
delay = float(self._settings.websocket_reconnect_max_delay)
except Exception as exc:
_log.warning("ws_connect_error", error=str(exc))
self._state = "disconnected"
if not self._stopping:
self._on_disconnected()
if self._stopping:
break
if clean_close:
wait = 1.0 # brief pause before reconnecting after a clean HA close
else:
jitter_range = delay * self._settings.websocket_reconnect_jitter
wait = max(0.1, delay + random.uniform(-jitter_range, jitter_range))
delay = min(delay * 2, float(self._settings.websocket_reconnect_max_delay))
_log.debug("ws_reconnect_wait", seconds=round(wait, 2))
await asyncio.sleep(wait)
# ------------------------------------------------------------------
# Connect, auth, subscribe, receive
# ------------------------------------------------------------------
async def _connect_and_listen(self) -> None:
# Override the session-level timeout: WS must stay open indefinitely,
# only the initial TCP connect should be bounded.
ws_timeout = aiohttp.ClientTimeout(total=None, connect=10.0, sock_connect=10.0)
async with self._session.ws_connect(
self._ws_url,
timeout=ws_timeout,
heartbeat=30.0,
) as ws:
self._state = "authenticating"
# Receive auth_required
try:
msg = await asyncio.wait_for(ws.receive_json(), timeout=10.0)
except (asyncio.TimeoutError, TypeError, json.JSONDecodeError) as exc:
raise ConnectionError(f"Failed to receive auth_required: {exc}") from exc
if msg.get("type") != "auth_required":
raise ConnectionError(
f"Unexpected initial message type: {msg.get('type')!r}"
)
await ws.send_json({"type": "auth", "access_token": self._token})
# Receive auth_ok or auth_invalid
try:
msg = await asyncio.wait_for(ws.receive_json(), timeout=10.0)
except (asyncio.TimeoutError, TypeError, json.JSONDecodeError) as exc:
raise ConnectionError(f"Failed to receive auth response: {exc}") from exc
if msg.get("type") == "auth_invalid":
raise _AuthError(msg.get("message", "auth_invalid"))
if msg.get("type") != "auth_ok":
raise ConnectionError(
f"Unexpected auth response type: {msg.get('type')!r}"
)
# Subscribe to state_changed events
self._msg_id += 1
await ws.send_json({
"id": self._msg_id,
"type": "subscribe_events",
"event_type": "state_changed",
})
# Mark connected — capture prior dead state before resetting
prev_dead_at = self._last_dead_alert_at
self._state = "subscribed"
self._last_event_monotonic = time.monotonic()
# Emit recovery if this reconnect follows a dead alert
if prev_dead_at > 0.0:
self._last_dead_alert_at = 0.0
self._emit_recovered()
_log.info("ws_subscribed", ws_url=self._ws_url)
# Receive loop — any TEXT message proves HA is alive
async for raw in ws:
if self._stopping:
break
if raw.type == aiohttp.WSMsgType.TEXT:
self._last_event_monotonic = time.monotonic()
elif raw.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
_log.warning("ws_closed_by_server", msg_type=raw.type.name)
break
# ------------------------------------------------------------------
# Watchdog loop — detects silence while the WS appears connected
# ------------------------------------------------------------------
async def _watchdog_loop(self) -> None:
while not self._stopping:
try:
await asyncio.sleep(self._settings.websocket_watchdog_interval_seconds)
except asyncio.CancelledError:
raise
if self._state != "subscribed":
continue # disconnects are handled by the connection loop
now = time.monotonic()
silent_secs = now - self._last_event_monotonic
if silent_secs <= self._settings.websocket_silence_threshold_seconds:
continue
cooldown = self._settings.websocket_down_alert_repeat_minutes * 60
if self._last_dead_alert_at == 0.0 or (now - self._last_dead_alert_at) >= cooldown:
self._emitter.emit(
event_type=HAEventType.ha_websocket_dead.value,
severity=Severity.error.value,
service="homeassistant",
message=(
f"HA WebSocket silent for {silent_secs:.0f}s — no events received"
),
payload=self._dead_payload(silent_secs),
)
self._last_dead_alert_at = now
_log.warning("ws_silent_dead_emitted", silent_seconds=round(silent_secs))
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _on_disconnected(self) -> None:
"""Emit ha_websocket_dead on connection loss, respecting cooldown."""
if self._stopping:
return
now = time.monotonic()
cooldown = self._settings.websocket_down_alert_repeat_minutes * 60
if self._last_dead_alert_at == 0.0 or (now - self._last_dead_alert_at) >= cooldown:
silent_secs = now - self._last_event_monotonic
self._emitter.emit(
event_type=HAEventType.ha_websocket_dead.value,
severity=Severity.error.value,
service="homeassistant",
message=f"HA WebSocket disconnected — silent for {silent_secs:.0f}s",
payload=self._dead_payload(silent_secs),
)
self._last_dead_alert_at = now
_log.warning("ws_dead_emitted", silent_seconds=round(silent_secs))
def _emit_recovered(self) -> None:
self._emitter.emit(
event_type=HAEventType.ha_websocket_recovered.value,
severity=Severity.info.value,
service="homeassistant",
message="HA WebSocket reconnected and receiving events",
payload={"connection_state": "subscribed"},
)
_log.info("ws_recovered_emitted")
def _dead_payload(self, silent_secs: float) -> dict:
event_age = time.monotonic() - self._last_event_monotonic
last_event_wall = time.time() - event_age
return {
"silent_seconds": round(silent_secs),
"last_event_at": datetime.fromtimestamp(
last_event_wall, tz=timezone.utc
).isoformat(),
"connection_state": self._state,
}

View file

@ -0,0 +1,186 @@
"""Integration tests for WebSocketMonitor against real HA instances.
Requires:
docker compose -f tests/integration/docker-compose.ken.yml up -d
tests/integration/scripts/wait-for-ha.sh http://localhost:8123
TEST_HA_TOKEN=<long-lived-token> pytest tests/ -m integration
Container stop/restart tests additionally need Docker access from the host.
"""
from __future__ import annotations
import asyncio
import subprocess
import time
from pathlib import Path
import pytest
from ha_diag.config import Settings
from ha_diag.event_emitter import EventEmitter
from ha_diag.models import HAEventType
from ha_diag.monitors.websocket_monitor import WebSocketMonitor
from ha_diag.ha_client import make_session
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_settings(ha_url: str, ha_token: str, **overrides) -> Settings:
defaults: dict = {
"ha_url": ha_url,
"ha_token": ha_token,
"node_name": "test-piha",
"location_tag": "ken",
"websocket_enabled": True,
"websocket_silence_threshold_seconds": 30, # low for fast test
"websocket_watchdog_interval_seconds": 5,
"websocket_reconnect_initial_delay": 1.0,
"websocket_reconnect_max_delay": 10.0,
"websocket_reconnect_jitter": 0.0,
"websocket_down_alert_repeat_minutes": 0, # always re-alert
}
defaults.update(overrides)
return Settings(**defaults)
def _emitted_types(events_dir: Path) -> list[str]:
return [
__import__("json").loads(f.read_text())["type"]
for f in sorted(events_dir.glob("*.json"))
]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.integration
async def test_ws_normal_operation_no_false_alerts(
ha_ken_url: str, ha_token: str, tmp_path: Path
):
"""Normal operation: monitor connects, subscribes, no dead alerts emitted."""
events_dir = tmp_path / "events"
events_dir.mkdir()
settings = _make_settings(ha_ken_url, ha_token)
emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken")
async with make_session(ha_token) as session:
monitor = WebSocketMonitor(
ha_url=ha_ken_url,
token=ha_token,
settings=settings,
emitter=emitter,
session=session,
)
await monitor.start()
await asyncio.sleep(5) # let it connect and settle
assert monitor.is_healthy, "Monitor should be subscribed and healthy"
await monitor.stop()
# No dead alerts during normal operation
types = _emitted_types(events_dir)
assert HAEventType.ha_websocket_dead.value not in types, (
f"Unexpected dead alerts during normal operation: {types}"
)
@pytest.mark.integration
async def test_ws_dead_emitted_when_ha_stops(ha_ken_url: str, ha_token: str, tmp_path: Path):
"""Stopping the HA container triggers ha_websocket_dead."""
events_dir = tmp_path / "events"
events_dir.mkdir()
settings = _make_settings(ha_ken_url, ha_token)
emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken")
async with make_session(ha_token) as session:
monitor = WebSocketMonitor(
ha_url=ha_ken_url,
token=ha_token,
settings=settings,
emitter=emitter,
session=session,
)
await monitor.start()
# Wait for initial subscription
for _ in range(20):
if monitor.is_healthy:
break
await asyncio.sleep(0.5)
assert monitor.is_healthy, "Monitor did not subscribe within 10s"
# Stop HA container
subprocess.run(
["docker", "stop", "ha-test-ken"],
check=True, capture_output=True, timeout=30,
)
try:
# Wait for dead alert (up to 15s)
deadline = time.monotonic() + 15
while time.monotonic() < deadline:
types = _emitted_types(events_dir)
if HAEventType.ha_websocket_dead.value in types:
break
await asyncio.sleep(0.5)
types = _emitted_types(events_dir)
assert HAEventType.ha_websocket_dead.value in types, (
"ha_websocket_dead not emitted after HA container stopped"
)
finally:
await monitor.stop()
subprocess.run(
["docker", "start", "ha-test-ken"],
check=False, capture_output=True, timeout=30,
)
@pytest.mark.integration
async def test_ws_recovered_after_ha_restart(ha_ken_url: str, ha_token: str, tmp_path: Path):
"""After HA restarts, monitor reconnects and emits ha_websocket_recovered."""
events_dir = tmp_path / "events"
events_dir.mkdir()
settings = _make_settings(ha_ken_url, ha_token)
emitter = EventEmitter(events_dir, node_name="test-piha", location_tag="ken")
async with make_session(ha_token) as session:
monitor = WebSocketMonitor(
ha_url=ha_ken_url,
token=ha_token,
settings=settings,
emitter=emitter,
session=session,
)
await monitor.start()
for _ in range(20):
if monitor.is_healthy:
break
await asyncio.sleep(0.5)
assert monitor.is_healthy
# Stop then restart HA
subprocess.run(["docker", "stop", "ha-test-ken"], check=True, timeout=30)
await asyncio.sleep(2)
subprocess.run(["docker", "start", "ha-test-ken"], check=True, timeout=30)
try:
# Wait for recovery (up to 60s — HA takes time to start)
deadline = time.monotonic() + 60
while time.monotonic() < deadline:
types = _emitted_types(events_dir)
if HAEventType.ha_websocket_recovered.value in types:
break
await asyncio.sleep(1.0)
types = _emitted_types(events_dir)
assert HAEventType.ha_websocket_dead.value in types, (
"ha_websocket_dead not emitted after container stop"
)
assert HAEventType.ha_websocket_recovered.value in types, (
"ha_websocket_recovered not emitted after HA restarted"
)
finally:
await monitor.stop()

View file

@ -0,0 +1,558 @@
"""Unit tests for WebSocketMonitor."""
from __future__ import annotations
import asyncio
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from ha_diag.config import Settings
from ha_diag.event_emitter import EventEmitter
from ha_diag.models import HAEventType
from ha_diag.monitors.websocket_monitor import (
WebSocketMonitor,
_AuthError,
_make_ws_url,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_settings(**overrides) -> Settings:
defaults: dict = {
"ha_url": "http://test.local:8123",
"ha_token": "test-token",
"node_name": "test-node",
"location_tag": "test-loc",
"websocket_enabled": True,
"websocket_silence_threshold_seconds": 300,
"websocket_watchdog_interval_seconds": 30,
"websocket_reconnect_initial_delay": 1.0,
"websocket_reconnect_max_delay": 60.0,
"websocket_reconnect_jitter": 0.0,
"websocket_down_alert_repeat_minutes": 10,
}
defaults.update(overrides)
return Settings(**defaults)
class FakeWS:
"""Fake aiohttp ClientWebSocketResponse for unit tests."""
def __init__(self, auth_messages: list, event_messages: list | None = None):
self._auth_queue = list(auth_messages)
self._event_queue = list(event_messages or [])
self.sent: list = []
async def receive_json(self) -> dict:
if not self._auth_queue:
raise ConnectionError("FakeWS: no more auth messages")
return self._auth_queue.pop(0)
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def __aiter__(self):
return self
async def __anext__(self):
if not self._event_queue:
raise StopAsyncIteration
item = self._event_queue.pop(0)
if isinstance(item, BaseException):
raise item
return item
def _text_msg(data: str = '{"type":"event"}') -> aiohttp.WSMessage:
return aiohttp.WSMessage(type=aiohttp.WSMsgType.TEXT, data=data, extra=None)
def _close_msg() -> aiohttp.WSMessage:
return aiohttp.WSMessage(type=aiohttp.WSMsgType.CLOSE, data=b"", extra=None)
def _mock_session(fake_ws: FakeWS) -> MagicMock:
cm = MagicMock()
cm.__aenter__ = AsyncMock(return_value=fake_ws)
cm.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.ws_connect.return_value = cm
return session
def _make_monitor(
settings: Settings | None = None,
session=None,
emitter: EventEmitter | None = None,
tmp_path: Path | None = None,
) -> WebSocketMonitor:
if settings is None:
settings = _make_settings()
if emitter is None:
p = (tmp_path or Path("/tmp/ws_test_events")).absolute()
p.mkdir(parents=True, exist_ok=True)
emitter = EventEmitter(p, node_name="test-node")
if session is None:
session = MagicMock()
return WebSocketMonitor(
ha_url=settings.ha_url,
token=settings.ha_token,
settings=settings,
emitter=emitter,
session=session,
)
# ---------------------------------------------------------------------------
# URL derivation
# ---------------------------------------------------------------------------
def test_make_ws_url_http():
assert _make_ws_url("http://ha.local:8123") == "ws://ha.local:8123/api/websocket"
def test_make_ws_url_https():
assert _make_ws_url("https://ha.example.com") == "wss://ha.example.com/api/websocket"
def test_make_ws_url_strips_trailing_slash():
assert _make_ws_url("http://ha.local:8123/") == "ws://ha.local:8123/api/websocket"
# ---------------------------------------------------------------------------
# Auth flow (via _connect_and_listen)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_normal_auth_sends_correct_messages(tmp_path):
"""Happy path: sends auth + subscribe, ends in subscribed state."""
fake_ws = FakeWS(
[{"type": "auth_required"}, {"type": "auth_ok"}],
[_text_msg('{"type":"result","id":1,"success":true}')],
)
monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path)
await monitor._connect_and_listen()
assert fake_ws.sent[0] == {"type": "auth", "access_token": "test-token"}
assert fake_ws.sent[1]["type"] == "subscribe_events"
assert fake_ws.sent[1]["event_type"] == "state_changed"
assert monitor._state == "subscribed"
@pytest.mark.asyncio
async def test_last_event_monotonic_updated_on_text_message(tmp_path):
"""Receiving TEXT messages updates last_event_monotonic."""
fake_ws = FakeWS(
[{"type": "auth_required"}, {"type": "auth_ok"}],
[_text_msg(), _text_msg()],
)
monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path)
before = time.monotonic()
await monitor._connect_and_listen()
assert monitor._last_event_monotonic >= before
@pytest.mark.asyncio
async def test_auth_invalid_raises_auth_error(tmp_path):
"""auth_invalid → _AuthError propagates."""
fake_ws = FakeWS([
{"type": "auth_required"},
{"type": "auth_invalid", "message": "invalid token"},
])
monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path)
with pytest.raises(_AuthError, match="invalid token"):
await monitor._connect_and_listen()
@pytest.mark.asyncio
async def test_unexpected_initial_message_raises(tmp_path):
"""Anything other than auth_required on connect → ConnectionError."""
fake_ws = FakeWS([{"type": "unexpected"}])
monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path)
with pytest.raises(ConnectionError, match="Unexpected initial"):
await monitor._connect_and_listen()
@pytest.mark.asyncio
async def test_empty_auth_queue_raises_connection_error(tmp_path):
"""Connection closed before auth_required → ConnectionError."""
fake_ws = FakeWS([])
monitor = _make_monitor(session=_mock_session(fake_ws), tmp_path=tmp_path)
with pytest.raises(ConnectionError):
await monitor._connect_and_listen()
# ---------------------------------------------------------------------------
# Disconnect / dead alerts (_on_disconnected)
# ---------------------------------------------------------------------------
def test_on_disconnected_emits_ha_websocket_dead(tmp_path):
emitter = MagicMock()
monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path)
monitor._state = "disconnected"
monitor._on_disconnected()
emitter.emit.assert_called_once()
assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value
def test_on_disconnected_within_cooldown_suppresses_second_emit(tmp_path):
emitter = MagicMock()
monitor = _make_monitor(
settings=_make_settings(websocket_down_alert_repeat_minutes=10),
emitter=emitter,
tmp_path=tmp_path,
)
monitor._state = "disconnected"
monitor._on_disconnected() # first emit
emitter.emit.reset_mock()
monitor._on_disconnected() # within cooldown → suppressed
emitter.emit.assert_not_called()
def test_on_disconnected_after_cooldown_emits_again(tmp_path):
emitter = MagicMock()
monitor = _make_monitor(
settings=_make_settings(websocket_down_alert_repeat_minutes=10),
emitter=emitter,
tmp_path=tmp_path,
)
monitor._state = "disconnected"
monitor._on_disconnected()
# Backdate to simulate cooldown expiry
monitor._last_dead_alert_at = time.monotonic() - (10 * 60 + 5)
emitter.emit.reset_mock()
monitor._on_disconnected()
emitter.emit.assert_called_once()
def test_on_disconnected_noop_when_stopping(tmp_path):
emitter = MagicMock()
monitor = _make_monitor(emitter=emitter, tmp_path=tmp_path)
monitor._stopping = True
monitor._on_disconnected()
emitter.emit.assert_not_called()
# ---------------------------------------------------------------------------
# Recovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_after_dead_emits_recovered(tmp_path):
"""Successful reconnect after a dead alert emits ha_websocket_recovered."""
emitter = MagicMock()
fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], [])
settings = _make_settings()
monitor = WebSocketMonitor(
ha_url=settings.ha_url,
token=settings.ha_token,
settings=settings,
emitter=emitter,
session=_mock_session(fake_ws),
)
monitor._last_dead_alert_at = time.monotonic() - 30.0 # prior dead was sent
await monitor._connect_and_listen()
emitted_types = [c[1]["event_type"] for c in emitter.emit.call_args_list]
assert HAEventType.ha_websocket_recovered.value in emitted_types
assert monitor._last_dead_alert_at == 0.0 # reset after recovery
@pytest.mark.asyncio
async def test_no_recovered_if_no_prior_dead(tmp_path):
"""First-ever connect with no prior dead alert → no recovered emitted."""
emitter = MagicMock()
fake_ws = FakeWS([{"type": "auth_required"}, {"type": "auth_ok"}], [])
settings = _make_settings()
monitor = WebSocketMonitor(
ha_url=settings.ha_url,
token=settings.ha_token,
settings=settings,
emitter=emitter,
session=_mock_session(fake_ws),
)
monitor._last_dead_alert_at = 0.0
await monitor._connect_and_listen()
emitter.emit.assert_not_called()
# ---------------------------------------------------------------------------
# Watchdog loop
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_watchdog_emits_dead_when_silent_over_threshold(tmp_path):
"""Watchdog detects silence > threshold and emits ha_websocket_dead."""
emitter = MagicMock()
settings = _make_settings(
websocket_silence_threshold_seconds=60,
websocket_watchdog_interval_seconds=30,
websocket_down_alert_repeat_minutes=0,
)
monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path)
monitor._state = "subscribed"
monitor._last_event_monotonic = time.monotonic() - 120.0 # 120s > 60s threshold
monitor._last_dead_alert_at = 0.0
sleep_calls = 0
async def one_iteration(t):
nonlocal sleep_calls
sleep_calls += 1
if sleep_calls >= 2:
raise asyncio.CancelledError()
with patch("asyncio.sleep", side_effect=one_iteration):
with pytest.raises(asyncio.CancelledError):
await monitor._watchdog_loop()
emitter.emit.assert_called_once()
assert emitter.emit.call_args[1]["event_type"] == HAEventType.ha_websocket_dead.value
@pytest.mark.asyncio
async def test_watchdog_no_emit_when_events_recent(tmp_path):
"""Watchdog does not emit when last event is within silence threshold."""
emitter = MagicMock()
settings = _make_settings(
websocket_silence_threshold_seconds=300,
websocket_watchdog_interval_seconds=30,
websocket_down_alert_repeat_minutes=0,
)
monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path)
monitor._state = "subscribed"
monitor._last_event_monotonic = time.monotonic() - 10.0 # recent
sleep_calls = 0
async def one_iteration(t):
nonlocal sleep_calls
sleep_calls += 1
if sleep_calls >= 2:
raise asyncio.CancelledError()
with patch("asyncio.sleep", side_effect=one_iteration):
with pytest.raises(asyncio.CancelledError):
await monitor._watchdog_loop()
emitter.emit.assert_not_called()
@pytest.mark.asyncio
async def test_watchdog_skips_when_not_subscribed(tmp_path):
"""Watchdog does not emit when state is not 'subscribed'."""
emitter = MagicMock()
settings = _make_settings(
websocket_silence_threshold_seconds=1,
websocket_watchdog_interval_seconds=30,
websocket_down_alert_repeat_minutes=0,
)
monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path)
monitor._state = "disconnected"
monitor._last_event_monotonic = time.monotonic() - 9999.0 # very old
sleep_calls = 0
async def one_iteration(t):
nonlocal sleep_calls
sleep_calls += 1
if sleep_calls >= 2:
raise asyncio.CancelledError()
with patch("asyncio.sleep", side_effect=one_iteration):
with pytest.raises(asyncio.CancelledError):
await monitor._watchdog_loop()
emitter.emit.assert_not_called()
@pytest.mark.asyncio
async def test_watchdog_repeat_alert_respects_cooldown(tmp_path):
"""Second watchdog dead alert fires only after cooldown."""
emitter = MagicMock()
settings = _make_settings(
websocket_silence_threshold_seconds=60,
websocket_watchdog_interval_seconds=30,
websocket_down_alert_repeat_minutes=10,
)
monitor = _make_monitor(settings=settings, emitter=emitter, tmp_path=tmp_path)
monitor._state = "subscribed"
monitor._last_event_monotonic = time.monotonic() - 3600.0 # 1hr silent
# Set last alert to just now → still within 10-min cooldown
monitor._last_dead_alert_at = time.monotonic()
sleep_calls = 0
async def one_iteration(t):
nonlocal sleep_calls
sleep_calls += 1
if sleep_calls >= 2:
raise asyncio.CancelledError()
with patch("asyncio.sleep", side_effect=one_iteration):
with pytest.raises(asyncio.CancelledError):
await monitor._watchdog_loop()
emitter.emit.assert_not_called() # within cooldown
# ---------------------------------------------------------------------------
# Reconnect backoff
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_backoff_doubles_each_attempt(tmp_path):
"""Retry delay doubles on consecutive failures."""
delays: list[float] = []
call_count = 0
async def fail_connect():
nonlocal call_count
call_count += 1
raise ConnectionError("refused")
async def capture_sleep(t):
delays.append(t)
if call_count >= 3:
raise asyncio.CancelledError()
settings = _make_settings(
websocket_reconnect_initial_delay=1.0,
websocket_reconnect_max_delay=60.0,
websocket_reconnect_jitter=0.0,
)
monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path)
monitor._connect_and_listen = fail_connect
with patch("asyncio.sleep", side_effect=capture_sleep):
with pytest.raises(asyncio.CancelledError):
await monitor._connection_loop()
assert len(delays) >= 2
assert delays[0] == pytest.approx(1.0)
assert delays[1] == pytest.approx(2.0)
@pytest.mark.asyncio
async def test_reconnect_delay_capped_at_max(tmp_path):
"""Delay never exceeds websocket_reconnect_max_delay."""
delays: list[float] = []
call_count = 0
async def fail_connect():
nonlocal call_count
call_count += 1
raise ConnectionError("refused")
async def capture_sleep(t):
delays.append(t)
if call_count >= 8:
raise asyncio.CancelledError()
settings = _make_settings(
websocket_reconnect_initial_delay=1.0,
websocket_reconnect_max_delay=8.0,
websocket_reconnect_jitter=0.0,
)
monitor = _make_monitor(settings=settings, emitter=MagicMock(), tmp_path=tmp_path)
monitor._connect_and_listen = fail_connect
with patch("asyncio.sleep", side_effect=capture_sleep):
with pytest.raises(asyncio.CancelledError):
await monitor._connection_loop()
assert max(delays) <= 8.0
# ---------------------------------------------------------------------------
# is_healthy
# ---------------------------------------------------------------------------
def test_is_healthy_true_when_subscribed(tmp_path):
monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path)
monitor._state = "subscribed"
assert monitor.is_healthy is True
def test_is_healthy_false_when_disconnected(tmp_path):
monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path)
monitor._state = "disconnected"
assert monitor.is_healthy is False
def test_is_healthy_false_when_connecting(tmp_path):
monitor = _make_monitor(settings=_make_settings(websocket_enabled=True), tmp_path=tmp_path)
monitor._state = "connecting"
assert monitor.is_healthy is False
def test_is_healthy_true_when_disabled(tmp_path):
"""Disabled monitor reports healthy — it's off, not broken."""
monitor = _make_monitor(settings=_make_settings(websocket_enabled=False), tmp_path=tmp_path)
monitor._state = "disconnected"
assert monitor.is_healthy is True
# ---------------------------------------------------------------------------
# start / stop lifecycle
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stop_cancels_background_tasks(tmp_path):
"""stop() cancels the main and watchdog tasks."""
async def hang():
await asyncio.sleep(9999)
monitor = _make_monitor(tmp_path=tmp_path)
monitor._main_task = asyncio.create_task(hang())
monitor._watchdog_task = asyncio.create_task(hang())
await monitor.stop()
assert monitor._main_task is None
assert monitor._watchdog_task is None
@pytest.mark.asyncio
async def test_start_no_tasks_when_disabled(tmp_path):
"""start() with websocket_enabled=False does not spawn tasks."""
monitor = _make_monitor(
settings=_make_settings(websocket_enabled=False),
tmp_path=tmp_path,
)
await monitor.start()
assert monitor._main_task is None
assert monitor._watchdog_task is None