feat(planner-agent): add llm_router.py with local-first fallback chain

services/planner-agent/src/llm_router.py:
- LLMRouter: async routing via litellm; chain = Qwen/Ollama → haiku → sonnet
- Timeouts: 8s local, 30s cloud; asyncio.wait_for belt-and-suspenders
- Rejection triggers: timeout, API error, refusal patterns, JSON schema fail
- JSON fence extraction: recovers valid JSON from  blocks
- ModelMetrics: per-model success/fallback/error counters + success_rate()
- Redis publish to 'llm_router_metrics' after every call (failure-safe)
- redis_url=None disables Redis (useful in tests / edge nodes)
- context= param adds caller label to all log lines for tracing

services/planner-agent/tests/test_llm_router.py:
- 34 tests, 0 network calls (litellm + Redis fully mocked)
- Covers: primary success, JSON error fallback, refusal fallback,
  timeout fallback, API exception fallback, all-fail RuntimeError,
  schema validation, fence extraction, metrics recording, Redis publish,
  Redis failure isolation

services/planner-agent/requirements.txt:
- litellm>=1.40.0, redis>=5.0.0, jsonschema>=4.21.0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Oskar Kapala 2026-05-27 18:38:06 +02:00
parent 603e10a364
commit 1bbc511bb7
3 changed files with 899 additions and 0 deletions

View file

@ -0,0 +1,3 @@
litellm>=1.40.0
redis[asyncio]>=5.0.0
jsonschema>=4.21.0

View file

