feat(ha-diag-agent): UnavailableEntitiesCheck with root cause dedup

- shared aiohttp ClientSession in HAClient (Phase 1 Flag #2 fixed):
  make_session() factory, session injected at startup, closed on shutdown
- Check.run() → list[CheckResult]: clean multi-event interface
- first real diagnostic check: entity unavailable > 24h
  (INSERT OR IGNORE baseline preserves first-seen timestamp)
- root cause grouping: emit ha_integration_failed instead of N entity
  events when ≥50% of integration's entities are unavailable (≥3 min)
- alert deduplication via SQLite cooldown window (default 6h)
- recovery clears baseline + dedup for immediate re-alert
- configurable thresholds: duration, integration %, cooldown
- 38 unit tests + 7 integration tests (42 pass, 3 skip w/o live HA)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Oskar Kapala 2026-05-29 13:41:55 +02:00
parent 07bd498fd6
commit 20f6761a67
15 changed files with 1154 additions and 167 deletions

View file

@ -4,13 +4,21 @@
# Home Assistant connection (required) # Home Assistant connection (required)
HA_URL=http://homeassistant.local:8123 HA_URL=http://homeassistant.local:8123
HA_TOKEN=your-long-lived-token-here HA_TOKEN=your-long-lived-token-here
HA_TIMEOUT=10.0
# Node identity # Node identity
NODE_NAME=piha NODE_NAME=piha
LOCATION_TAG=ken LOCATION_TAG=ken
# Timing # Check intervals (seconds)
CHECK_INTERVAL=60 CHECK_INTERVAL=60 # heartbeat check
CHECK_INTERVAL_UNAVAILABLE=3600 # entity availability check (1h)
# Unavailable entities thresholds
UNAVAILABLE_THRESHOLD_HOURS=24 # alert after N hours unavailable
INTEGRATION_FAILURE_THRESHOLD_PCT=0.5 # fraction of integration entities
INTEGRATION_FAILURE_MIN_ENTITIES=3 # minimum count for integration event
ALERT_COOLDOWN_HOURS=6 # suppress re-alert within N hours
# API server # API server
PORT=8087 PORT=8087

View file

@ -28,10 +28,15 @@ service:
runtime: runtime:
env_vars: env_vars:
- HA_TOKEN # long-lived HA access token (required) - HA_TOKEN # long-lived HA access token (required)
- HA_URL # http://homeassistant.local:8123 - HA_URL # http://homeassistant.local:8123
- NODE_NAME # canonical node name: piha, chelsty-infra, ... - NODE_NAME # canonical node name: piha, chelsty-infra
- LOCATION_TAG # human site label: ken, chelsty, ... - LOCATION_TAG # human site label: ken, chelsty
- CHECK_INTERVAL # seconds between check cycles (default: 60) - CHECK_INTERVAL # heartbeat interval seconds (default: 60)
- PORT # FastAPI port (default: 8087) - CHECK_INTERVAL_UNAVAILABLE # entity check interval seconds (default: 3600)
- LOG_LEVEL # default: info - UNAVAILABLE_THRESHOLD_HOURS # alert threshold (default: 24)
- INTEGRATION_FAILURE_THRESHOLD_PCT # fraction threshold (default: 0.5)
- INTEGRATION_FAILURE_MIN_ENTITIES # min count for integration event (default: 3)
- ALERT_COOLDOWN_HOURS # re-alert suppression (default: 6)
- PORT # FastAPI port (default: 8087)
- LOG_LEVEL # default: info

View file

@ -11,8 +11,10 @@ class Check(ABC):
name: str # unique slug used in /trigger/<name> and check_history name: str # unique slug used in /trigger/<name> and check_history
@abstractmethod @abstractmethod
async def run(self) -> CheckResult: async def run(self) -> list[CheckResult]:
"""Execute the check and return a result. """Execute the check and return results.
The caller is responsible for emitting events when result.event_type is set. Empty list means the check passed cleanly.
Each CheckResult with event_type set causes an event to be emitted.
The caller (runner in main.py) handles emission and history recording.
""" """

View file

@ -6,10 +6,9 @@ from .base import Check
class HeartbeatCheck(Check): class HeartbeatCheck(Check):
"""Pings HA /api/ to verify the API is reachable. """Pings HA /api/ to verify the REST API is reachable.
Validates the end-to-end pipeline: HA client check result event emitter. Validates the end-to-end pipeline: shared HAClient check event emitter.
Real diagnostic checks (entity availability, system health, etc.) come in Phase 3.
""" """
name = "heartbeat" name = "heartbeat"
@ -17,31 +16,23 @@ class HeartbeatCheck(Check):
def __init__(self, ha_client: HAClient) -> None: def __init__(self, ha_client: HAClient) -> None:
self._client = ha_client self._client = ha_client
async def run(self) -> CheckResult: async def run(self) -> list[CheckResult]:
try: try:
async with self._client: data = await self._client.get_api_status()
data = await self._client.get_api_status()
if isinstance(data, dict) and "message" in data: if isinstance(data, dict) and "message" in data:
return CheckResult( return []
healthy=True, return [CheckResult(
event_type=None,
severity=Severity.info,
message="HA API reachable",
payload={"response": data},
)
return CheckResult(
healthy=False, healthy=False,
event_type=HAEventType.ha_websocket_dead, event_type=HAEventType.ha_websocket_dead,
severity=Severity.error, severity=Severity.error,
message=f"HA API returned unexpected response: {data!r}", message=f"HA API returned unexpected response: {data!r}",
payload={"response": str(data)}, payload={"response": str(data)},
) )]
except Exception as exc: except Exception as exc:
return CheckResult( return [CheckResult(
healthy=False, healthy=False,
event_type=HAEventType.ha_websocket_dead, event_type=HAEventType.ha_websocket_dead,
severity=Severity.error, severity=Severity.error,
message=f"HA API unreachable: {exc}", message=f"HA API unreachable: {exc}",
payload={"error": str(exc)}, payload={"error": str(exc)},
) )]

View file

@ -0,0 +1,243 @@
from __future__ import annotations
import time
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from ..ha_client import HAClient
from ..models import CheckResult, HAEventType, Severity
from ..storage import Storage
from .base import Check
if TYPE_CHECKING:
from ..config import Settings
_BAD_STATES = frozenset({"unavailable", "unknown"})
class UnavailableEntitiesCheck(Check):
"""Detects entities stuck in unavailable/unknown state.
Logic:
1. Fetch all entity states from HA.
2. Maintain SQLite baseline: INSERT OR IGNORE to preserve first-seen timestamp.
3. Handle recoveries: clear baseline + alert dedup for entities back online.
4. Alert on entities unavailable > unavailable_threshold_hours.
5. Root-cause grouping: if >= integration_failure_threshold_pct of an
integration's entities are unavailable (and count >= min_entities), emit
ha_integration_failed instead of N individual ha_entity_unavailable_long
events.
6. Alert dedup: skip re-emitting the same alert within alert_cooldown_hours.
"""
name = "unavailable_entities"
def __init__(
self,
ha_client: HAClient,
storage: Storage,
settings: "Settings",
) -> None:
self._client = ha_client
self._storage = storage
self._settings = settings
# ------------------------------------------------------------------
# Public entry point
# ------------------------------------------------------------------
async def run(self) -> list[CheckResult]:
now = time.time()
try:
all_states = await self._client.get_states()
except Exception as exc:
return [CheckResult(
healthy=False,
event_type=HAEventType.ha_websocket_dead,
severity=Severity.error,
message=f"Failed to fetch entity states: {exc}",
payload={"error": str(exc)},
)]
integration_map, area_map = await self._load_registry()
unavailable: dict[str, dict[str, Any]] = {
s["entity_id"]: s for s in all_states if s["state"] in _BAD_STATES
}
available_ids: set[str] = {
s["entity_id"] for s in all_states if s["state"] not in _BAD_STATES
}
# Handle recoveries first
tracked = await self._storage.get_all_tracked_entity_ids()
for eid in tracked:
if eid in available_ids:
await self._handle_recovery(eid)
# Record new/continuing unavailable entities (INSERT OR IGNORE preserves timestamp)
for eid, state_data in unavailable.items():
await self._storage.set_entity_unavailable_since(
eid, state_data["state"], now
)
# Determine which entities have exceeded the alert threshold
to_alert: list[dict[str, Any]] = []
cooldown_s = self._settings.alert_cooldown_hours * 3600
threshold_h = self._settings.unavailable_threshold_hours
for eid, state_data in unavailable.items():
first_at = await self._storage.get_entity_first_unavailable_at(eid)
if first_at is None:
continue
duration_h = (now - first_at) / 3600
if duration_h < threshold_h:
continue
alert_key = f"entity_unavailable:{eid}"
if await self._storage.was_alert_sent(alert_key, cooldown_s):
continue
to_alert.append({
"entity_id": eid,
"state": state_data["state"],
"first_at": first_at,
"duration_h": duration_h,
"domain": eid.split(".")[0],
"integration": integration_map.get(eid),
"area_id": area_map.get(eid),
})
if not to_alert:
return []
return await self._build_results(to_alert, all_states, integration_map)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _load_registry(
self,
) -> tuple[dict[str, str], dict[str, str]]:
"""Fetch entity registry; return (integration_map, area_map).
Falls back to empty dicts when the endpoint is unavailable.
"""
try:
registry = await self._client.get_entity_registry()
integration_map = {
e["entity_id"]: e.get("platform") or ""
for e in registry
if "entity_id" in e
}
area_map = {
e["entity_id"]: e.get("area_id") or ""
for e in registry
if "entity_id" in e
}
return integration_map, area_map
except Exception:
return {}, {}
async def _handle_recovery(self, entity_id: str) -> None:
await self._storage.clear_entity_unavailable(entity_id)
# Clear dedup so the next unavailability triggers an alert immediately
await self._storage.clear_alert(f"entity_unavailable:{entity_id}")
async def _build_results(
self,
to_alert: list[dict[str, Any]],
all_states: list[dict[str, Any]],
integration_map: dict[str, str],
) -> list[CheckResult]:
results: list[CheckResult] = []
handled: set[str] = set()
# Build per-integration stats across ALL entities (not just to_alert)
total_per_integ: dict[str, int] = {}
unav_per_integ: dict[str, list[str]] = {}
for state in all_states:
eid = state["entity_id"]
integ = integration_map.get(eid)
if not integ:
continue
total_per_integ[integ] = total_per_integ.get(integ, 0) + 1
if state["state"] in _BAD_STATES:
unav_per_integ.setdefault(integ, []).append(eid)
min_ent = self._settings.integration_failure_min_entities
threshold_pct = self._settings.integration_failure_threshold_pct
cooldown_s = self._settings.alert_cooldown_hours * 3600
# Integration-level events
for integ, unav_ids in unav_per_integ.items():
total = total_per_integ.get(integ, 0)
pct = len(unav_ids) / total if total else 0
alerted_from_integ = [e for e in to_alert if e["integration"] == integ]
if not alerted_from_integ:
continue
if pct < threshold_pct or len(unav_ids) < min_ent:
continue
alert_key = f"integration_failed:{integ}"
if await self._storage.was_alert_sent(alert_key, cooldown_s):
handled.update(e["entity_id"] for e in alerted_from_integ)
continue
results.append(CheckResult(
healthy=False,
event_type=HAEventType.ha_integration_failed,
severity=Severity.error,
message=(
f"Integration '{integ}' appears down: "
f"{len(unav_ids)}/{total} entities unavailable"
),
payload={
"integration": integ,
"affected_entities": unav_ids,
"unavailable_count": len(unav_ids),
"total_count": total,
"unavailable_pct": round(pct, 2),
},
))
await self._storage.mark_alert_sent(alert_key)
handled.update(e["entity_id"] for e in alerted_from_integ)
# Per-entity events for entities not covered by an integration event
for entity in to_alert:
eid = entity["entity_id"]
if eid in handled:
continue
since_iso = (
datetime.fromtimestamp(entity["first_at"], tz=timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)
payload: dict[str, Any] = {
"entity_id": eid,
"state": entity["state"],
"since": since_iso,
"duration_hours": round(entity["duration_h"], 1),
"domain": entity["domain"],
}
if entity["integration"]:
payload["integration"] = entity["integration"]
if entity["area_id"]:
payload["area"] = entity["area_id"]
results.append(CheckResult(
healthy=False,
event_type=HAEventType.ha_entity_unavailable_long,
severity=Severity.warning,
message=(
f"Entity {eid} unavailable for "
f"{entity['duration_h']:.1f}h"
),
payload=payload,
))
await self._storage.mark_alert_sent(f"entity_unavailable:{eid}")
return results

View file

@ -11,13 +11,30 @@ _CONFIG_YAML = Path("/config/ha-diag-agent.yaml")
class Settings(BaseSettings): class Settings(BaseSettings):
# HA connection
ha_url: str = "http://homeassistant.local:8123" ha_url: str = "http://homeassistant.local:8123"
ha_token: str = "" ha_token: str = ""
ha_timeout: float = 10.0
# Node identity
node_name: str = "unknown" node_name: str = "unknown"
location_tag: str = "default" location_tag: str = "default"
check_interval: int = 60
# Intervals (seconds)
check_interval: int = 60 # heartbeat check interval
check_interval_unavailable: int = 3600 # unavailable entities check interval
# Unavailable entities check thresholds
unavailable_threshold_hours: float = 24.0 # alert after N hours unavailable
integration_failure_threshold_pct: float = 0.5 # % of integration entities unavailable
integration_failure_min_entities: int = 3 # min count to trigger integration event
alert_cooldown_hours: float = 6.0 # don't re-alert same entity within N hours
# API server
port: int = 8087 port: int = 8087
log_level: str = "info" log_level: str = "info"
# Runtime paths (inside container)
events_dir: Path = Path("/events") events_dir: Path = Path("/events")
data_dir: Path = Path("/data") data_dir: Path = Path("/data")

View file

@ -5,72 +5,74 @@ from typing import Any
import aiohttp import aiohttp
class HAClient: def make_session(token: str, timeout: float = 10.0) -> aiohttp.ClientSession:
"""Async Home Assistant REST API client using long-lived token auth.""" """Create a pre-configured ClientSession for use with HAClient."""
return aiohttp.ClientSession(
def __init__(self, base_url: str, token: str, timeout: float = 10.0) -> None: headers={
self._base_url = base_url.rstrip("/")
self._headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} },
self._timeout = aiohttp.ClientTimeout(total=timeout) timeout=aiohttp.ClientTimeout(total=timeout),
self._session: aiohttp.ClientSession | None = None )
async def __aenter__(self) -> "HAClient":
self._session = aiohttp.ClientSession(
headers=self._headers,
timeout=self._timeout,
)
return self
async def __aexit__(self, *_: Any) -> None: class HAClient:
if self._session: """Async Home Assistant REST API client.
await self._session.close()
self._session = None
def _session_or_raise(self) -> aiohttp.ClientSession: Session lifecycle is managed externally the caller creates the session
if self._session is None: via make_session() at startup and closes it on shutdown. HAClient is a
raise RuntimeError("HAClient must be used as an async context manager") session-borrower: it never opens or closes the session it receives.
return self._session """
def __init__(self, base_url: str, session: aiohttp.ClientSession) -> None:
self._base_url = base_url.rstrip("/")
self._session = session
async def get_api_status(self) -> dict[str, Any]: async def get_api_status(self) -> dict[str, Any]:
"""GET /api/ — returns {"message": "API running."} when HA is up.""" """GET /api/ — returns {"message": "API running."} when HA is up."""
async with self._session_or_raise().get(f"{self._base_url}/api/") as resp: async with self._session.get(f"{self._base_url}/api/") as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.json() return await resp.json()
async def get_states(self) -> list[dict[str, Any]]: async def get_states(self) -> list[dict[str, Any]]:
"""GET /api/states — full entity state list.""" """GET /api/states — full entity state list."""
async with self._session_or_raise().get(f"{self._base_url}/api/states") as resp: async with self._session.get(f"{self._base_url}/api/states") as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.json() return await resp.json()
async def get_system_health(self) -> dict[str, Any]: async def get_system_health(self) -> dict[str, Any]:
"""GET /api/system_health — per-integration health summary.""" """GET /api/system_health — per-integration health summary."""
async with self._session_or_raise().get( async with self._session.get(f"{self._base_url}/api/system_health") as resp:
f"{self._base_url}/api/system_health"
) as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.json() return await resp.json()
async def get_config(self) -> dict[str, Any]: async def get_config(self) -> dict[str, Any]:
"""GET /api/config — HA configuration including version.""" """GET /api/config — HA configuration including version."""
async with self._session_or_raise().get(f"{self._base_url}/api/config") as resp: async with self._session.get(f"{self._base_url}/api/config") as resp:
resp.raise_for_status()
return await resp.json()
async def get_entity_registry(self) -> list[dict[str, Any]]:
"""GET /api/config/entity_registry — entity registry entries.
Each entry includes entity_id, platform (integration name), area_id,
config_entry_id, and other metadata.
"""
async with self._session.get(
f"{self._base_url}/api/config/entity_registry"
) as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.json() return await resp.json()
async def get_automation_traces(self, automation_id: str) -> list[dict[str, Any]]: async def get_automation_traces(self, automation_id: str) -> list[dict[str, Any]]:
"""GET /api/trace/automation/<id> — last run traces for an automation.""" """GET /api/trace/automation/<id> — last run traces for an automation."""
url = f"{self._base_url}/api/trace/automation/{automation_id}" url = f"{self._base_url}/api/trace/automation/{automation_id}"
async with self._session_or_raise().get(url) as resp: async with self._session.get(url) as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.json() return await resp.json()
async def get_error_log(self) -> str: async def get_error_log(self) -> str:
"""GET /api/error_log — plaintext error log.""" """GET /api/error_log — plaintext error log."""
async with self._session_or_raise().get( async with self._session.get(f"{self._base_url}/api/error_log") as resp:
f"{self._base_url}/api/error_log"
) as resp:
resp.raise_for_status() resp.raise_for_status()
return await resp.text() return await resp.text()

View file

@ -12,9 +12,10 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
from .api import app, register_checks from .api import app, register_checks
from .checks.heartbeat import HeartbeatCheck from .checks.heartbeat import HeartbeatCheck
from .checks.unavailable_entities import UnavailableEntitiesCheck
from .config import Settings from .config import Settings
from .event_emitter import EventEmitter from .event_emitter import EventEmitter
from .ha_client import HAClient from .ha_client import HAClient, make_session
from .storage import Storage from .storage import Storage
_log = structlog.get_logger() _log = structlog.get_logger()
@ -34,32 +35,42 @@ def _configure_structlog(log_level: str) -> None:
logging.basicConfig(level=getattr(logging, log_level.upper(), logging.INFO)) logging.basicConfig(level=getattr(logging, log_level.upper(), logging.INFO))
async def _run_check_and_emit(check, emitter: EventEmitter, storage: Storage) -> None: async def _run_check_and_emit(
check, emitter: EventEmitter, storage: Storage
) -> None:
"""Run a check, emit events for each result, and record to check_history."""
try: try:
result = await check.run() results = await check.run()
healthy = not any(r.event_type for r in results)
summary = f"{len(results)} issue(s)" if results else "ok"
await storage.record_check( await storage.record_check(
check_name=check.name, check_name=check.name,
ran_at=time.time(), ran_at=time.time(),
healthy=result.healthy, healthy=healthy,
message=result.message, message=summary,
payload=json.dumps(result.payload), payload=json.dumps([r.model_dump() for r in results]),
) )
if result.event_type:
emitter.emit( for result in results:
event_type=result.event_type, if result.event_type:
severity=result.severity.value, emitter.emit(
service="homeassistant", event_type=result.event_type,
message=result.message, severity=result.severity.value,
payload=result.payload, service="homeassistant",
) message=result.message,
_log.warning( payload=result.payload,
"check_unhealthy", )
check=check.name, _log.warning(
event=result.event_type, "check_unhealthy",
msg=result.message, check=check.name,
) event=result.event_type,
else: msg=result.message,
)
if healthy:
_log.info("check_ok", check=check.name) _log.info("check_ok", check=check.name)
except Exception as exc: except Exception as exc:
_log.error("check_error", check=check.name, error=str(exc), exc_info=True) _log.error("check_error", check=check.name, error=str(exc), exc_info=True)
@ -71,30 +82,49 @@ async def run(settings: Settings) -> None:
node=settings.node_name, node=settings.node_name,
location=settings.location_tag, location=settings.location_tag,
ha_url=settings.ha_url, ha_url=settings.ha_url,
interval=settings.check_interval, heartbeat_interval=settings.check_interval,
unavailable_interval=settings.check_interval_unavailable,
) )
storage = Storage(settings.data_dir / "ha_diag.db") storage = Storage(settings.data_dir / "ha_diag.db")
await storage.open() await storage.open()
emitter = EventEmitter(settings.events_dir, settings.node_name, settings.location_tag) emitter = EventEmitter(settings.events_dir, settings.node_name, settings.location_tag)
ha_client = HAClient(settings.ha_url, settings.ha_token)
checks = [HeartbeatCheck(ha_client)] # Shared session — created once at startup, closed on shutdown
register_checks(checks, settings.node_name, settings.location_tag) session = make_session(settings.ha_token, settings.ha_timeout)
ha_client = HAClient(settings.ha_url, session)
heartbeat = HeartbeatCheck(ha_client)
unavailable = UnavailableEntitiesCheck(ha_client, storage, settings)
all_checks = [heartbeat, unavailable]
register_checks(all_checks, settings.node_name, settings.location_tag)
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
for check in checks: scheduler.add_job(
scheduler.add_job( _run_check_and_emit,
_run_check_and_emit, "interval",
"interval", seconds=settings.check_interval,
seconds=settings.check_interval, args=[heartbeat, emitter, storage],
args=[check, emitter, storage], id="check_heartbeat",
id=f"check_{check.name}", next_run_time=datetime.now(),
next_run_time=datetime.now(), )
) scheduler.add_job(
_run_check_and_emit,
"interval",
seconds=settings.check_interval_unavailable,
args=[unavailable, emitter, storage],
id="check_unavailable_entities",
next_run_time=datetime.now(),
)
scheduler.start() scheduler.start()
_log.info("scheduler_started", checks=[c.name for c in checks]) _log.info(
"scheduler_started",
checks=[c.name for c in all_checks],
heartbeat_interval=settings.check_interval,
unavailable_interval=settings.check_interval_unavailable,
)
config = uvicorn.Config( config = uvicorn.Config(
app, app,
@ -108,6 +138,7 @@ async def run(settings: Settings) -> None:
finally: finally:
scheduler.shutdown(wait=False) scheduler.shutdown(wait=False)
await storage.close() await storage.close()
await session.close()
def main() -> None: def main() -> None:

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -8,7 +9,11 @@ import aiosqlite
_SCHEMA = """ _SCHEMA = """
CREATE TABLE IF NOT EXISTS entity_baseline ( CREATE TABLE IF NOT EXISTS entity_baseline (
entity_id TEXT PRIMARY KEY, entity_id TEXT PRIMARY KEY,
-- state when entity first entered unavailable/unknown
state TEXT NOT NULL, state TEXT NOT NULL,
-- timestamp when the entity FIRST entered its current bad state (INSERT OR IGNORE)
first_seen REAL NOT NULL,
-- kept for legacy compat; not used by UnavailableEntitiesCheck
attributes TEXT NOT NULL DEFAULT '{}', attributes TEXT NOT NULL DEFAULT '{}',
updated_at REAL NOT NULL updated_at REAL NOT NULL
); );
@ -28,6 +33,10 @@ CREATE TABLE IF NOT EXISTS alerts_sent (
); );
""" """
_MIGRATE_ENTITY_BASELINE = """
ALTER TABLE entity_baseline ADD COLUMN first_seen REAL NOT NULL DEFAULT 0;
"""
class Storage: class Storage:
def __init__(self, db_path: Path) -> None: def __init__(self, db_path: Path) -> None:
@ -39,6 +48,11 @@ class Storage:
self._db = await aiosqlite.connect(self._db_path) self._db = await aiosqlite.connect(self._db_path)
self._db.row_factory = aiosqlite.Row self._db.row_factory = aiosqlite.Row
await self._db.executescript(_SCHEMA) await self._db.executescript(_SCHEMA)
# Add first_seen column to existing databases that pre-date Phase 3
try:
await self._db.execute(_MIGRATE_ENTITY_BASELINE)
except Exception:
pass # column already exists
await self._db.commit() await self._db.commit()
async def close(self) -> None: async def close(self) -> None:
@ -52,22 +66,66 @@ class Storage:
return self._db return self._db
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# entity_baseline # entity_baseline — tracks entities currently in bad state
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def set_entity_unavailable_since(
self, entity_id: str, state: str, first_seen: float
) -> None:
"""Record when an entity first entered unavailable/unknown state.
INSERT OR IGNORE: if the entity is already tracked, preserves the
original first_seen timestamp so duration is computed correctly.
"""
await self._conn().execute(
"""
INSERT OR IGNORE INTO entity_baseline
(entity_id, state, first_seen, attributes, updated_at)
VALUES (?, ?, ?, '{}', ?)
""",
(entity_id, state, first_seen, first_seen),
)
await self._conn().commit()
async def get_entity_first_unavailable_at(self, entity_id: str) -> float | None:
"""Return when the entity first entered its bad state, or None if not tracked."""
async with self._conn().execute(
"SELECT first_seen FROM entity_baseline WHERE entity_id = ?",
(entity_id,),
) as cur:
row = await cur.fetchone()
return float(row["first_seen"]) if row else None
async def clear_entity_unavailable(self, entity_id: str) -> None:
"""Remove entity from unavailable tracking (entity has recovered)."""
await self._conn().execute(
"DELETE FROM entity_baseline WHERE entity_id = ?",
(entity_id,),
)
await self._conn().commit()
async def get_all_tracked_entity_ids(self) -> list[str]:
"""Return all entity IDs currently tracked as unavailable/unknown."""
async with self._conn().execute(
"SELECT entity_id FROM entity_baseline"
) as cur:
rows = await cur.fetchall()
return [r["entity_id"] for r in rows]
# Legacy upsert — kept for backwards compat with existing callers
async def upsert_entity_baseline( async def upsert_entity_baseline(
self, entity_id: str, state: str, attributes: str, updated_at: float self, entity_id: str, state: str, attributes: str, updated_at: float
) -> None: ) -> None:
await self._conn().execute( await self._conn().execute(
""" """
INSERT INTO entity_baseline (entity_id, state, attributes, updated_at) INSERT INTO entity_baseline (entity_id, state, first_seen, attributes, updated_at)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(entity_id) DO UPDATE SET ON CONFLICT(entity_id) DO UPDATE SET
state = excluded.state, state = excluded.state,
attributes = excluded.attributes, attributes = excluded.attributes,
updated_at = excluded.updated_at updated_at = excluded.updated_at
""", """,
(entity_id, state, attributes, updated_at), (entity_id, state, updated_at, attributes, updated_at),
) )
await self._conn().commit() await self._conn().commit()
@ -104,8 +162,6 @@ class Storage:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def was_alert_sent(self, alert_key: str, within_seconds: float) -> bool: async def was_alert_sent(self, alert_key: str, within_seconds: float) -> bool:
import time
cutoff = time.time() - within_seconds cutoff = time.time() - within_seconds
async with self._conn().execute( async with self._conn().execute(
"SELECT sent_at FROM alerts_sent WHERE alert_key = ? AND sent_at > ?", "SELECT sent_at FROM alerts_sent WHERE alert_key = ? AND sent_at > ?",
@ -114,8 +170,6 @@ class Storage:
return (await cur.fetchone()) is not None return (await cur.fetchone()) is not None
async def mark_alert_sent(self, alert_key: str) -> None: async def mark_alert_sent(self, alert_key: str) -> None:
import time
await self._conn().execute( await self._conn().execute(
""" """
INSERT INTO alerts_sent (alert_key, sent_at) VALUES (?, ?) INSERT INTO alerts_sent (alert_key, sent_at) VALUES (?, ?)
@ -124,3 +178,10 @@ class Storage:
(alert_key, time.time()), (alert_key, time.time()),
) )
await self._conn().commit() await self._conn().commit()
async def clear_alert(self, alert_key: str) -> None:
"""Delete an alert record so the next occurrence triggers immediately."""
await self._conn().execute(
"DELETE FROM alerts_sent WHERE alert_key = ?", (alert_key,)
)
await self._conn().commit()

View file

@ -1,7 +1,6 @@
"""Shared fixtures for ha-diag-agent tests.""" """Shared fixtures for ha-diag-agent tests."""
from __future__ import annotations from __future__ import annotations
import os
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -14,7 +13,7 @@ from ha_diag.storage import Storage
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Event dir fixture # Filesystem fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -26,7 +25,7 @@ def tmp_events_dir(tmp_path: Path) -> Path:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Storage fixture (in-memory via tmp SQLite) # Storage fixture (tmp SQLite — fast, no mocking)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -55,9 +54,9 @@ def emitter(tmp_events_dir: Path) -> EventEmitter:
@pytest.fixture @pytest.fixture
def mock_ha_client(): def mock_ha_client():
"""HAClient mock that behaves as an async context manager.""" """Plain HAClient mock — no context manager, just async methods."""
client = MagicMock() client = MagicMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
client.get_api_status = AsyncMock(return_value={"message": "API running."}) client.get_api_status = AsyncMock(return_value={"message": "API running."})
client.get_states = AsyncMock(return_value=[])
client.get_entity_registry = AsyncMock(return_value=[])
return client return client

View file

@ -11,43 +11,49 @@ import pytest
from ha_diag.checks.heartbeat import HeartbeatCheck from ha_diag.checks.heartbeat import HeartbeatCheck
from ha_diag.event_emitter import EventEmitter from ha_diag.event_emitter import EventEmitter
from ha_diag.ha_client import HAClient from ha_diag.ha_client import HAClient, make_session
from ha_diag.models import HAEventType
@pytest.mark.integration @pytest.mark.integration
async def test_heartbeat_ken_healthy(ha_ken_url: str, ha_token: str, tmp_path): async def test_heartbeat_ken_healthy(ha_ken_url: str, ha_token: str):
client = HAClient(ha_ken_url, ha_token) async with make_session(ha_token) as session:
check = HeartbeatCheck(client) client = HAClient(ha_ken_url, session)
result = await check.run() check = HeartbeatCheck(client)
assert result.healthy is True, f"HA ken not healthy: {result.message}" results = await check.run()
assert result.event_type is None assert results == [], f"HA ken not healthy: {results}"
@pytest.mark.integration @pytest.mark.integration
async def test_heartbeat_chelsty_healthy(ha_chelsty_url: str, ha_token: str): async def test_heartbeat_chelsty_healthy(ha_chelsty_url: str, ha_token: str):
client = HAClient(ha_chelsty_url, ha_token) async with make_session(ha_token) as session:
check = HeartbeatCheck(client) client = HAClient(ha_chelsty_url, session)
result = await check.run() check = HeartbeatCheck(client)
assert result.healthy is True, f"HA chelsty not healthy: {result.message}" results = await check.run()
assert result.event_type is None assert results == [], f"HA chelsty not healthy: {results}"
@pytest.mark.integration @pytest.mark.integration
async def test_heartbeat_emits_event_on_failure(tmp_path): async def test_heartbeat_emits_event_on_failure():
client = HAClient("http://127.0.0.1:19999", "bad-token") # nothing here """Connecting to a closed port should yield ha_websocket_dead."""
check = HeartbeatCheck(client) async with make_session("bad-token") as session:
result = await check.run() client = HAClient("http://127.0.0.1:19999", session) # nothing here
assert result.healthy is False check = HeartbeatCheck(client)
assert result.event_type == "ha_websocket_dead" results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_websocket_dead
@pytest.mark.integration @pytest.mark.integration
async def test_heartbeat_event_written_to_filesystem(ha_ken_url: str, ha_token: str, tmp_path): async def test_heartbeat_event_written_to_filesystem(
ha_ken_url: str, ha_token: str, tmp_path
):
emitter = EventEmitter(tmp_path / "events", node_name="test-piha", location_tag="ken") emitter = EventEmitter(tmp_path / "events", node_name="test-piha", location_tag="ken")
client = HAClient(ha_ken_url, ha_token) async with make_session(ha_token) as session:
check = HeartbeatCheck(client) client = HAClient(ha_ken_url, session)
result = await check.run() check = HeartbeatCheck(client)
results = await check.run()
assert result.healthy is True # Healthy HA → no events
# No event emitted for a healthy result assert results == []
assert not list((tmp_path / "events").glob("*.json")) or result.event_type is None assert not list((tmp_path / "events").glob("*.json"))

View file

@ -0,0 +1,192 @@
"""Functional integration test for UnavailableEntitiesCheck.
Uses aioresponses for HA HTTP (controlled, deterministic) and real aiosqlite +
EventEmitter (tests the full agent pipeline end-to-end without a live HA).
Marked 'integration' because it exercises the complete multi-component stack.
For a live-HA variant, start the ken testenv Docker instances, set
TEST_HA_TOKEN, and extend with tests that call real HA endpoints.
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from ha_diag.checks.unavailable_entities import UnavailableEntitiesCheck
from ha_diag.config import Settings
from ha_diag.event_emitter import EventEmitter
from ha_diag.ha_client import HAClient, make_session
from ha_diag.models import HAEventType
from ha_diag.storage import Storage
HA_URL = "http://ha-test-ken:8123"
def _settings(**overrides) -> Settings:
defaults: dict = {
"ha_url": HA_URL,
"ha_token": "test-token",
"node_name": "piha",
"location_tag": "ken",
"unavailable_threshold_hours": 0.0,
"integration_failure_threshold_pct": 0.5,
"integration_failure_min_entities": 3,
"alert_cooldown_hours": 0.0,
"check_interval": 60,
"check_interval_unavailable": 3600,
}
defaults.update(overrides)
return Settings(**defaults)
@pytest_asyncio.fixture
async def storage(tmp_path: Path) -> AsyncGenerator[Storage, None]:
s = Storage(tmp_path / "integration_test.db")
await s.open()
yield s
await s.close()
@pytest.fixture
def events_dir(tmp_path: Path) -> Path:
d = tmp_path / "events"
d.mkdir()
return d
@pytest.mark.integration
async def test_full_pipeline_integration_event(storage: Storage, events_dir: Path):
"""3/3 zha entities unavailable → ha_integration_failed, 1 event file on disk."""
unavailable_entities = [
{"entity_id": f"light.test_{i}", "state": "unavailable", "attributes": {}}
for i in range(3)
]
available_entities = [{"entity_id": "sensor.ok", "state": "on", "attributes": {}}]
all_states = unavailable_entities + available_entities
registry = [
{"entity_id": e["entity_id"], "platform": "zha", "area_id": "living_room"}
for e in unavailable_entities
]
for e in unavailable_entities:
await storage.set_entity_unavailable_since(
e["entity_id"], "unavailable", time.time() - 25 * 3600
)
emitter = EventEmitter(events_dir, node_name="piha", location_tag="ken")
with aioresponses() as m:
m.get(f"{HA_URL}/api/states", payload=all_states)
m.get(f"{HA_URL}/api/config/entity_registry", payload=registry)
async with make_session("test-token") as session:
client = HAClient(HA_URL, session)
check = UnavailableEntitiesCheck(client, storage, _settings())
results = await check.run()
# 3/3 zha entities (100% >= 50%, count 3 >= 3) → integration event
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_integration_failed
assert results[0].payload["integration"] == "zha"
emitter.emit(
event_type=results[0].event_type,
severity=results[0].severity.value,
service="homeassistant",
message=results[0].message,
payload=results[0].payload,
)
event_files = list(events_dir.glob("*.json"))
assert len(event_files) == 1
event_data = json.loads(event_files[0].read_text())
assert event_data["node"] == "piha"
assert event_data["payload"]["location_tag"] == "ken"
assert event_data["payload"]["integration"] == "zha"
assert event_data["type"] == "ha_integration_failed"
@pytest.mark.integration
async def test_full_pipeline_individual_entity_events(
storage: Storage, events_dir: Path
):
"""2 unavailable entities from different integrations → 2 individual events."""
states = [
{"entity_id": "light.zha_one", "state": "unavailable", "attributes": {}},
{"entity_id": "sensor.mqtt_one", "state": "unavailable", "attributes": {}},
{"entity_id": "switch.ok", "state": "on", "attributes": {}},
]
registry = [
{"entity_id": "light.zha_one", "platform": "zha", "area_id": ""},
{"entity_id": "sensor.mqtt_one", "platform": "mqtt", "area_id": ""},
]
for e in ["light.zha_one", "sensor.mqtt_one"]:
await storage.set_entity_unavailable_since(e, "unavailable", time.time() - 25 * 3600)
emitter = EventEmitter(events_dir, node_name="piha", location_tag="ken")
with aioresponses() as m:
m.get(f"{HA_URL}/api/states", payload=states)
m.get(f"{HA_URL}/api/config/entity_registry", payload=registry)
async with make_session("test-token") as session:
client = HAClient(HA_URL, session)
check = UnavailableEntitiesCheck(client, storage, _settings())
results = await check.run()
# Both integrations have only 1 entity each → below min_entities threshold
assert len(results) == 2
assert all(r.event_type == HAEventType.ha_entity_unavailable_long for r in results)
for result in results:
emitter.emit(
event_type=result.event_type,
severity=result.severity.value,
service="homeassistant",
message=result.message,
payload=result.payload,
)
files = list(events_dir.glob("*.json"))
assert len(files) == 2
for f in files:
data = json.loads(f.read_text())
assert data["payload"]["location_tag"] == "ken"
assert "entity_id" in data["payload"]
assert "since" in data["payload"]
assert data["payload"]["since"].endswith("Z")
@pytest.mark.integration
async def test_recovery_removes_tracking(storage: Storage, events_dir: Path):
"""Entity recovers between check cycles → baseline cleared, no event next cycle."""
eid = "light.recoverable"
await storage.set_entity_unavailable_since(eid, "unavailable", time.time() - 25 * 3600)
# Cycle 1: entity unavailable → event
states_cycle1 = [{"entity_id": eid, "state": "unavailable", "attributes": {}}]
with aioresponses() as m:
m.get(f"{HA_URL}/api/states", payload=states_cycle1)
m.get(f"{HA_URL}/api/config/entity_registry", payload=[])
async with make_session("test-token") as session:
client = HAClient(HA_URL, session)
check = UnavailableEntitiesCheck(client, storage, _settings())
results1 = await check.run()
assert len(results1) == 1
# Cycle 2: entity recovered → no event, baseline cleared
states_cycle2 = [{"entity_id": eid, "state": "on", "attributes": {}}]
with aioresponses() as m:
m.get(f"{HA_URL}/api/states", payload=states_cycle2)
m.get(f"{HA_URL}/api/config/entity_registry", payload=[])
async with make_session("test-token") as session:
client = HAClient(HA_URL, session)
check2 = UnavailableEntitiesCheck(client, storage, _settings())
results2 = await check2.run()
assert results2 == []
assert await storage.get_entity_first_unavailable_at(eid) is None

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import pytest import pytest
from aioresponses import aioresponses from aioresponses import aioresponses
from ha_diag.ha_client import HAClient from ha_diag.ha_client import HAClient, make_session
HA_URL = "http://homeassistant.test:8123" HA_URL = "http://homeassistant.test:8123"
TOKEN = "test-token" TOKEN = "test-token"
@ -14,7 +14,8 @@ TOKEN = "test-token"
async def test_get_api_status_ok(): async def test_get_api_status_ok():
with aioresponses() as m: with aioresponses() as m:
m.get(f"{HA_URL}/api/", payload={"message": "API running."}) m.get(f"{HA_URL}/api/", payload={"message": "API running."})
async with HAClient(HA_URL, TOKEN) as client: async with make_session(TOKEN) as session:
client = HAClient(HA_URL, session)
result = await client.get_api_status() result = await client.get_api_status()
assert result == {"message": "API running."} assert result == {"message": "API running."}
@ -23,7 +24,8 @@ async def test_get_api_status_ok():
async def test_get_api_status_unauthorized(): async def test_get_api_status_unauthorized():
with aioresponses() as m: with aioresponses() as m:
m.get(f"{HA_URL}/api/", status=401) m.get(f"{HA_URL}/api/", status=401)
async with HAClient(HA_URL, TOKEN) as client: async with make_session(TOKEN) as session:
client = HAClient(HA_URL, session)
with pytest.raises(Exception): with pytest.raises(Exception):
await client.get_api_status() await client.get_api_status()
@ -33,7 +35,8 @@ async def test_get_states_returns_list():
payload = [{"entity_id": "light.living_room", "state": "on"}] payload = [{"entity_id": "light.living_room", "state": "on"}]
with aioresponses() as m: with aioresponses() as m:
m.get(f"{HA_URL}/api/states", payload=payload) m.get(f"{HA_URL}/api/states", payload=payload)
async with HAClient(HA_URL, TOKEN) as client: async with make_session(TOKEN) as session:
client = HAClient(HA_URL, session)
states = await client.get_states() states = await client.get_states()
assert isinstance(states, list) assert isinstance(states, list)
assert states[0]["entity_id"] == "light.living_room" assert states[0]["entity_id"] == "light.living_room"
@ -44,13 +47,34 @@ async def test_get_config_returns_dict():
payload = {"version": "2024.1.0", "location_name": "Home"} payload = {"version": "2024.1.0", "location_name": "Home"}
with aioresponses() as m: with aioresponses() as m:
m.get(f"{HA_URL}/api/config", payload=payload) m.get(f"{HA_URL}/api/config", payload=payload)
async with HAClient(HA_URL, TOKEN) as client: async with make_session(TOKEN) as session:
client = HAClient(HA_URL, session)
config = await client.get_config() config = await client.get_config()
assert config["version"] == "2024.1.0" assert config["version"] == "2024.1.0"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_required_without_context_manager(): async def test_get_entity_registry_returns_list():
client = HAClient(HA_URL, TOKEN) payload = [
with pytest.raises(RuntimeError, match="context manager"): {"entity_id": "light.hall", "platform": "zha", "area_id": "hallway"},
await client.get_api_status() {"entity_id": "sensor.temp", "platform": "mqtt", "area_id": None},
]
with aioresponses() as m:
m.get(f"{HA_URL}/api/config/entity_registry", payload=payload)
async with make_session(TOKEN) as session:
client = HAClient(HA_URL, session)
registry = await client.get_entity_registry()
assert len(registry) == 2
assert registry[0]["platform"] == "zha"
@pytest.mark.asyncio
async def test_make_session_sets_auth_header():
"""make_session injects the Bearer token in all requests."""
with aioresponses() as m:
m.get(f"{HA_URL}/api/", payload={"message": "API running."})
async with make_session("my-secret-token") as session:
client = HAClient(HA_URL, session)
await client.get_api_status()
# Verify the Authorization header was sent
assert session.headers.get("Authorization") == "Bearer my-secret-token"

View file

@ -11,8 +11,6 @@ from ha_diag.models import HAEventType, Severity
def _make_client(api_status=None, side_effect=None): def _make_client(api_status=None, side_effect=None):
client = MagicMock() client = MagicMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
if side_effect: if side_effect:
client.get_api_status = AsyncMock(side_effect=side_effect) client.get_api_status = AsyncMock(side_effect=side_effect)
else: else:
@ -21,45 +19,44 @@ def _make_client(api_status=None, side_effect=None):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_heartbeat_ok(): async def test_heartbeat_ok_returns_empty_list():
client = _make_client(api_status={"message": "API running."}) client = _make_client(api_status={"message": "API running."})
check = HeartbeatCheck(client) check = HeartbeatCheck(client)
result = await check.run() results = await check.run()
assert result.healthy is True assert results == []
assert result.event_type is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_heartbeat_connection_error(): async def test_heartbeat_connection_error():
client = _make_client(side_effect=ConnectionError("refused")) client = _make_client(side_effect=ConnectionError("refused"))
check = HeartbeatCheck(client) check = HeartbeatCheck(client)
result = await check.run() results = await check.run()
assert result.healthy is False assert len(results) == 1
assert result.event_type == HAEventType.ha_websocket_dead assert results[0].healthy is False
assert result.severity == Severity.error assert results[0].event_type == HAEventType.ha_websocket_dead
assert "refused" in result.message assert results[0].severity == Severity.error
assert "refused" in results[0].message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_heartbeat_unexpected_response(): async def test_heartbeat_unexpected_response():
client = _make_client(api_status={"unexpected": "key"}) client = _make_client(api_status={"unexpected": "key"})
check = HeartbeatCheck(client) check = HeartbeatCheck(client)
result = await check.run() results = await check.run()
assert result.healthy is False assert len(results) == 1
assert result.event_type == HAEventType.ha_websocket_dead assert results[0].event_type == HAEventType.ha_websocket_dead
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_heartbeat_timeout(): async def test_heartbeat_timeout():
client = _make_client(side_effect=TimeoutError("timed out")) client = _make_client(side_effect=TimeoutError("timed out"))
check = HeartbeatCheck(client) check = HeartbeatCheck(client)
result = await check.run() results = await check.run()
assert result.healthy is False assert len(results) == 1
assert result.event_type == HAEventType.ha_websocket_dead assert results[0].event_type == HAEventType.ha_websocket_dead
assert "timed out" in result.message assert "timed out" in results[0].message
def test_heartbeat_check_name(): def test_heartbeat_check_name():
client = MagicMock() check = HeartbeatCheck(MagicMock())
check = HeartbeatCheck(client)
assert check.name == "heartbeat" assert check.name == "heartbeat"

View file

@ -0,0 +1,409 @@
"""Unit tests for UnavailableEntitiesCheck."""
from __future__ import annotations
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from ha_diag.checks.unavailable_entities import UnavailableEntitiesCheck
from ha_diag.config import Settings
from ha_diag.models import HAEventType
from ha_diag.storage import Storage
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_settings(**overrides) -> Settings:
"""Settings with safe test defaults (alert immediately, no cooldown)."""
defaults: dict = {
"ha_url": "http://test.local:8123",
"ha_token": "test",
"node_name": "test-node",
"location_tag": "test-loc",
"unavailable_threshold_hours": 0.0, # alert immediately
"integration_failure_threshold_pct": 0.5,
"integration_failure_min_entities": 3,
"alert_cooldown_hours": 0.0, # no dedup window in most tests
"check_interval": 60,
"check_interval_unavailable": 3600,
}
defaults.update(overrides)
return Settings(**defaults)
def _make_state(entity_id: str, state: str = "on") -> dict:
return {"entity_id": entity_id, "state": state, "attributes": {}}
def _make_registry_entry(entity_id: str, platform: str, area_id: str = "") -> dict:
return {"entity_id": entity_id, "platform": platform, "area_id": area_id}
def _make_client(states=None, registry=None, states_error=None):
client = MagicMock()
if states_error:
client.get_states = AsyncMock(side_effect=states_error)
else:
client.get_states = AsyncMock(return_value=states or [])
client.get_entity_registry = AsyncMock(return_value=registry or [])
return client
# ---------------------------------------------------------------------------
# Basic unavailability detection
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_unavailable_entities_returns_empty(storage: Storage):
states = [_make_state("light.a", "on"), _make_state("sensor.b", "off")]
check = UnavailableEntitiesCheck(_make_client(states), storage, _make_settings())
assert await check.run() == []
@pytest.mark.asyncio
async def test_first_cycle_records_baseline_no_event(storage: Storage):
"""First observation of unavailable entity: record, don't alert yet."""
states = [_make_state("light.kitchen", "unavailable")]
settings = _make_settings(unavailable_threshold_hours=1.0) # needs 1h before alert
check = UnavailableEntitiesCheck(_make_client(states), storage, settings)
results = await check.run()
assert results == []
# Baseline should be recorded
first_at = await storage.get_entity_first_unavailable_at("light.kitchen")
assert first_at is not None
@pytest.mark.asyncio
async def test_unavailable_below_threshold_no_event(storage: Storage):
states = [_make_state("light.kitchen", "unavailable")]
settings = _make_settings(unavailable_threshold_hours=24.0)
check = UnavailableEntitiesCheck(_make_client(states), storage, settings)
# Seed the baseline as if entity just became unavailable
await storage.set_entity_unavailable_since("light.kitchen", "unavailable", time.time())
results = await check.run()
assert results == []
@pytest.mark.asyncio
async def test_unavailable_above_threshold_emits_event(storage: Storage):
states = [_make_state("light.kitchen", "unavailable")]
check = UnavailableEntitiesCheck(
_make_client(states), storage, _make_settings()
)
# Seed baseline as if 25h ago
await storage.set_entity_unavailable_since(
"light.kitchen", "unavailable", time.time() - 25 * 3600
)
results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_entity_unavailable_long
assert results[0].payload["entity_id"] == "light.kitchen"
assert results[0].payload["duration_hours"] == pytest.approx(25.0, abs=0.1)
assert results[0].payload["domain"] == "light"
@pytest.mark.asyncio
async def test_unknown_state_treated_as_unavailable(storage: Storage):
states = [_make_state("sensor.temp", "unknown")]
await storage.set_entity_unavailable_since(
"sensor.temp", "unknown", time.time() - 25 * 3600
)
check = UnavailableEntitiesCheck(
_make_client(states), storage, _make_settings()
)
results = await check.run()
assert len(results) == 1
assert results[0].payload["state"] == "unknown"
@pytest.mark.asyncio
async def test_payload_contains_since_timestamp(storage: Storage):
first_at = time.time() - 27 * 3600
await storage.set_entity_unavailable_since("light.k", "unavailable", first_at)
states = [_make_state("light.k", "unavailable")]
check = UnavailableEntitiesCheck(
_make_client(states), storage, _make_settings()
)
results = await check.run()
assert len(results) == 1
assert "since" in results[0].payload
assert "Z" in results[0].payload["since"] # ISO UTC timestamp
# ---------------------------------------------------------------------------
# Recovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_recovery_clears_baseline(storage: Storage):
await storage.set_entity_unavailable_since("light.k", "unavailable", time.time())
# Entity is now back online
states = [_make_state("light.k", "on")]
check = UnavailableEntitiesCheck(
_make_client(states), storage, _make_settings()
)
await check.run()
assert await storage.get_entity_first_unavailable_at("light.k") is None
@pytest.mark.asyncio
async def test_recovery_clears_alert_dedup(storage: Storage):
await storage.set_entity_unavailable_since(
"light.k", "unavailable", time.time() - 25 * 3600
)
await storage.mark_alert_sent("entity_unavailable:light.k")
# Entity recovers
states = [_make_state("light.k", "on")]
check = UnavailableEntitiesCheck(
_make_client(states), storage, _make_settings()
)
await check.run()
# Alert dedup should be gone
assert not await storage.was_alert_sent("entity_unavailable:light.k", 9999)
# ---------------------------------------------------------------------------
# Alert cooldown / deduplication
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cooldown_prevents_duplicate_event(storage: Storage):
await storage.set_entity_unavailable_since(
"light.k", "unavailable", time.time() - 25 * 3600
)
settings = _make_settings(alert_cooldown_hours=6.0)
states = [_make_state("light.k", "unavailable")]
check = UnavailableEntitiesCheck(_make_client(states), storage, settings)
results1 = await check.run()
assert len(results1) == 1 # first alert fires
results2 = await check.run()
assert results2 == [] # cooldown active
@pytest.mark.asyncio
async def test_no_cooldown_allows_repeat_event(storage: Storage):
await storage.set_entity_unavailable_since(
"light.k", "unavailable", time.time() - 25 * 3600
)
settings = _make_settings(alert_cooldown_hours=0.0)
states = [_make_state("light.k", "unavailable")]
check = UnavailableEntitiesCheck(_make_client(states), storage, settings)
results1 = await check.run()
results2 = await check.run()
assert len(results1) == 1
assert len(results2) == 1
# ---------------------------------------------------------------------------
# Integration root-cause grouping
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_integration_failure_emits_single_event(storage: Storage):
"""5/8 entities from zha unavailable → ha_integration_failed, not 5 entity events."""
zha_entities = [f"light.zha_{i}" for i in range(8)]
states = [
_make_state(eid, "unavailable" if i < 5 else "on")
for i, eid in enumerate(zha_entities)
]
registry = [_make_registry_entry(eid, "zha") for eid in zha_entities]
# Seed baselines for unavailable entities as 25h ago
for eid in zha_entities[:5]:
await storage.set_entity_unavailable_since(eid, "unavailable", time.time() - 25 * 3600)
settings = _make_settings(
integration_failure_threshold_pct=0.5,
integration_failure_min_entities=3,
)
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, settings
)
results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_integration_failed
assert results[0].payload["integration"] == "zha"
assert results[0].payload["unavailable_count"] == 5
assert results[0].payload["total_count"] == 8
assert set(results[0].payload["affected_entities"]) == set(zha_entities[:5])
@pytest.mark.asyncio
async def test_integration_failure_below_pct_threshold(storage: Storage):
"""2/8 entities from zha unavailable (25%) → per-entity events, not integration event."""
zha_entities = [f"light.zha_{i}" for i in range(8)]
states = [
_make_state(eid, "unavailable" if i < 2 else "on")
for i, eid in enumerate(zha_entities)
]
registry = [_make_registry_entry(eid, "zha") for eid in zha_entities]
for eid in zha_entities[:2]:
await storage.set_entity_unavailable_since(eid, "unavailable", time.time() - 25 * 3600)
settings = _make_settings(
integration_failure_threshold_pct=0.5,
integration_failure_min_entities=3,
)
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, settings
)
results = await check.run()
# Below count threshold (2 < 3) so individual events
assert all(r.event_type == HAEventType.ha_entity_unavailable_long for r in results)
assert len(results) == 2
@pytest.mark.asyncio
async def test_integration_failure_below_count_threshold(storage: Storage):
"""3/6 entities unavailable (50%) but min_entities=5 → per-entity events."""
zha_entities = [f"light.zha_{i}" for i in range(6)]
states = [
_make_state(eid, "unavailable" if i < 3 else "on")
for i, eid in enumerate(zha_entities)
]
registry = [_make_registry_entry(eid, "zha") for eid in zha_entities]
for eid in zha_entities[:3]:
await storage.set_entity_unavailable_since(eid, "unavailable", time.time() - 25 * 3600)
settings = _make_settings(
integration_failure_threshold_pct=0.5,
integration_failure_min_entities=5, # need 5, only have 3
)
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, settings
)
results = await check.run()
assert all(r.event_type == HAEventType.ha_entity_unavailable_long for r in results)
@pytest.mark.asyncio
async def test_entity_without_integration_gets_individual_event(storage: Storage):
"""Entity not in entity registry gets per-entity event regardless of integration grouping."""
await storage.set_entity_unavailable_since(
"light.mystery", "unavailable", time.time() - 25 * 3600
)
states = [_make_state("light.mystery", "unavailable")]
# Empty registry — no integration info
check = UnavailableEntitiesCheck(
_make_client(states, []), storage, _make_settings()
)
results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_entity_unavailable_long
assert "integration" not in results[0].payload
@pytest.mark.asyncio
async def test_mixed_integrations_correctly_partitioned(storage: Storage):
"""5 zha entities unavailable (triggers integration event) + 1 mqtt entity (individual)."""
zha_entities = [f"light.zha_{i}" for i in range(8)]
mqtt_entity = "sensor.mqtt_temp"
all_entities = zha_entities + [mqtt_entity]
states = (
[_make_state(eid, "unavailable" if i < 5 else "on") for i, eid in enumerate(zha_entities)]
+ [_make_state(mqtt_entity, "unavailable")]
)
registry = (
[_make_registry_entry(eid, "zha") for eid in zha_entities]
+ [_make_registry_entry(mqtt_entity, "mqtt")]
)
for eid in zha_entities[:5]:
await storage.set_entity_unavailable_since(eid, "unavailable", time.time() - 25 * 3600)
await storage.set_entity_unavailable_since(mqtt_entity, "unavailable", time.time() - 25 * 3600)
settings = _make_settings(
integration_failure_threshold_pct=0.5,
integration_failure_min_entities=3,
)
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, settings
)
results = await check.run()
event_types = {r.event_type for r in results}
assert HAEventType.ha_integration_failed in event_types
assert HAEventType.ha_entity_unavailable_long in event_types
# Exactly 2 events: 1 integration + 1 individual mqtt entity
assert len(results) == 2
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ha_client_error_returns_dead_event(storage: Storage):
client = _make_client(states_error=ConnectionError("HA down"))
check = UnavailableEntitiesCheck(client, storage, _make_settings())
results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_websocket_dead
@pytest.mark.asyncio
async def test_registry_failure_falls_back_gracefully(storage: Storage):
"""Registry endpoint failure → individual entity events without integration info."""
states = [_make_state("light.k", "unavailable")]
client = _make_client(states)
client.get_entity_registry = AsyncMock(side_effect=Exception("registry unavailable"))
await storage.set_entity_unavailable_since(
"light.k", "unavailable", time.time() - 25 * 3600
)
check = UnavailableEntitiesCheck(client, storage, _make_settings())
results = await check.run()
assert len(results) == 1
assert results[0].event_type == HAEventType.ha_entity_unavailable_long
assert "integration" not in results[0].payload
# ---------------------------------------------------------------------------
# Area / integration in payload
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_area_included_in_payload_when_known(storage: Storage):
await storage.set_entity_unavailable_since(
"light.hall", "unavailable", time.time() - 25 * 3600
)
states = [_make_state("light.hall", "unavailable")]
registry = [_make_registry_entry("light.hall", "zha", "hallway")]
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, _make_settings()
)
results = await check.run()
assert len(results) == 1
assert results[0].payload.get("area") == "hallway"
assert results[0].payload.get("integration") == "zha"
@pytest.mark.asyncio
async def test_area_omitted_when_unknown(storage: Storage):
await storage.set_entity_unavailable_since(
"light.k", "unavailable", time.time() - 25 * 3600
)
states = [_make_state("light.k", "unavailable")]
registry = [_make_registry_entry("light.k", "zha", "")]
check = UnavailableEntitiesCheck(
_make_client(states, registry), storage, _make_settings()
)
results = await check.run()
assert "area" not in results[0].payload