homelab-codex-ws/services/planner-agent/tests/test_planner.py

605 lines
22 KiB
Python
Raw Normal View History

feat(planner-agent): main loop with LLM routing and HITL action proposals services/planner-agent/src/planner.py: - PlannerAgent: async Redis pub/sub on health_events + world_updates - Pipeline: receive event → cooldown gate → LLMRouter → write pending action → emit remediation_started filesystem event - CooldownTracker: 5-min suppression per svc_key (configurable via env) - parse_event(): accepts node-agent shape A and world_updates shape B - PROPOSAL_SCHEMA: jsonschema enforced by LLMRouter before accepting response - SYSTEM_PROMPT: homelab topology + action rules (chelsty always requires_human, disk_pressure always notify, confidence<0.7 → requires_human) - write_pending_action(): atomic tmp→rename write, executor-compatible format - emit_event(): async wrapper around filesystem event write (no control-plane import) - _emit_event_sync() reads NODE_NAME at call time (not import) for testability - Benign events (service_healthy, node_online, ...) silently skipped - LLM chain failure: no cooldown recorded so next event can retry services/planner-agent/tests/test_planner.py (49 tests, 0 network): - TestCooldownTracker: 7 tests (ready/not-ready/elapsed/reset/independence) - TestHealthEvent, TestActionProposal, TestMapActionToExecutorType - TestParseEvent: both event shapes, missing fields, timestamp formats - TestBuildMessages: system prompt rules, payload inclusion - TestPlannerHandleEvent: benign skip, cooldown block, ignore/restart/redeploy/ notify proposals, remediation event emission, LLM failure isolation, requires_human propagation, cooldown recording, model name in proposal - TestPlannerDispatch: valid JSON, invalid JSON, non-string data, missing node - TestWritePendingAction, TestEmitEvent: filesystem integration with tmp_path services/planner-agent/service.yaml: owner_node: solaria, dependencies: [redis, ollama] services/planner-agent/docker-compose.yml: env + healthcheck services/planner-agent/Dockerfile: python:3.11-slim services/planner-agent/healthcheck.sh: heartbeat file age check (300s) services/planner-agent/requirements.txt: litellm, redis, jsonschema, structlog Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 19:11:39 +02:00
"""
Unit tests for planner.py.
All Redis, LLMRouter, and filesystem operations are mocked
no network or disk I/O required.
Run:
pytest services/planner-agent/tests/test_planner.py -v
"""
import asyncio
import json
import sys
import time
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch, call
import pytest
# Allow importing from src/ without installation
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from planner import (
ActionProposal,
CooldownTracker,
HealthEvent,
PlannerAgent,
build_messages,
map_action_to_executor_type,
parse_event,
write_pending_action,
emit_event,
PROPOSAL_SCHEMA,
)
from llm_router import AttemptRecord, RouteResult
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_route_result(
action: str = "restart",
service: str = "mosquitto",
node: str = "piha",
reason: str = "Container is stopped",
confidence: float = 0.9,
requires_human: bool = False,
model: str = "ollama/qwen2.5:7b",
) -> RouteResult:
content = {
"action": action,
"service": service,
"node": node,
"reason": reason,
"confidence": confidence,
"requires_human": requires_human,
}
return RouteResult(
content = content,
raw_text = json.dumps(content),
model_used = model,
attempts = [AttemptRecord(model, "success", None, 120)],
latency_ms = 120,
)
def _health_event(
node: str = "piha",
service: str = "mosquitto",
event_type: str = "service_unhealthy",
severity: str = "error",
payload: dict = None,
) -> HealthEvent:
return HealthEvent(
node = node,
service = service,
event_type = event_type,
severity = severity,
payload = payload or {},
timestamp = time.time(),
)
def _mock_router(result: RouteResult) -> MagicMock:
router = MagicMock()
router.route = AsyncMock(return_value=result)
router.close = AsyncMock()
return router
# ---------------------------------------------------------------------------
# CooldownTracker
# ---------------------------------------------------------------------------
class TestCooldownTracker:
def test_initially_ready(self):
ct = CooldownTracker(cooldown_seconds=60)
assert ct.is_ready("piha/mosquitto") is True
def test_not_ready_after_record(self):
ct = CooldownTracker(cooldown_seconds=300)
ct.record("piha/mosquitto")
assert ct.is_ready("piha/mosquitto") is False
def test_ready_after_elapsed(self):
ct = CooldownTracker(cooldown_seconds=1)
ct.record("piha/mosquitto")
time.sleep(1.1)
assert ct.is_ready("piha/mosquitto") is True
def test_remaining_seconds_decreases(self):
ct = CooldownTracker(cooldown_seconds=60)
ct.record("piha/mosquitto")
r = ct.remaining_seconds("piha/mosquitto")
assert 0 < r <= 60
def test_remaining_zero_when_never_recorded(self):
ct = CooldownTracker()
assert ct.remaining_seconds("ghost/svc") == 0.0
def test_reset_clears_cooldown(self):
ct = CooldownTracker(cooldown_seconds=300)
ct.record("piha/mosquitto")
assert ct.is_ready("piha/mosquitto") is False
ct.reset("piha/mosquitto")
assert ct.is_ready("piha/mosquitto") is True
def test_independent_keys(self):
ct = CooldownTracker(cooldown_seconds=300)
ct.record("piha/mosquitto")
assert ct.is_ready("piha/mosquitto") is False
assert ct.is_ready("solaria/ollama") is True
# ---------------------------------------------------------------------------
# HealthEvent
# ---------------------------------------------------------------------------
class TestHealthEvent:
def test_svc_key(self):
e = _health_event("piha", "mosquitto")
assert e.svc_key == "piha/mosquitto"
def test_str_repr(self):
e = _health_event("vps", "observer", "service_unhealthy", "error")
assert "service_unhealthy" in str(e)
assert "vps/observer" in str(e)
# ---------------------------------------------------------------------------
# ActionProposal.to_action_file
# ---------------------------------------------------------------------------
class TestActionProposal:
def _sample(self, **kwargs) -> ActionProposal:
defaults = dict(
action_id = "plan-piha-mosquitto-123",
type = "container_restart",
action = "restart",
service = "mosquitto",
node = "piha",
reason = "Container stopped unexpectedly",
confidence = 0.9,
requires_human = False,
risk_level = "low",
)
defaults.update(kwargs)
return ActionProposal(**defaults)
def test_to_action_file_keys(self):
d = self._sample().to_action_file()
for key in ("action_id", "type", "node", "service", "risk_level",
"confidence", "requires_human", "status", "timestamp",
"source_event", "llm_model", "llm_attempts", "payload"):
assert key in d, f"missing key: {key}"
def test_status_pending(self):
d = self._sample().to_action_file()
assert d["status"] == "pending"
def test_payload_contains_action_and_reason(self):
d = self._sample().to_action_file()
assert d["payload"]["action"] == "restart"
assert "Container stopped" in d["payload"]["reason"]
def test_description_fallback_to_reason(self):
p = self._sample(description="")
d = p.to_action_file()
assert d["description"] == p.reason
# ---------------------------------------------------------------------------
# map_action_to_executor_type
# ---------------------------------------------------------------------------
class TestMapActionToExecutorType:
@pytest.mark.parametrize("action,expected_type,expected_risk", [
("restart", "container_restart", "low"),
("redeploy", "redeploy", "guarded"),
("notify", "notify", "low"),
("ignore", "ignore", "none"),
("unknown", "notify", "low"), # safe fallback
])
def test_mapping(self, action, expected_type, expected_risk):
t, r = map_action_to_executor_type(action)
assert t == expected_type
assert r == expected_risk
# ---------------------------------------------------------------------------
# parse_event
# ---------------------------------------------------------------------------
class TestParseEvent:
def test_shape_a_node_agent(self):
raw = {
"type": "service_unhealthy",
"node": "piha",
"service": "mosquitto",
"severity": "error",
"payload": {"status": "exited"},
}
ev = parse_event(raw, "health_events")
assert ev is not None
assert ev.node == "piha"
assert ev.service == "mosquitto"
assert ev.event_type == "service_unhealthy"
assert ev.severity == "error"
assert ev.payload == {"status": "exited"}
def test_shape_b_world_updates(self):
raw = {
"event_type": "node_offline",
"node": "chelsty-infra",
"service": "mosquitto",
"severity": "critical",
}
ev = parse_event(raw, "world_updates")
assert ev is not None
assert ev.event_type == "node_offline"
assert ev.node == "chelsty-infra"
def test_missing_node_returns_none(self):
raw = {"type": "service_unhealthy", "service": "mosquitto"}
assert parse_event(raw, "health_events") is None
def test_missing_type_returns_none(self):
raw = {"node": "piha", "service": "mosquitto"}
assert parse_event(raw, "health_events") is None
def test_service_falls_back_to_node(self):
raw = {"type": "node_offline", "node": "piha"}
ev = parse_event(raw, "health_events")
assert ev is not None
assert ev.service == "piha"
def test_timestamp_iso_parsed(self):
raw = {
"type": "service_unhealthy",
"node": "piha",
"service": "mosquitto",
"timestamp": "2026-05-27T12:00:00Z",
}
ev = parse_event(raw, "health_events")
assert ev is not None
assert ev.timestamp > 1_700_000_000 # sanity: recent epoch
def test_timestamp_numeric_accepted(self):
ts = time.time()
raw = {"type": "service_unhealthy", "node": "piha", "service": "mosquitto",
"timestamp": ts}
ev = parse_event(raw, "health_events")
assert abs(ev.timestamp - ts) < 1
def test_channel_stored(self):
raw = {"type": "service_unhealthy", "node": "piha", "service": "mosquitto"}
ev = parse_event(raw, "world_updates")
assert ev.raw_channel == "world_updates"
# ---------------------------------------------------------------------------
# build_messages
# ---------------------------------------------------------------------------
class TestBuildMessages:
def test_returns_two_messages(self):
ev = _health_event()
msgs = build_messages(ev)
assert len(msgs) == 2
assert msgs[0]["role"] == "system"
assert msgs[1]["role"] == "user"
def test_user_message_contains_event_fields(self):
ev = _health_event("vps", "observer", "service_unhealthy", "error",
payload={"exit_code": 1})
msgs = build_messages(ev)
user = msgs[1]["content"]
assert "vps" in user
assert "observer" in user
assert "service_unhealthy" in user
def test_payload_included_when_present(self):
ev = _health_event(payload={"disk_pct": 95})
msgs = build_messages(ev)
assert "disk_pct" in msgs[1]["content"]
def test_system_prompt_contains_homelab_rules(self):
ev = _health_event()
msgs = build_messages(ev)
sys_content = msgs[0]["content"]
assert "chelsty" in sys_content
assert "requires_human" in sys_content
# ---------------------------------------------------------------------------
# PlannerAgent._handle_event
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestPlannerHandleEvent:
def _agent(self, result: RouteResult) -> PlannerAgent:
router = _mock_router(result)
return PlannerAgent(redis_url=None, router=router)
async def test_benign_event_no_proposal(self, tmp_path):
agent = self._agent(_make_route_result())
ev = _health_event(event_type="service_healthy")
with patch("planner.write_pending_action", new=AsyncMock()) as mock_write:
await agent._handle_event(ev)
mock_write.assert_not_called()
async def test_cooldown_blocks_duplicate(self):
agent = self._agent(_make_route_result())
ev = _health_event()
agent.cooldown.record(ev.svc_key) # simulate recent proposal
with patch("planner.write_pending_action", new=AsyncMock()) as mock_write:
await agent._handle_event(ev)
mock_write.assert_not_called()
agent.router.route.assert_not_called()
async def test_ignore_action_no_file_written(self):
agent = self._agent(_make_route_result(action="ignore", reason="Transient glitch"))
ev = _health_event()
with patch("planner.write_pending_action", new=AsyncMock()) as mock_write:
await agent._handle_event(ev)
mock_write.assert_not_called()
async def test_ignore_records_cooldown(self):
agent = self._agent(_make_route_result(action="ignore", reason="Transient glitch"))
ev = _health_event()
with patch("planner.write_pending_action", new=AsyncMock()):
await agent._handle_event(ev)
assert not agent.cooldown.is_ready(ev.svc_key)
async def test_restart_action_writes_pending_file(self, tmp_path):
agent = self._agent(_make_route_result(action="restart"))
ev = _health_event()
captured: list[ActionProposal] = []
async def fake_write(p: ActionProposal) -> Path:
captured.append(p)
return tmp_path / f"{p.action_id}.json"
with patch("planner.write_pending_action", new=fake_write), \
patch("planner.emit_event", new=AsyncMock()):
await agent._handle_event(ev)
assert len(captured) == 1
assert captured[0].action == "restart"
assert captured[0].type == "container_restart"
async def test_redeploy_action_risk_guarded(self, tmp_path):
agent = self._agent(_make_route_result(action="redeploy"))
ev = _health_event()
captured: list[ActionProposal] = []
async def fake_write(p: ActionProposal) -> Path:
captured.append(p)
return tmp_path / f"{p.action_id}.json"
with patch("planner.write_pending_action", new=fake_write), \
patch("planner.emit_event", new=AsyncMock()):
await agent._handle_event(ev)
assert captured[0].risk_level == "guarded"
assert captured[0].type == "redeploy"
async def test_remediation_started_event_emitted(self, tmp_path):
agent = self._agent(_make_route_result(action="restart"))
ev = _health_event()
emitted: list[tuple] = []
async def fake_emit(event_type, severity, service, correlation_id, payload=None):
emitted.append((event_type, service, correlation_id))
with patch("planner.write_pending_action", new=AsyncMock(return_value=tmp_path / "x.json")), \
patch("planner.emit_event", new=fake_emit):
await agent._handle_event(ev)
assert len(emitted) == 1
assert emitted[0][0] == "remediation_started"
assert emitted[0][1] == ev.service
async def test_llm_failure_no_file_no_cooldown(self):
router = MagicMock()
router.route = AsyncMock(side_effect=RuntimeError("all models failed"))
router.close = AsyncMock()
agent = PlannerAgent(redis_url=None, router=router)
ev = _health_event()
with patch("planner.write_pending_action", new=AsyncMock()) as mock_write:
await agent._handle_event(ev)
mock_write.assert_not_called()
# Cooldown NOT recorded — next event should be able to retry
assert agent.cooldown.is_ready(ev.svc_key) is True
async def test_requires_human_preserved_in_proposal(self, tmp_path):
agent = self._agent(
_make_route_result(action="restart", requires_human=True, confidence=0.6)
)
ev = _health_event()
captured: list[ActionProposal] = []
async def fake_write(p: ActionProposal) -> Path:
captured.append(p)
return tmp_path / f"{p.action_id}.json"
with patch("planner.write_pending_action", new=fake_write), \
patch("planner.emit_event", new=AsyncMock()):
await agent._handle_event(ev)
assert captured[0].requires_human is True
async def test_cooldown_recorded_after_success(self, tmp_path):
agent = self._agent(_make_route_result(action="restart"))
ev = _health_event()
with patch("planner.write_pending_action",
new=AsyncMock(return_value=tmp_path / "x.json")), \
patch("planner.emit_event", new=AsyncMock()):
await agent._handle_event(ev)
assert not agent.cooldown.is_ready(ev.svc_key)
async def test_llm_model_recorded_in_proposal(self, tmp_path):
agent = self._agent(
_make_route_result(action="restart", model="claude-haiku-4-5-20251001")
)
ev = _health_event()
captured: list[ActionProposal] = []
async def fake_write(p: ActionProposal) -> Path:
captured.append(p)
return tmp_path / f"{p.action_id}.json"
with patch("planner.write_pending_action", new=fake_write), \
patch("planner.emit_event", new=AsyncMock()):
await agent._handle_event(ev)
assert captured[0].llm_model == "claude-haiku-4-5-20251001"
# ---------------------------------------------------------------------------
# PlannerAgent._dispatch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestPlannerDispatch:
def _agent(self) -> PlannerAgent:
router = _mock_router(_make_route_result())
return PlannerAgent(redis_url=None, router=router)
async def test_valid_json_dispatched(self):
agent = self._agent()
msg = {
"channel": "health_events",
"data": json.dumps({
"type": "service_unhealthy",
"node": "piha",
"service": "mosquitto",
"severity": "error",
}),
}
with patch.object(agent, "_handle_event", new=AsyncMock()) as mock_handle:
await agent._dispatch(msg)
mock_handle.assert_awaited_once()
async def test_invalid_json_skipped(self):
agent = self._agent()
msg = {"channel": "health_events", "data": "{not valid json"}
with patch.object(agent, "_handle_event", new=AsyncMock()) as mock_handle:
await agent._dispatch(msg)
mock_handle.assert_not_called()
async def test_non_string_data_skipped(self):
agent = self._agent()
msg = {"channel": "health_events", "data": 42}
with patch.object(agent, "_handle_event", new=AsyncMock()) as mock_handle:
await agent._dispatch(msg)
mock_handle.assert_not_called()
async def test_missing_node_skipped(self):
agent = self._agent()
msg = {
"channel": "health_events",
"data": json.dumps({"type": "service_unhealthy", "service": "mosquitto"}),
}
with patch.object(agent, "_handle_event", new=AsyncMock()) as mock_handle:
await agent._dispatch(msg)
mock_handle.assert_not_called()
# ---------------------------------------------------------------------------
# write_pending_action (integration-style with tmp_path)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestWritePendingAction:
async def test_file_created_with_correct_content(self, tmp_path):
proposal = ActionProposal(
action_id = "plan-piha-mosquitto-1000",
type = "container_restart",
action = "restart",
service = "mosquitto",
node = "piha",
reason = "Container stopped unexpectedly",
confidence = 0.95,
requires_human = False,
risk_level = "low",
)
with patch("planner.ACTIONS_DIR", tmp_path):
path = await write_pending_action(proposal)
assert path.exists()
data = json.loads(path.read_text())
assert data["action_id"] == "plan-piha-mosquitto-1000"
assert data["status"] == "pending"
assert data["type"] == "container_restart"
assert data["confidence"] == 0.95
assert data["requires_human"] is False
async def test_file_is_valid_json(self, tmp_path):
proposal = ActionProposal(
action_id="x", type="redeploy", action="redeploy",
service="ollama", node="solaria",
reason="Service is broken beyond a simple restart",
confidence=0.8, requires_human=True, risk_level="guarded",
)
with patch("planner.ACTIONS_DIR", tmp_path):
path = await write_pending_action(proposal)
# Should not raise
json.loads(path.read_text())
# ---------------------------------------------------------------------------
# emit_event (filesystem write)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestEmitEvent:
async def test_event_file_created(self, tmp_path):
with patch("planner.EVENTS_DIR", tmp_path), \
patch("planner.NODE_NAME", "test-node"):
await emit_event(
event_type = "remediation_started",
severity = "info",
service = "mosquitto",
correlation_id = "plan-abc-123",
payload = {"action": "restart"},
)
files = list(tmp_path.rglob("*.json"))
assert len(files) == 1
data = json.loads(files[0].read_text())
assert data["type"] == "remediation_started"
assert data["service"] == "mosquitto"
assert data["correlation_id"] == "plan-abc-123"
assert data["payload"]["action"] == "restart"
async def test_event_dir_structure(self, tmp_path):
"""Events must be stored under YYYY-MM-DD/<node>/."""
import planner as planner_mod
orig = planner_mod.NODE_NAME
planner_mod.NODE_NAME = "piha"
try:
with patch("planner.EVENTS_DIR", tmp_path):
await emit_event("test_event", "info", "svc", "cid-1")
finally:
planner_mod.NODE_NAME = orig
files = list(tmp_path.rglob("*.json"))
assert len(files) == 1
# Path: <date>/<node>/<filename>
parts = files[0].relative_to(tmp_path).parts
assert len(parts) == 3 # date / node / file.json
assert parts[1] == "piha"