@ -0,0 +1,447 @@
"""
llm_router.py LLM routing with local-first fallback chain.
Routing strategy:
1. Local Qwen via Ollama (piha:11434, timeout 8 s)
2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s)
3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s)
A model is rejected when it:
- times out
- raises any API / network exception
- returns text matching a refusal pattern
- returns JSON that fails the caller-supplied JSON Schema
After every call (success or full-chain failure) a metrics event is
published to the Redis channel "llm_router_metrics".
Usage
-----
router = LLMRouter()
result = await router.route(
messages=[{"role": "user", "content": "What should I do?"}],
schema={"type": "object", "required": ["action"], "properties": {...}},
)
print(result.model_used, result.content)
await router.close()
"""
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Optional
import litellm
import redis.asyncio as aioredis
from jsonschema import validate, ValidationError
litellm.suppress_debug_info = True
logger = logging.getLogger("llm_router")
# ---------------------------------------------------------------------------
# Refusal patterns — any substring match (case-insensitive) triggers fallback
# ---------------------------------------------------------------------------
REFUSAL_PATTERNS: list[str] = [
"nie wiem",
"I cannot",
"I can't",
"as an AI",
"I don't know",
"I'm not able",
"I am not able",
"I'm unable",
"I am unable",
"beyond my capabilities",
]
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ModelConfig:
"""Configuration for one model in the fallback chain."""
name: str # litellm model string, e.g. "ollama/qwen2.5:7b"
timeout: float # hard wall-clock timeout in seconds
api_base: Optional[str] = None # override API base URL (Ollama needs this)
extra_kwargs: dict = field(default_factory=dict)
def __str__(self) -> str:
base = f" @ {self.api_base}" if self.api_base else ""
return f"{self.name}{base} (timeout={self.timeout}s)"
@dataclass
class AttemptRecord:
model: str
outcome: str # "success" | "rejected" | "invalid"
reason: Optional[str] # None on success
latency_ms: int
@dataclass
class RouteResult:
"""Return value of LLMRouter.route()."""
content: Any # parsed JSON (if schema given) or raw str
raw_text: str
model_used: str # "none" if every model failed
attempts: list[AttemptRecord]
latency_ms: int # wall-clock from first attempt to return
@property
def succeeded(self) -> bool:
return self.model_used != "none"
def to_dict(self) -> dict:
return {
"model_used": self.model_used,
"latency_ms": self.latency_ms,
"attempts": [
{
"model": a.model,
"outcome": a.outcome,
"reason": a.reason,
"latency_ms": a.latency_ms,
}
for a in self.attempts
],
}
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
class ModelMetrics:
"""Thread-safe-ish counter per model × outcome.
Outcomes: "success", "fallback", "error"
("fallback" = rejected but another model succeeded after it;
"error" = rejected and it was the last in chain or chain exhausted)
"""
def __init__(self) -> None:
self._counts: dict[str, dict[str, int]] = {}
def record(self, model: str, outcome: str) -> None:
if model not in self._counts:
self._counts[model] = {"success": 0, "fallback": 0, "error": 0}
self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1
def snapshot(self) -> dict[str, dict[str, int]]:
return {m: dict(c) for m, c in self._counts.items()}
def total_calls(self, model: str) -> int:
return sum(self._counts.get(model, {}).values())
def success_rate(self, model: str) -> Optional[float]:
counts = self._counts.get(model, {})
total = sum(counts.values())
if total == 0:
return None
return counts.get("success", 0) / total
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
class LLMRouter:
"""Route LLM calls through a local-first fallback chain.
Parameters
----------
redis_url:
Redis connection URL for metrics publishing.
Set to None to disable Redis (useful in tests / local dev).
ollama_host:
Base URL of the Ollama API. Defaults to piha's Tailscale address.
ollama_model:
Model tag as known to Ollama (e.g. "qwen2.5:7b").
chain:
Override the entire fallback chain. When None the default
Qwen haiku sonnet chain is used.
"""
DEFAULT_OLLAMA_HOST = "http://100.108.208.3:11434"
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
DEFAULT_REDIS_URL = "redis://100.108.208.3:6379"
def __init__(
self,
redis_url: Optional[str] = DEFAULT_REDIS_URL,
ollama_host: str = DEFAULT_OLLAMA_HOST,
ollama_model: str = DEFAULT_OLLAMA_MODEL,
chain: Optional[list[ModelConfig]] = None,
) -> None:
if chain is not None:
self.chain = chain
else:
self.chain = [
ModelConfig(
name=f"ollama/{ollama_model}",
timeout=8.0,
api_base=ollama_host,
),
ModelConfig(
name="claude-haiku-4-5-20251001",
timeout=30.0,
),
ModelConfig(
name="claude-sonnet-4-6",
timeout=30.0,
),
]
self.metrics = ModelMetrics()
self._redis_url = redis_url
self._redis: Optional[aioredis.Redis] = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def route(
self,
messages: list[dict],
schema: Optional[dict] = None,
context: Optional[str] = None,
) -> RouteResult:
"""Try each model in order; return the first valid response.
Parameters
----------
messages:
OpenAI-style message list, e.g.
[{"role": "user", "content": "..."}]
schema:
Optional JSON Schema dict. When provided the model's response
must be valid JSON that conforms to the schema.
context:
Optional free-text caller label included in log lines (e.g.
"supervisor.reconcile") for easier tracing.
Raises
------
RuntimeError
When every model in the chain fails. The exception message
contains a JSON-formatted attempt log.
"""
tag = f"[{context}] " if context else ""
start = time.monotonic()
attempts: list[AttemptRecord] = []
for i, cfg in enumerate(self.chain):
is_last = i == len(self.chain) - 1
attempt_start = time.monotonic()
logger.info(
f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}"
)
raw_text, call_error = await self._call_model(cfg, messages)
attempt_ms = round((time.monotonic() - attempt_start) * 1000)
if call_error:
self.metrics.record(cfg.name, "error" if is_last else "fallback")
logger.warning(
f"{tag}[llm_router] {cfg.name} → rejected "
f"({call_error}) [{attempt_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="rejected",
reason=call_error, latency_ms=attempt_ms,
))
continue
parsed, schema_error = self._validate(raw_text, schema)
if schema_error:
self.metrics.record(cfg.name, "error" if is_last else "fallback")
logger.warning(
f"{tag}[llm_router] {cfg.name} → invalid "
f"({schema_error}) [{attempt_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="invalid",
reason=schema_error, latency_ms=attempt_ms,
))
continue
# ── success ───────────────────────────────────────────────
self.metrics.record(cfg.name, "success")
total_ms = round((time.monotonic() - start) * 1000)
logger.info(
f"{tag}[llm_router] {cfg.name} → success "
f"[attempt {attempt_ms}ms, total {total_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="success",
reason=None, latency_ms=attempt_ms,
))
result = RouteResult(
content=parsed,
raw_text=raw_text,
model_used=cfg.name,
attempts=attempts,
latency_ms=total_ms,
)
await self._publish_metrics(result)
return result
# ── all models exhausted ──────────────────────────────────────
total_ms = round((time.monotonic() - start) * 1000)
result = RouteResult(
content=None,
raw_text="",
model_used="none",
attempts=attempts,
latency_ms=total_ms,
)
await self._publish_metrics(result)
attempt_log = json.dumps(
[{"model": a.model, "reason": a.reason} for a in attempts],
indent=2,
)
raise RuntimeError(
f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n"
f"Attempts:\n{attempt_log}"
)
async def close(self) -> None:
"""Release the Redis connection."""
if self._redis is not None:
await self._redis.aclose()
self._redis = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _call_model(
self,
cfg: ModelConfig,
messages: list[dict],
) -> tuple[str, Optional[str]]:
"""Invoke one model. Returns (raw_text, error_reason|None)."""
kwargs: dict[str, Any] = {
"model": cfg.name,
"messages": messages,
"timeout": cfg.timeout,
**cfg.extra_kwargs,
}
if cfg.api_base:
kwargs["api_base"] = cfg.api_base
try:
# asyncio.wait_for as belt-and-suspenders — litellm timeout
# is passed to the underlying HTTP client, but asyncio task
# cancellation ensures we never block the event loop.
resp = await asyncio.wait_for(
litellm.acompletion(**kwargs),
timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead
)
text = (resp.choices[0].message.content or "").strip()
except asyncio.TimeoutError:
return "", f"Timeout after {cfg.timeout}s"
except litellm.exceptions.Timeout:
return "", f"Timeout after {cfg.timeout}s"
except litellm.exceptions.APIConnectionError as e:
return "", f"APIConnectionError: {e}"
except litellm.exceptions.AuthenticationError as e:
return "", f"AuthenticationError: {e}"
except Exception as e:
return "", f"{type(e).__name__}: {e}"
# Check for refusals in the model's own text
refusal = self._detect_refusal(text)
if refusal:
return text, f"RefusalPattern matched: '{refusal}'"
return text, None
@staticmethod
def _detect_refusal(text: str) -> Optional[str]:
"""Return the first matching refusal pattern, or None."""
lower = text.lower()
for pattern in REFUSAL_PATTERNS:
if pattern.lower() in lower:
return pattern
return None
@staticmethod
def _validate(
text: str,
schema: Optional[dict],
) -> tuple[Any, Optional[str]]:
"""Parse and validate the model response.
Returns (parsed_content, error_reason|None).
When schema is None, returns (raw_text, None) only refusal
detection (already done in _call_model) applies.
"""
if schema is None:
return text, None
try:
parsed = json.loads(text)
except json.JSONDecodeError as exc:
# Try to extract JSON from a markdown code fence
extracted = _extract_json_from_fence(text)
if extracted is not None:
parsed = extracted
else:
return None, f"JSONDecodeError: {exc}"
try:
validate(instance=parsed, schema=schema)
except ValidationError as exc:
return None, f"SchemaValidationError: {exc.message}"
return parsed, None
async def _get_redis(self) -> Optional[aioredis.Redis]:
if self._redis_url is None:
return None
if self._redis is None:
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2,
)
return self._redis
async def _publish_metrics(self, result: RouteResult) -> None:
"""Non-blocking publish to Redis channel 'llm_router_metrics'."""
payload = {
**result.to_dict(),
"metrics_snapshot": self.metrics.snapshot(),
"timestamp": time.time(),
}
try:
r = await self._get_redis()
if r is not None:
await r.publish("llm_router_metrics", json.dumps(payload))
except Exception as exc:
# Never let a metrics failure break the caller
logger.warning(f"[llm_router] metrics publish failed: {exc}")
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _extract_json_from_fence(text: str) -> Optional[Any]:
"""Extract JSON from a ```json ... ``` markdown code fence, if present."""
import re
match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return None

