services/planner-agent/src/llm_router.py: - LLMRouter: async routing via litellm; chain = Qwen/Ollama → haiku → sonnet - Timeouts: 8s local, 30s cloud; asyncio.wait_for belt-and-suspenders - Rejection triggers: timeout, API error, refusal patterns, JSON schema fail - JSON fence extraction: recovers valid JSON from blocks - ModelMetrics: per-model success/fallback/error counters + success_rate() - Redis publish to 'llm_router_metrics' after every call (failure-safe) - redis_url=None disables Redis (useful in tests / edge nodes) - context= param adds caller label to all log lines for tracing services/planner-agent/tests/test_llm_router.py: - 34 tests, 0 network calls (litellm + Redis fully mocked) - Covers: primary success, JSON error fallback, refusal fallback, timeout fallback, API exception fallback, all-fail RuntimeError, schema validation, fence extraction, metrics recording, Redis publish, Redis failure isolation services/planner-agent/requirements.txt: - litellm>=1.40.0, redis>=5.0.0, jsonschema>=4.21.0 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
450 lines
16 KiB
Python
450 lines
16 KiB
Python
"""
|
|
Unit tests for llm_router.py.
|
|
|
|
All LLM and Redis calls are mocked — no network required.
|
|
|
|
Run:
|
|
pip install pytest pytest-asyncio litellm jsonschema redis
|
|
pytest services/planner-agent/tests/test_llm_router.py -v
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# Allow importing from src/ without installation
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from llm_router import (
|
|
AttemptRecord,
|
|
LLMRouter,
|
|
ModelConfig,
|
|
ModelMetrics,
|
|
RouteResult,
|
|
_extract_json_from_fence,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _fake_completion(content: str):
|
|
"""Build a minimal litellm-style response object."""
|
|
msg = MagicMock()
|
|
msg.content = content
|
|
choice = MagicMock()
|
|
choice.message = msg
|
|
resp = MagicMock()
|
|
resp.choices = [choice]
|
|
return resp
|
|
|
|
|
|
def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]:
|
|
"""Build a minimal test chain with no api_base."""
|
|
return [ModelConfig(name=name, timeout=timeout) for name, timeout in models]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ModelMetrics
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestModelMetrics:
|
|
def test_record_and_snapshot(self):
|
|
m = ModelMetrics()
|
|
m.record("qwen", "success")
|
|
m.record("qwen", "success")
|
|
m.record("qwen", "fallback")
|
|
m.record("haiku", "success")
|
|
|
|
snap = m.snapshot()
|
|
assert snap["qwen"]["success"] == 2
|
|
assert snap["qwen"]["fallback"] == 1
|
|
assert snap["haiku"]["success"] == 1
|
|
|
|
def test_success_rate(self):
|
|
m = ModelMetrics()
|
|
m.record("model-a", "success")
|
|
m.record("model-a", "fallback")
|
|
assert m.success_rate("model-a") == 0.5
|
|
|
|
def test_success_rate_unknown_model(self):
|
|
m = ModelMetrics()
|
|
assert m.success_rate("ghost") is None
|
|
|
|
def test_total_calls(self):
|
|
m = ModelMetrics()
|
|
m.record("x", "success")
|
|
m.record("x", "error")
|
|
assert m.total_calls("x") == 2
|
|
|
|
def test_snapshot_is_copy(self):
|
|
m = ModelMetrics()
|
|
m.record("x", "success")
|
|
snap = m.snapshot()
|
|
snap["x"]["success"] = 999 # mutate the copy
|
|
assert m.snapshot()["x"]["success"] == 1 # original unchanged
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RouteResult
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRouteResult:
|
|
def test_succeeded(self):
|
|
r = RouteResult("hello", "hello", "model-a", [], 100)
|
|
assert r.succeeded is True
|
|
|
|
def test_not_succeeded(self):
|
|
r = RouteResult(None, "", "none", [], 100)
|
|
assert r.succeeded is False
|
|
|
|
def test_to_dict_structure(self):
|
|
attempt = AttemptRecord("m", "success", None, 50)
|
|
r = RouteResult("x", "x", "m", [attempt], 50)
|
|
d = r.to_dict()
|
|
assert d["model_used"] == "m"
|
|
assert len(d["attempts"]) == 1
|
|
assert d["attempts"][0]["outcome"] == "success"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _extract_json_from_fence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExtractJsonFromFence:
|
|
def test_json_fence(self):
|
|
text = 'Sure!\n```json\n{"a": 1}\n```\nDone.'
|
|
assert _extract_json_from_fence(text) == {"a": 1}
|
|
|
|
def test_plain_fence(self):
|
|
text = "```\n[1, 2]\n```"
|
|
assert _extract_json_from_fence(text) == [1, 2]
|
|
|
|
def test_no_fence(self):
|
|
assert _extract_json_from_fence("no json here") is None
|
|
|
|
def test_broken_fence(self):
|
|
assert _extract_json_from_fence("```json\n{broken```") is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLMRouter — validation & refusal detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLLMRouterValidation:
|
|
def setup_method(self):
|
|
self.router = LLMRouter(redis_url=None)
|
|
|
|
def test_validate_no_schema_returns_text(self):
|
|
parsed, err = self.router._validate("hello world", schema=None)
|
|
assert parsed == "hello world"
|
|
assert err is None
|
|
|
|
def test_validate_valid_json(self):
|
|
schema = {"type": "object", "required": ["action"]}
|
|
parsed, err = self.router._validate('{"action": "redeploy"}', schema)
|
|
assert err is None
|
|
assert parsed == {"action": "redeploy"}
|
|
|
|
def test_validate_invalid_json(self):
|
|
schema = {"type": "object"}
|
|
_, err = self.router._validate("not json {", schema)
|
|
assert err is not None
|
|
assert "JSONDecodeError" in err
|
|
|
|
def test_validate_schema_violation(self):
|
|
schema = {"type": "object", "required": ["action"]}
|
|
_, err = self.router._validate('{"other": 1}', schema)
|
|
assert err is not None
|
|
assert "SchemaValidationError" in err
|
|
|
|
def test_validate_extracts_fenced_json(self):
|
|
schema = {"type": "object", "required": ["action"]}
|
|
text = '```json\n{"action": "restart"}\n```'
|
|
parsed, err = self.router._validate(text, schema)
|
|
assert err is None
|
|
assert parsed == {"action": "restart"}
|
|
|
|
def test_detect_refusal_nie_wiem(self):
|
|
assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem"
|
|
|
|
def test_detect_refusal_as_an_ai(self):
|
|
# Text contains both "as an AI" and "I cannot"; first match wins.
|
|
# We only assert that a refusal IS detected, not which pattern fires.
|
|
assert self.router._detect_refusal("As an AI I cannot help") is not None
|
|
|
|
def test_detect_refusal_none(self):
|
|
assert self.router._detect_refusal("Sure, here is the action.") is None
|
|
|
|
def test_detect_refusal_case_insensitive(self):
|
|
assert self.router._detect_refusal("I CANNOT do that") == "I cannot"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLMRouter — routing logic (mocked litellm + Redis)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture
|
|
def router_no_redis():
|
|
"""Router with a 3-model chain and Redis disabled."""
|
|
chain = _chain_of(
|
|
("local/qwen", 8.0),
|
|
("claude-haiku-test", 30.0),
|
|
("claude-sonnet-test", 30.0),
|
|
)
|
|
return LLMRouter(redis_url=None, chain=chain)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestLLMRouterRouting:
|
|
async def test_primary_success(self, router_no_redis):
|
|
"""Primary model succeeds — no fallback."""
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion('{"action": "ok"}')),
|
|
):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
schema={"type": "object", "required": ["action"]},
|
|
)
|
|
|
|
assert result.model_used == "local/qwen"
|
|
assert result.content == {"action": "ok"}
|
|
assert len(result.attempts) == 1
|
|
assert result.attempts[0].outcome == "success"
|
|
|
|
async def test_fallback_on_json_error(self, router_no_redis):
|
|
"""Primary returns bad JSON → falls back to haiku which returns valid JSON."""
|
|
responses = [
|
|
_fake_completion("not json at all"),
|
|
_fake_completion('{"action": "restart"}'),
|
|
]
|
|
call_count = 0
|
|
|
|
async def fake_acompletion(**kwargs):
|
|
nonlocal call_count
|
|
r = responses[call_count]
|
|
call_count += 1
|
|
return r
|
|
|
|
with patch("litellm.acompletion", fake_acompletion):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
schema={"type": "object", "required": ["action"]},
|
|
)
|
|
|
|
assert result.model_used == "claude-haiku-test"
|
|
assert result.content == {"action": "restart"}
|
|
assert result.attempts[0].outcome == "invalid"
|
|
assert result.attempts[1].outcome == "success"
|
|
|
|
async def test_fallback_on_refusal(self, router_no_redis):
|
|
"""Primary returns refusal text → falls back."""
|
|
responses = [
|
|
_fake_completion("I cannot help with that request."),
|
|
_fake_completion("Sure! Here is the plan."),
|
|
]
|
|
idx = 0
|
|
|
|
async def fake_acompletion(**kwargs):
|
|
nonlocal idx
|
|
r = responses[idx]
|
|
idx += 1
|
|
return r
|
|
|
|
with patch("litellm.acompletion", fake_acompletion):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
)
|
|
|
|
assert result.model_used == "claude-haiku-test"
|
|
assert result.attempts[0].outcome == "rejected"
|
|
assert "RefusalPattern" in result.attempts[0].reason
|
|
|
|
async def test_fallback_on_timeout(self, router_no_redis):
|
|
"""Primary times out → falls back to haiku."""
|
|
call_count = 0
|
|
|
|
async def fake_acompletion(**kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise asyncio.TimeoutError()
|
|
return _fake_completion("fallback response")
|
|
|
|
with patch("litellm.acompletion", fake_acompletion):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
)
|
|
|
|
assert result.model_used == "claude-haiku-test"
|
|
assert "Timeout" in result.attempts[0].reason
|
|
|
|
async def test_fallback_on_api_exception(self, router_no_redis):
|
|
"""Primary raises a connection error → falls back."""
|
|
import litellm.exceptions
|
|
|
|
call_count = 0
|
|
|
|
async def fake_acompletion(**kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise litellm.exceptions.APIConnectionError(
|
|
message="Connection refused", llm_provider="ollama", model="qwen"
|
|
)
|
|
return _fake_completion("ok")
|
|
|
|
with patch("litellm.acompletion", fake_acompletion):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
)
|
|
|
|
assert result.model_used == "claude-haiku-test"
|
|
assert "APIConnectionError" in result.attempts[0].reason
|
|
|
|
async def test_all_models_fail_raises(self, router_no_redis):
|
|
"""All models return bad JSON → RuntimeError with attempt log."""
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("not json")),
|
|
):
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
schema={"type": "object"},
|
|
)
|
|
|
|
assert "All 3 models in chain failed" in str(exc_info.value)
|
|
|
|
async def test_schema_none_no_json_required(self, router_no_redis):
|
|
"""Without a schema, plain text responses are accepted."""
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("Here is your plan.")),
|
|
):
|
|
result = await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
)
|
|
|
|
assert result.content == "Here is your plan."
|
|
assert result.succeeded
|
|
|
|
async def test_metrics_recorded_on_success(self, router_no_redis):
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("ok")),
|
|
):
|
|
await router_no_redis.route([{"role": "user", "content": "x"}])
|
|
|
|
snap = router_no_redis.metrics.snapshot()
|
|
assert snap["local/qwen"]["success"] == 1
|
|
|
|
async def test_metrics_fallback_recorded(self, router_no_redis):
|
|
"""Primary fails → fallback → metrics show primary=fallback, haiku=success."""
|
|
responses = [
|
|
_fake_completion("I cannot help"), # refusal
|
|
_fake_completion("ok"),
|
|
]
|
|
idx = 0
|
|
|
|
async def fake_acompletion(**kwargs):
|
|
nonlocal idx
|
|
r = responses[idx]; idx += 1
|
|
return r
|
|
|
|
with patch("litellm.acompletion", fake_acompletion):
|
|
await router_no_redis.route([{"role": "user", "content": "x"}])
|
|
|
|
snap = router_no_redis.metrics.snapshot()
|
|
assert snap["local/qwen"]["fallback"] == 1
|
|
assert snap["claude-haiku-test"]["success"] == 1
|
|
|
|
async def test_context_label_in_logs(self, router_no_redis, caplog):
|
|
"""context= parameter appears in log output."""
|
|
import logging
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("ok")),
|
|
):
|
|
with caplog.at_level(logging.INFO, logger="llm_router"):
|
|
await router_no_redis.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
context="supervisor.reconcile",
|
|
)
|
|
|
|
assert any("supervisor.reconcile" in r.message for r in caplog.records)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLMRouter — Redis metrics publish
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
class TestLLMRouterRedis:
|
|
async def test_metrics_published_on_success(self):
|
|
chain = _chain_of(("m1", 8.0),)
|
|
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
|
|
|
mock_redis = AsyncMock()
|
|
router._redis = mock_redis
|
|
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("ok")),
|
|
):
|
|
await router.route([{"role": "user", "content": "x"}])
|
|
|
|
mock_redis.publish.assert_awaited_once()
|
|
channel, payload_str = mock_redis.publish.call_args[0]
|
|
assert channel == "llm_router_metrics"
|
|
payload = json.loads(payload_str)
|
|
assert payload["model_used"] == "m1"
|
|
assert "metrics_snapshot" in payload
|
|
assert "timestamp" in payload
|
|
|
|
async def test_redis_failure_does_not_raise(self):
|
|
"""A broken Redis must never break the LLM call result."""
|
|
chain = _chain_of(("m1", 8.0),)
|
|
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish.side_effect = ConnectionError("Redis down")
|
|
router._redis = mock_redis
|
|
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("ok")),
|
|
):
|
|
result = await router.route([{"role": "user", "content": "x"}])
|
|
|
|
assert result.succeeded # LLM call still returned
|
|
|
|
async def test_metrics_published_on_full_failure(self):
|
|
chain = _chain_of(("m1", 8.0),)
|
|
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
|
|
|
mock_redis = AsyncMock()
|
|
router._redis = mock_redis
|
|
|
|
with patch(
|
|
"litellm.acompletion",
|
|
AsyncMock(return_value=_fake_completion("not json")),
|
|
):
|
|
with pytest.raises(RuntimeError):
|
|
await router.route(
|
|
messages=[{"role": "user", "content": "x"}],
|
|
schema={"type": "object"},
|
|
)
|
|
|
|
# Metrics must still be published even when we raise
|
|
mock_redis.publish.assert_awaited_once()
|
|
_, payload_str = mock_redis.publish.call_args[0]
|
|
payload = json.loads(payload_str)
|
|
assert payload["model_used"] == "none"
|