homelab-codex-ws/services/planner-agent/tests/test_llm_router.py
Oskar Kapala 1bbc511bb7 feat(planner-agent): add llm_router.py with local-first fallback chain
services/planner-agent/src/llm_router.py:
- LLMRouter: async routing via litellm; chain = Qwen/Ollama → haiku → sonnet
- Timeouts: 8s local, 30s cloud; asyncio.wait_for belt-and-suspenders
- Rejection triggers: timeout, API error, refusal patterns, JSON schema fail
- JSON fence extraction: recovers valid JSON from  blocks
- ModelMetrics: per-model success/fallback/error counters + success_rate()
- Redis publish to 'llm_router_metrics' after every call (failure-safe)
- redis_url=None disables Redis (useful in tests / edge nodes)
- context= param adds caller label to all log lines for tracing

services/planner-agent/tests/test_llm_router.py:
- 34 tests, 0 network calls (litellm + Redis fully mocked)
- Covers: primary success, JSON error fallback, refusal fallback,
  timeout fallback, API exception fallback, all-fail RuntimeError,
  schema validation, fence extraction, metrics recording, Redis publish,
  Redis failure isolation

services/planner-agent/requirements.txt:
- litellm>=1.40.0, redis>=5.0.0, jsonschema>=4.21.0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-27 18:38:06 +02:00

450 lines
16 KiB
Python

