from __future__ import annotations import time from typing import Any import aiohttp def make_session(token: str, timeout: float = 10.0) -> aiohttp.ClientSession: """Create a pre-configured ClientSession for use with HAClient.""" return aiohttp.ClientSession( headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, timeout=aiohttp.ClientTimeout(total=timeout), ) class HAClient: """Async Home Assistant REST API client. Session lifecycle is managed externally — the caller creates the session via make_session() at startup and closes it on shutdown. HAClient is a session-borrower: it never opens or closes the session it receives. """ def __init__( self, base_url: str, session: aiohttp.ClientSession, entity_registry_cache_ttl: float = 300.0, ) -> None: self._base_url = base_url.rstrip("/") self._session = session self._registry_cache_ttl = entity_registry_cache_ttl self._registry_cache: list[dict[str, Any]] | None = None self._registry_fetched_at: float = 0.0 async def get_api_status(self) -> dict[str, Any]: """GET /api/ — returns {"message": "API running."} when HA is up.""" async with self._session.get(f"{self._base_url}/api/") as resp: resp.raise_for_status() return await resp.json() async def get_states(self) -> list[dict[str, Any]]: """GET /api/states — full entity state list.""" async with self._session.get(f"{self._base_url}/api/states") as resp: resp.raise_for_status() return await resp.json() async def get_system_health(self) -> dict[str, Any]: """GET /api/system_health — per-integration health summary.""" async with self._session.get(f"{self._base_url}/api/system_health") as resp: resp.raise_for_status() return await resp.json() async def get_config(self) -> dict[str, Any]: """GET /api/config — HA configuration including version.""" async with self._session.get(f"{self._base_url}/api/config") as resp: resp.raise_for_status() return await resp.json() async def get_entity_registry(self) -> list[dict[str, Any]]: """GET /api/config/entity_registry — entity registry entries. Each entry includes entity_id, platform (integration name), area_id, config_entry_id, and other metadata. Result is cached in-process for entity_registry_cache_ttl seconds to avoid hammering HA on every check cycle (Phase 3 Flag #3). """ now = time.monotonic() if ( self._registry_cache is not None and (now - self._registry_fetched_at) < self._registry_cache_ttl ): return self._registry_cache async with self._session.get( f"{self._base_url}/api/config/entity_registry" ) as resp: resp.raise_for_status() result = await resp.json() self._registry_cache = result self._registry_fetched_at = now return result def invalidate_registry_cache(self) -> None: """Force the next get_entity_registry() call to fetch fresh data.""" self._registry_cache = None self._registry_fetched_at = 0.0 async def get_automation_traces(self, automation_id: str) -> list[dict[str, Any]]: """GET /api/trace/automation/ — last run traces for an automation.""" url = f"{self._base_url}/api/trace/automation/{automation_id}" async with self._session.get(url) as resp: resp.raise_for_status() return await resp.json() async def get_error_log(self) -> str: """GET /api/error_log — plaintext error log.""" async with self._session.get(f"{self._base_url}/api/error_log") as resp: resp.raise_for_status() return await resp.text()