View file

@ -0,0 +1,449 @@
"""
Unit tests for llm_router.py.
All LLM and Redis calls are mocked no network required.
Run:
pip install pytest pytest-asyncio litellm jsonschema redis
pytest services/planner-agent/tests/test_llm_router.py -v
"""
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Allow importing from src/ without installation
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from llm_router import (
AttemptRecord,
LLMRouter,
ModelConfig,
ModelMetrics,
RouteResult,
_extract_json_from_fence,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_completion(content: str):
"""Build a minimal litellm-style response object."""
msg = MagicMock()
msg.content = content
choice = MagicMock()
choice.message = msg
resp = MagicMock()
resp.choices = [choice]
return resp
def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]:
"""Build a minimal test chain with no api_base."""
return [ModelConfig(name=name, timeout=timeout) for name, timeout in models]
# ---------------------------------------------------------------------------
# ModelMetrics
# ---------------------------------------------------------------------------
class TestModelMetrics:
def test_record_and_snapshot(self):
m = ModelMetrics()
m.record("qwen", "success")
m.record("qwen", "success")
m.record("qwen", "fallback")
m.record("haiku", "success")
snap = m.snapshot()
assert snap["qwen"]["success"] == 2
assert snap["qwen"]["fallback"] == 1
assert snap["haiku"]["success"] == 1
def test_success_rate(self):
m = ModelMetrics()
m.record("model-a", "success")
m.record("model-a", "fallback")
assert m.success_rate("model-a") == 0.5
def test_success_rate_unknown_model(self):
m = ModelMetrics()
assert m.success_rate("ghost") is None
def test_total_calls(self):
m = ModelMetrics()
m.record("x", "success")
m.record("x", "error")
assert m.total_calls("x") == 2
def test_snapshot_is_copy(self):
m = ModelMetrics()
m.record("x", "success")
snap = m.snapshot()
snap["x"]["success"] = 999 # mutate the copy
assert m.snapshot()["x"]["success"] == 1 # original unchanged
# ---------------------------------------------------------------------------
# RouteResult
# ---------------------------------------------------------------------------
class TestRouteResult:
def test_succeeded(self):
r = RouteResult("hello", "hello", "model-a", [], 100)
assert r.succeeded is True
def test_not_succeeded(self):
r = RouteResult(None, "", "none", [], 100)
assert r.succeeded is False
def test_to_dict_structure(self):
attempt = AttemptRecord("m", "success", None, 50)
r = RouteResult("x", "x", "m", [attempt], 50)
d = r.to_dict()
assert d["model_used"] == "m"
assert len(d["attempts"]) == 1
assert d["attempts"][0]["outcome"] == "success"
# ---------------------------------------------------------------------------
# _extract_json_from_fence
# ---------------------------------------------------------------------------
class TestExtractJsonFromFence:
def test_json_fence(self):
text = 'Sure!\n```json\n{"a": 1}\n```\nDone.'
assert _extract_json_from_fence(text) == {"a": 1}
def test_plain_fence(self):
text = "```\n[1, 2]\n```"
assert _extract_json_from_fence(text) == [1, 2]
def test_no_fence(self):
assert _extract_json_from_fence("no json here") is None
def test_broken_fence(self):
assert _extract_json_from_fence("```json\n{broken```") is None
# ---------------------------------------------------------------------------
# LLMRouter — validation & refusal detection
# ---------------------------------------------------------------------------
class TestLLMRouterValidation:
def setup_method(self):
self.router = LLMRouter(redis_url=None)
def test_validate_no_schema_returns_text(self):
parsed, err = self.router._validate("hello world", schema=None)
assert parsed == "hello world"
assert err is None
def test_validate_valid_json(self):
schema = {"type": "object", "required": ["action"]}
parsed, err = self.router._validate('{"action": "redeploy"}', schema)
assert err is None
assert parsed == {"action": "redeploy"}
def test_validate_invalid_json(self):
schema = {"type": "object"}
_, err = self.router._validate("not json {", schema)
assert err is not None
assert "JSONDecodeError" in err
def test_validate_schema_violation(self):
schema = {"type": "object", "required": ["action"]}
_, err = self.router._validate('{"other": 1}', schema)
assert err is not None
assert "SchemaValidationError" in err
def test_validate_extracts_fenced_json(self):
schema = {"type": "object", "required": ["action"]}
text = '```json\n{"action": "restart"}\n```'
parsed, err = self.router._validate(text, schema)
assert err is None
assert parsed == {"action": "restart"}
def test_detect_refusal_nie_wiem(self):
assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem"
def test_detect_refusal_as_an_ai(self):
# Text contains both "as an AI" and "I cannot"; first match wins.
# We only assert that a refusal IS detected, not which pattern fires.
assert self.router._detect_refusal("As an AI I cannot help") is not None
def test_detect_refusal_none(self):
assert self.router._detect_refusal("Sure, here is the action.") is None
def test_detect_refusal_case_insensitive(self):
assert self.router._detect_refusal("I CANNOT do that") == "I cannot"
# ---------------------------------------------------------------------------
# LLMRouter — routing logic (mocked litellm + Redis)
# ---------------------------------------------------------------------------
@pytest.fixture
def router_no_redis():
"""Router with a 3-model chain and Redis disabled."""
chain = _chain_of(
("local/qwen", 8.0),
("claude-haiku-test", 30.0),
("claude-sonnet-test", 30.0),
)
return LLMRouter(redis_url=None, chain=chain)
@pytest.mark.asyncio
class TestLLMRouterRouting:
async def test_primary_success(self, router_no_redis):
"""Primary model succeeds — no fallback."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion('{"action": "ok"}')),
):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "test"}],
schema={"type": "object", "required": ["action"]},
)
assert result.model_used == "local/qwen"
assert result.content == {"action": "ok"}
assert len(result.attempts) == 1
assert result.attempts[0].outcome == "success"
async def test_fallback_on_json_error(self, router_no_redis):
"""Primary returns bad JSON → falls back to haiku which returns valid JSON."""
responses = [
_fake_completion("not json at all"),
_fake_completion('{"action": "restart"}'),
]
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
r = responses[call_count]
call_count += 1
return r
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object", "required": ["action"]},
)
assert result.model_used == "claude-haiku-test"
assert result.content == {"action": "restart"}
assert result.attempts[0].outcome == "invalid"
assert result.attempts[1].outcome == "success"
async def test_fallback_on_refusal(self, router_no_redis):
"""Primary returns refusal text → falls back."""
responses = [
_fake_completion("I cannot help with that request."),
_fake_completion("Sure! Here is the plan."),
]
idx = 0
async def fake_acompletion(**kwargs):
nonlocal idx
r = responses[idx]
idx += 1
return r
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert result.attempts[0].outcome == "rejected"
assert "RefusalPattern" in result.attempts[0].reason
async def test_fallback_on_timeout(self, router_no_redis):
"""Primary times out → falls back to haiku."""
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise asyncio.TimeoutError()
return _fake_completion("fallback response")
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert "Timeout" in result.attempts[0].reason
async def test_fallback_on_api_exception(self, router_no_redis):
"""Primary raises a connection error → falls back."""
import litellm.exceptions
call_count = 0
async def fake_acompletion(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise litellm.exceptions.APIConnectionError(
message="Connection refused", llm_provider="ollama", model="qwen"
)
return _fake_completion("ok")
with patch("litellm.acompletion", fake_acompletion):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.model_used == "claude-haiku-test"
assert "APIConnectionError" in result.attempts[0].reason
async def test_all_models_fail_raises(self, router_no_redis):
"""All models return bad JSON → RuntimeError with attempt log."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("not json")),
):
with pytest.raises(RuntimeError) as exc_info:
await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object"},
)
assert "All 3 models in chain failed" in str(exc_info.value)
async def test_schema_none_no_json_required(self, router_no_redis):
"""Without a schema, plain text responses are accepted."""
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("Here is your plan.")),
):
result = await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
)
assert result.content == "Here is your plan."
assert result.succeeded
async def test_metrics_recorded_on_success(self, router_no_redis):
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
await router_no_redis.route([{"role": "user", "content": "x"}])
snap = router_no_redis.metrics.snapshot()
assert snap["local/qwen"]["success"] == 1
async def test_metrics_fallback_recorded(self, router_no_redis):
"""Primary fails → fallback → metrics show primary=fallback, haiku=success."""
responses = [
_fake_completion("I cannot help"), # refusal
_fake_completion("ok"),
]
idx = 0
async def fake_acompletion(**kwargs):
nonlocal idx
r = responses[idx]; idx += 1
return r
with patch("litellm.acompletion", fake_acompletion):
await router_no_redis.route([{"role": "user", "content": "x"}])
snap = router_no_redis.metrics.snapshot()
assert snap["local/qwen"]["fallback"] == 1
assert snap["claude-haiku-test"]["success"] == 1
async def test_context_label_in_logs(self, router_no_redis, caplog):
"""context= parameter appears in log output."""
import logging
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
with caplog.at_level(logging.INFO, logger="llm_router"):
await router_no_redis.route(
messages=[{"role": "user", "content": "x"}],
context="supervisor.reconcile",
)
assert any("supervisor.reconcile" in r.message for r in caplog.records)
# ---------------------------------------------------------------------------
# LLMRouter — Redis metrics publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestLLMRouterRedis:
async def test_metrics_published_on_success(self):
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
await router.route([{"role": "user", "content": "x"}])
mock_redis.publish.assert_awaited_once()
channel, payload_str = mock_redis.publish.call_args[0]
assert channel == "llm_router_metrics"
payload = json.loads(payload_str)
assert payload["model_used"] == "m1"
assert "metrics_snapshot" in payload
assert "timestamp" in payload
async def test_redis_failure_does_not_raise(self):
"""A broken Redis must never break the LLM call result."""
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
mock_redis.publish.side_effect = ConnectionError("Redis down")
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("ok")),
):
result = await router.route([{"role": "user", "content": "x"}])
assert result.succeeded # LLM call still returned
async def test_metrics_published_on_full_failure(self):
chain = _chain_of(("m1", 8.0),)
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
mock_redis = AsyncMock()
router._redis = mock_redis
with patch(
"litellm.acompletion",
AsyncMock(return_value=_fake_completion("not json")),
):
with pytest.raises(RuntimeError):
await router.route(
messages=[{"role": "user", "content": "x"}],
schema={"type": "object"},
)
# Metrics must still be published even when we raise
mock_redis.publish.assert_awaited_once()
_, payload_str = mock_redis.publish.call_args[0]
payload = json.loads(payload_str)
assert payload["model_used"] == "none"