"""
Unit tests for llm_router.py.
All LLM and Redis calls are mocked — no network required.
Run:
pip install pytest pytest-asyncio litellm jsonschema redis
pytest services/planner-agent/tests/test_llm_router.py -v
"""
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Allow importing from src/ without installation
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from llm_router import (
AttemptRecord,
LLMRouter,
ModelConfig,
ModelMetrics,
RouteResult,
_extract_json_from_fence,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_completion(content: str):
"""Build a minimal litellm-style response object."""
msg = MagicMock()
msg.content = content
choice = MagicMock()
choice.message = msg
resp = MagicMock()
resp.choices = [choice]
return resp
def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]:
"""Build a minimal test chain with no api_base."""
return [ModelConfig(name=name, timeout=timeout) for name, timeout in models]
# ---------------------------------------------------------------------------
# ModelMetrics
# ---------------------------------------------------------------------------
class TestModelMetrics:
def test_record_and_snapshot(self):
m = ModelMetrics()
m.record("qwen", "success")
m.record("qwen", "success")
m.record("qwen", "fallback")
m.record("haiku", "success")
snap = m.snapshot()
assert snap["qwen"]["success"] == 2
assert snap["qwen"]["fallback"] == 1
assert snap["haiku"]["success"] == 1
def test_success_rate(self):
m = ModelMetrics()
m.record("model-a", "success")
m.record("model-a", "fallback")
assert m.success_rate("model-a") == 0.5
def test_success_rate_unknown_model(self):
m = ModelMetrics()
assert m.success_rate("ghost") is None
def test_total_calls(self):
m = ModelMetrics()
m.record("x", "success")
m.record("x", "error")
assert m.total_calls("x") == 2
def test_snapshot_is_copy(self):
m = ModelMetrics()
m.record("x", "success")
snap = m.snapshot()
snap["x"]["success"] = 999 # mutate the copy
assert m.snapshot()["x"]["success"] == 1 # original unchanged
# ---------------------------------------------------------------------------
# RouteResult
# ---------------------------------------------------------------------------
class TestRouteResult:
def test_succeeded(self):
r = RouteResult("hello", "hello", "model-a", [], 100)
assert r.succeeded is True
def test_not_succeeded(self):
r = RouteResult(None, "", "none", [], 100)
assert r.succeeded is False
def test_to_dict_structure(self):
attempt = AttemptRecord("m", "success", None, 50)
r = RouteResult("x", "x", "m", [attempt], 50)
d = r.to_dict()
assert d["model_used"] == "m"
assert len(d["attempts"]) == 1
assert d["attempts"][0]["outcome"] == "success"
# ---------------------------------------------------------------------------
# _extract_json_from_fence
# ---------------------------------------------------------------------------
class TestExtractJsonFromFence:
def test_json_fence(self):
text = 'Sure!\n```json\n{"a": 1}\n```\nDone.'
assert _extract_json_from_fence(text) == {"a": 1}
def test_plain_fence(self):
text = "```\n[1, 2]\n```"
assert _extract_json_from_fence(text) == [1, 2]
def test_no_fence(self):
assert _extract_json_from_fence("no json here") is None
def test_broken_fence(self):
assert _extract_json_from_fence("```json\n{broken```") is None
# ---------------------------------------------------------------------------
# LLMRouter — validation & refusal detection
# ---------------------------------------------------------------------------
class TestLLMRouterValidation:
def setup_method(self):
self.router = LLMRouter(redis_url=None)
def test_validate_no_schema_returns_text(self):
parsed, err = self.router._validate("hello world", schema=None)
assert parsed == "hello world"
assert err is None
def test_validate_valid_json(self):
schema = {"type": "object", "required": ["action"]}
parsed, err = self.router._validate('{"action": "redeploy"}', schema)
assert err is None
assert parsed == {"action": "redeploy"}
def test_validate_invalid_json(self):
schema = {"type": "object"}
_, err = self.router._validate("not json {", schema)
assert err is not None
assert "JSONDecodeError" in err
def test_validate_schema_violation(self):
schema = {"type": "object", "required": ["action"]}
_, err = self.router._validate('{"other": 1}', schema)
assert err is not None
assert "SchemaValidationError" in err
def test_validate_extracts_fenced_json(self):
schema = {"type": "object", "required": ["action"]}
text = '```json\n{"action": "restart"}\n```'
parsed, err = self.router._validate(text, schema)
assert err is None
assert parsed == {"action": "restart"}
def test_detect_refusal_nie_wiem(self):
assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem"
def test_detect_refusal_as_an_ai(self):
# Text contains both "as an AI" and "I cannot"; first match wins.
# We only assert that a refusal IS detected, not which pattern fires.
assert self.router._detect_refusal("As an AI I cannot help") is not None
def test_detect_refusal_none(self):
assert self.router._detect_refusal("Sure, here is the action.") is None
def test_detect_refusal_case_insensitive(self):
assert self.router._detect_refusal("I CANNOT do that") == "I cannot"
# ---------------------------------------------------------------------------
# LLMRouter — routing logic (mocked litellm + Redis)
# ---------------------------------------------------------------------------
@pytest.fixture
def router_no_redis():
"""Router with a 3-model chain and Redis disabled."""
chain = _chain_of(
("local/qwen", 8.0),
("claude-haiku-test", 30.0),
("claude-sonnet-test", 30.0),
)
return LLMRouter(redis_url=None, chain=chain)
@pytest.mark.asyncio
class TestLLMRouterRouting:
async def test_primary_success(self, router_no_redis):
"""Primary model succeeds — no fallback."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion('{"action": "ok"}')),
):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "test"}],
schema={"type": "object", "required": ["action"]},
)
assert result.model_used == "local/qwen"
assert result.content == {"action": "ok"}
assert len(result.attempts) == 1
assert result.attempts[0].outcome == "success"
async def test_fallback_on_json_error(self, router_no_redis):
"""Primary returns bad JSON → falls back to haiku which returns valid JSON."""
responses = [
_fake_completion("not json at all"),
_fake_completion('{"action": "restart"}'),
]
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
r = responses[call_count]
call_count += 1
return r
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object", "required": ["action"]},
)
assert result.model_used == "claude-haiku-test"
assert result.content == {"action": "restart"}
assert result.attempts[0].outcome == "invalid"
assert result.attempts[1].outcome == "success"
async def test_fallback_on_refusal(self, router_no_redis):
"""Primary returns refusal text → falls back."""
responses = [
_fake_completion("I cannot help with that request."),
_fake_completion("Sure! Here is the plan."),
]
idx = 0
async def fake_acompletion(**kwargs):
nonlocal idx
r = responses[idx]
idx += 1
return r
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert result.attempts[0].outcome == "rejected"
assert "RefusalPattern" in result.attempts[0].reason
async def test_fallback_on_timeout(self, router_no_redis):
"""Primary times out → falls back to haiku."""
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise asyncio.TimeoutError()
return _fake_completion("fallback response")
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert "Timeout" in result.attempts[0].reason
async def test_fallback_on_api_exception(self, router_no_redis):
"""Primary raises a connection error → falls back."""
import litellm.exceptions
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise litellm.exceptions.APIConnectionError(
message="Connection refused", llm_provider="ollama", model="qwen"
)
return _fake_completion("ok")
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert "APIConnectionError" in result.attempts[0].reason
async def test_all_models_fail_raises(self, router_no_redis):
"""All models return bad JSON → RuntimeError with attempt log."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("not json")),
):
with pytest.raises(RuntimeError) as exc_info:
await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object"},
)
assert "All 3 models in chain failed" in str(exc_info.value)
async def test_schema_none_no_json_required(self, router_no_redis):
"""Without a schema, plain text responses are accepted."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("Here is your plan.")),
):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.content == "Here is your plan."
assert result.succeeded
async def test_metrics_recorded_on_success(self, router_no_redis):
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
await router_no_redis.route([{"role": "user", "content": "x"}])
snap = router_no_redis.metrics.snapshot()
assert snap["local/qwen"]["success"] == 1
async def test_metrics_fallback_recorded(self, router_no_redis):
"""Primary fails → fallback → metrics show primary=fallback, haiku=success."""
responses = [
_fake_completion("I cannot help"), # refusal
_fake_completion("ok"),
]
idx = 0
async def fake_acompletion(**kwargs):
nonlocal idx
r = responses[idx]; idx += 1
return r
with patch("litellm.acompletion", fake_acompletion):
await router_no_redis.route([{"role": "user", "content": "x"}])
snap = router_no_redis.metrics.snapshot()
assert snap["local/qwen"]["fallback"] == 1
assert snap["claude-haiku-test"]["success"] == 1
async def test_context_label_in_logs(self, router_no_redis, caplog):
"""context= parameter appears in log output."""
import logging
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
with caplog.at_level(logging.INFO, logger="llm_router"):
await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
context="supervisor.reconcile",
)
assert any("supervisor.reconcile" in r.message for r in caplog.records)
# ---------------------------------------------------------------------------
# LLMRouter — Redis metrics publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestLLMRouterRedis:
async def test_metrics_published_on_success(self):
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
await router.route([{"role": "user", "content": "x"}])
mock_redis.publish.assert_awaited_once()
channel, payload_str = mock_redis.publish.call_args[0]
assert channel == "llm_router_metrics"
payload = json.loads(payload_str)
assert payload["model_used"] == "m1"
assert "metrics_snapshot" in payload
assert "timestamp" in payload
async def test_redis_failure_does_not_raise(self):
"""A broken Redis must never break the LLM call result."""
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
mock_redis.publish.side_effect = ConnectionError("Redis down")
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
result = await router.route([{"role": "user", "content": "x"}])
assert result.succeeded # LLM call still returned
async def test_metrics_published_on_full_failure(self):
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("not json")),
):
with pytest.raises(RuntimeError):
await router.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object"},
)
# Metrics must still be published even when we raise
mock_redis.publish.assert_awaited_once()
_, payload_str = mock_redis.publish.call_args[0]
payload = json.loads(payload_str)
assert payload["model_used"] == "none"