"""Unit tests for AutomationFailuresCheck.""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock import pytest from ha_diag.checks.automation_failures import AutomationFailuresCheck, _is_trace_failure from ha_diag.config import Settings from ha_diag.models import HAEventType, Severity from ha_diag.storage import Storage # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_settings(**overrides) -> Settings: defaults: dict = { "ha_url": "http://test.local:8123", "ha_token": "test", "node_name": "test-node", "location_tag": "test-loc", "alert_cooldown_hours": 0.0, "automation_failure_threshold": 3, "check_interval": 60, "check_interval_unavailable": 3600, } defaults.update(overrides) return Settings(**defaults) def _make_client(states=None, traces_by_id=None, states_error=None): client = MagicMock() if states_error: client.get_states = AsyncMock(side_effect=states_error) else: client.get_states = AsyncMock(return_value=states or []) traces_map = traces_by_id or {} async def _get_traces(eid): if eid not in traces_map: raise Exception(f"404 for {eid}") return traces_map[eid] client.get_automation_traces = AsyncMock(side_effect=_get_traces) return client def _auto_state(entity_id: str, state: str = "on", friendly_name: str | None = None) -> dict: attrs: dict = {} if friendly_name: attrs["friendly_name"] = friendly_name return {"entity_id": entity_id, "state": state, "attributes": attrs} def _trace(error: str | None = None, state: str = "stopped") -> dict: return { "run_id": "abc", "timestamp": "2026-05-27T10:00:00+00:00", "trigger": "state", "state": state if error is None else "stopped", "error": error, } def _fail(error: str = "Script error") -> dict: return _trace(error=error) def _ok() -> dict: return _trace(error=None) # --------------------------------------------------------------------------- # _is_trace_failure unit tests # --------------------------------------------------------------------------- def test_trace_with_error_is_failure(): assert _is_trace_failure({"error": "Something went wrong"}) is True def test_trace_with_state_failed_is_failure(): assert _is_trace_failure({"state": "failed", "error": None}) is True def test_trace_with_null_error_is_success(): assert _is_trace_failure({"error": None, "state": "stopped"}) is False def test_trace_with_empty_string_error_is_success(): assert _is_trace_failure({"error": "", "state": "stopped"}) is False def test_trace_with_no_keys_is_success(): assert _is_trace_failure({}) is False # --------------------------------------------------------------------------- # AutomationFailuresCheck.run() tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_no_automations_returns_empty(storage: Storage): check = AutomationFailuresCheck(_make_client([]), storage, _make_settings()) assert await check.run() == [] @pytest.mark.asyncio async def test_disabled_automation_skipped(storage: Storage): states = [_auto_state("automation.morning_lights", state="off")] check = AutomationFailuresCheck(_make_client(states, {}), storage, _make_settings()) assert await check.run() == [] @pytest.mark.asyncio async def test_automation_with_no_traces_skipped(storage: Storage): states = [_auto_state("automation.morning_lights")] # _make_client raises exception for missing keys → graceful skip check = AutomationFailuresCheck(_make_client(states, {}), storage, _make_settings()) assert await check.run() == [] @pytest.mark.asyncio async def test_fewer_traces_than_threshold_skipped(storage: Storage): states = [_auto_state("automation.a")] traces = {"automation.a": [_fail(), _fail()]} # 2 failures, threshold=3 check = AutomationFailuresCheck(_make_client(states, traces), storage, _make_settings()) assert await check.run() == [] @pytest.mark.asyncio async def test_all_recent_failed_emits_event(storage: Storage): states = [_auto_state("automation.a", friendly_name="Morning Lights")] traces = {"automation.a": [_fail("step failed"), _fail("timeout"), _fail("no device")]} check = AutomationFailuresCheck(_make_client(states, traces), storage, _make_settings()) results = await check.run() assert len(results) == 1 r = results[0] assert r.event_type == HAEventType.ha_automation_failing assert r.severity == Severity.warning assert r.payload["entity_id"] == "automation.a" assert r.payload["friendly_name"] == "Morning Lights" assert r.payload["total_recent_failures"] == 3 assert len(r.payload["last_failures"]) == 3 @pytest.mark.asyncio async def test_partial_failures_no_event(storage: Storage): states = [_auto_state("automation.a")] # 2 failures, 1 success in recent 3 → not all failed traces = {"automation.a": [_fail(), _ok(), _fail()]} check = AutomationFailuresCheck(_make_client(states, traces), storage, _make_settings()) assert await check.run() == [] @pytest.mark.asyncio async def test_cooldown_prevents_duplicate_event(storage: Storage): states = [_auto_state("automation.a")] traces = {"automation.a": [_fail(), _fail(), _fail()]} settings = _make_settings(alert_cooldown_hours=6.0) check = AutomationFailuresCheck(_make_client(states, traces), storage, settings) r1 = await check.run() r2 = await check.run() assert len(r1) == 1 assert r2 == [] @pytest.mark.asyncio async def test_multiple_failing_automations(storage: Storage): states = [_auto_state("automation.a"), _auto_state("automation.b")] traces = { "automation.a": [_fail(), _fail(), _fail()], "automation.b": [_fail(), _fail(), _fail()], } check = AutomationFailuresCheck(_make_client(states, traces), storage, _make_settings()) results = await check.run() assert len(results) == 2 eids = {r.payload["entity_id"] for r in results} assert eids == {"automation.a", "automation.b"} @pytest.mark.asyncio async def test_states_error_returns_empty(storage: Storage): check = AutomationFailuresCheck( _make_client(states_error=ConnectionError("down")), storage, _make_settings() ) assert await check.run() == [] @pytest.mark.asyncio async def test_custom_threshold(storage: Storage): states = [_auto_state("automation.a")] # threshold=2: 2 failures should trigger traces = {"automation.a": [_fail(), _fail(), _ok()]} settings = _make_settings(automation_failure_threshold=2) check = AutomationFailuresCheck(_make_client(states, traces), storage, settings) results = await check.run() assert len(results) == 1 @pytest.mark.asyncio async def test_failure_with_state_failed_field(storage: Storage): states = [_auto_state("automation.a")] traces = {"automation.a": [ {"run_id": "x", "state": "failed", "error": None, "timestamp": "2026-05-27T10:00:00Z"}, {"run_id": "y", "state": "failed", "error": None, "timestamp": "2026-05-27T09:00:00Z"}, {"run_id": "z", "state": "failed", "error": None, "timestamp": "2026-05-27T08:00:00Z"}, ]} check = AutomationFailuresCheck(_make_client(states, traces), storage, _make_settings()) results = await check.run() assert len(results) == 1