feat(planner-agent): add llm_router.py with local-first fallback chain
services/planner-agent/src/llm_router.py: - LLMRouter: async routing via litellm; chain = Qwen/Ollama → haiku → sonnet - Timeouts: 8s local, 30s cloud; asyncio.wait_for belt-and-suspenders - Rejection triggers: timeout, API error, refusal patterns, JSON schema fail - JSON fence extraction: recovers valid JSON from blocks - ModelMetrics: per-model success/fallback/error counters + success_rate() - Redis publish to 'llm_router_metrics' after every call (failure-safe) - redis_url=None disables Redis (useful in tests / edge nodes) - context= param adds caller label to all log lines for tracing services/planner-agent/tests/test_llm_router.py: - 34 tests, 0 network calls (litellm + Redis fully mocked) - Covers: primary success, JSON error fallback, refusal fallback, timeout fallback, API exception fallback, all-fail RuntimeError, schema validation, fence extraction, metrics recording, Redis publish, Redis failure isolation services/planner-agent/requirements.txt: - litellm>=1.40.0, redis>=5.0.0, jsonschema>=4.21.0 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
603e10a364
commit
1bbc511bb7
3
services/planner-agent/requirements.txt
Normal file
3
services/planner-agent/requirements.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
litellm>=1.40.0
|
||||
redis[asyncio]>=5.0.0
|
||||
jsonschema>=4.21.0
|
||||
447
services/planner-agent/src/llm_router.py
Normal file
447
services/planner-agent/src/llm_router.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""
|
||||
llm_router.py — LLM routing with local-first fallback chain.
|
||||
|
||||
Routing strategy:
|
||||
1. Local Qwen via Ollama (piha:11434, timeout 8 s)
|
||||
2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s)
|
||||
3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s)
|
||||
|
||||
A model is rejected when it:
|
||||
- times out
|
||||
- raises any API / network exception
|
||||
- returns text matching a refusal pattern
|
||||
- returns JSON that fails the caller-supplied JSON Schema
|
||||
|
||||
After every call (success or full-chain failure) a metrics event is
|
||||
published to the Redis channel "llm_router_metrics".
|
||||
|
||||
Usage
|
||||
-----
|
||||
router = LLMRouter()
|
||||
result = await router.route(
|
||||
messages=[{"role": "user", "content": "What should I do?"}],
|
||||
schema={"type": "object", "required": ["action"], "properties": {...}},
|
||||
)
|
||||
print(result.model_used, result.content)
|
||||
await router.close()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
import redis.asyncio as aioredis
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
logger = logging.getLogger("llm_router")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refusal patterns — any substring match (case-insensitive) triggers fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
REFUSAL_PATTERNS: list[str] = [
|
||||
"nie wiem",
|
||||
"I cannot",
|
||||
"I can't",
|
||||
"as an AI",
|
||||
"I don't know",
|
||||
"I'm not able",
|
||||
"I am not able",
|
||||
"I'm unable",
|
||||
"I am unable",
|
||||
"beyond my capabilities",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for one model in the fallback chain."""
|
||||
name: str # litellm model string, e.g. "ollama/qwen2.5:7b"
|
||||
timeout: float # hard wall-clock timeout in seconds
|
||||
api_base: Optional[str] = None # override API base URL (Ollama needs this)
|
||||
extra_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = f" @ {self.api_base}" if self.api_base else ""
|
||||
return f"{self.name}{base} (timeout={self.timeout}s)"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttemptRecord:
|
||||
model: str
|
||||
outcome: str # "success" | "rejected" | "invalid"
|
||||
reason: Optional[str] # None on success
|
||||
latency_ms: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteResult:
|
||||
"""Return value of LLMRouter.route()."""
|
||||
content: Any # parsed JSON (if schema given) or raw str
|
||||
raw_text: str
|
||||
model_used: str # "none" if every model failed
|
||||
attempts: list[AttemptRecord]
|
||||
latency_ms: int # wall-clock from first attempt to return
|
||||
|
||||
@property
|
||||
def succeeded(self) -> bool:
|
||||
return self.model_used != "none"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model_used": self.model_used,
|
||||
"latency_ms": self.latency_ms,
|
||||
"attempts": [
|
||||
{
|
||||
"model": a.model,
|
||||
"outcome": a.outcome,
|
||||
"reason": a.reason,
|
||||
"latency_ms": a.latency_ms,
|
||||
}
|
||||
for a in self.attempts
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ModelMetrics:
|
||||
"""Thread-safe-ish counter per model × outcome.
|
||||
|
||||
Outcomes: "success", "fallback", "error"
|
||||
("fallback" = rejected but another model succeeded after it;
|
||||
"error" = rejected and it was the last in chain or chain exhausted)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._counts: dict[str, dict[str, int]] = {}
|
||||
|
||||
def record(self, model: str, outcome: str) -> None:
|
||||
if model not in self._counts:
|
||||
self._counts[model] = {"success": 0, "fallback": 0, "error": 0}
|
||||
self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1
|
||||
|
||||
def snapshot(self) -> dict[str, dict[str, int]]:
|
||||
return {m: dict(c) for m, c in self._counts.items()}
|
||||
|
||||
def total_calls(self, model: str) -> int:
|
||||
return sum(self._counts.get(model, {}).values())
|
||||
|
||||
def success_rate(self, model: str) -> Optional[float]:
|
||||
counts = self._counts.get(model, {})
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
return None
|
||||
return counts.get("success", 0) / total
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LLMRouter:
|
||||
"""Route LLM calls through a local-first fallback chain.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
redis_url:
|
||||
Redis connection URL for metrics publishing.
|
||||
Set to None to disable Redis (useful in tests / local dev).
|
||||
ollama_host:
|
||||
Base URL of the Ollama API. Defaults to piha's Tailscale address.
|
||||
ollama_model:
|
||||
Model tag as known to Ollama (e.g. "qwen2.5:7b").
|
||||
chain:
|
||||
Override the entire fallback chain. When None the default
|
||||
Qwen → haiku → sonnet chain is used.
|
||||
"""
|
||||
|
||||
DEFAULT_OLLAMA_HOST = "http://100.108.208.3:11434"
|
||||
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
|
||||
DEFAULT_REDIS_URL = "redis://100.108.208.3:6379"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: Optional[str] = DEFAULT_REDIS_URL,
|
||||
ollama_host: str = DEFAULT_OLLAMA_HOST,
|
||||
ollama_model: str = DEFAULT_OLLAMA_MODEL,
|
||||
chain: Optional[list[ModelConfig]] = None,
|
||||
) -> None:
|
||||
if chain is not None:
|
||||
self.chain = chain
|
||||
else:
|
||||
self.chain = [
|
||||
ModelConfig(
|
||||
name=f"ollama/{ollama_model}",
|
||||
timeout=8.0,
|
||||
api_base=ollama_host,
|
||||
),
|
||||
ModelConfig(
|
||||
name="claude-haiku-4-5-20251001",
|
||||
timeout=30.0,
|
||||
),
|
||||
ModelConfig(
|
||||
name="claude-sonnet-4-6",
|
||||
timeout=30.0,
|
||||
),
|
||||
]
|
||||
|
||||
self.metrics = ModelMetrics()
|
||||
self._redis_url = redis_url
|
||||
self._redis: Optional[aioredis.Redis] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def route(
|
||||
self,
|
||||
messages: list[dict],
|
||||
schema: Optional[dict] = None,
|
||||
context: Optional[str] = None,
|
||||
) -> RouteResult:
|
||||
"""Try each model in order; return the first valid response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages:
|
||||
OpenAI-style message list, e.g.
|
||||
[{"role": "user", "content": "..."}]
|
||||
schema:
|
||||
Optional JSON Schema dict. When provided the model's response
|
||||
must be valid JSON that conforms to the schema.
|
||||
context:
|
||||
Optional free-text caller label included in log lines (e.g.
|
||||
"supervisor.reconcile") for easier tracing.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
When every model in the chain fails. The exception message
|
||||
contains a JSON-formatted attempt log.
|
||||
"""
|
||||
tag = f"[{context}] " if context else ""
|
||||
start = time.monotonic()
|
||||
attempts: list[AttemptRecord] = []
|
||||
|
||||
for i, cfg in enumerate(self.chain):
|
||||
is_last = i == len(self.chain) - 1
|
||||
attempt_start = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}"
|
||||
)
|
||||
|
||||
raw_text, call_error = await self._call_model(cfg, messages)
|
||||
attempt_ms = round((time.monotonic() - attempt_start) * 1000)
|
||||
|
||||
if call_error:
|
||||
self.metrics.record(cfg.name, "error" if is_last else "fallback")
|
||||
logger.warning(
|
||||
f"{tag}[llm_router] {cfg.name} → rejected "
|
||||
f"({call_error}) [{attempt_ms}ms]"
|
||||
)
|
||||
attempts.append(AttemptRecord(
|
||||
model=cfg.name, outcome="rejected",
|
||||
reason=call_error, latency_ms=attempt_ms,
|
||||
))
|
||||
continue
|
||||
|
||||
parsed, schema_error = self._validate(raw_text, schema)
|
||||
if schema_error:
|
||||
self.metrics.record(cfg.name, "error" if is_last else "fallback")
|
||||
logger.warning(
|
||||
f"{tag}[llm_router] {cfg.name} → invalid "
|
||||
f"({schema_error}) [{attempt_ms}ms]"
|
||||
)
|
||||
attempts.append(AttemptRecord(
|
||||
model=cfg.name, outcome="invalid",
|
||||
reason=schema_error, latency_ms=attempt_ms,
|
||||
))
|
||||
continue
|
||||
|
||||
# ── success ───────────────────────────────────────────────
|
||||
self.metrics.record(cfg.name, "success")
|
||||
total_ms = round((time.monotonic() - start) * 1000)
|
||||
logger.info(
|
||||
f"{tag}[llm_router] {cfg.name} → success "
|
||||
f"[attempt {attempt_ms}ms, total {total_ms}ms]"
|
||||
)
|
||||
attempts.append(AttemptRecord(
|
||||
model=cfg.name, outcome="success",
|
||||
reason=None, latency_ms=attempt_ms,
|
||||
))
|
||||
result = RouteResult(
|
||||
content=parsed,
|
||||
raw_text=raw_text,
|
||||
model_used=cfg.name,
|
||||
attempts=attempts,
|
||||
latency_ms=total_ms,
|
||||
)
|
||||
await self._publish_metrics(result)
|
||||
return result
|
||||
|
||||
# ── all models exhausted ──────────────────────────────────────
|
||||
total_ms = round((time.monotonic() - start) * 1000)
|
||||
result = RouteResult(
|
||||
content=None,
|
||||
raw_text="",
|
||||
model_used="none",
|
||||
attempts=attempts,
|
||||
latency_ms=total_ms,
|
||||
)
|
||||
await self._publish_metrics(result)
|
||||
|
||||
attempt_log = json.dumps(
|
||||
[{"model": a.model, "reason": a.reason} for a in attempts],
|
||||
indent=2,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n"
|
||||
f"Attempts:\n{attempt_log}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release the Redis connection."""
|
||||
if self._redis is not None:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _call_model(
|
||||
self,
|
||||
cfg: ModelConfig,
|
||||
messages: list[dict],
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Invoke one model. Returns (raw_text, error_reason|None)."""
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": cfg.name,
|
||||
"messages": messages,
|
||||
"timeout": cfg.timeout,
|
||||
**cfg.extra_kwargs,
|
||||
}
|
||||
if cfg.api_base:
|
||||
kwargs["api_base"] = cfg.api_base
|
||||
|
||||
try:
|
||||
# asyncio.wait_for as belt-and-suspenders — litellm timeout
|
||||
# is passed to the underlying HTTP client, but asyncio task
|
||||
# cancellation ensures we never block the event loop.
|
||||
resp = await asyncio.wait_for(
|
||||
litellm.acompletion(**kwargs),
|
||||
timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead
|
||||
)
|
||||
text = (resp.choices[0].message.content or "").strip()
|
||||
except asyncio.TimeoutError:
|
||||
return "", f"Timeout after {cfg.timeout}s"
|
||||
except litellm.exceptions.Timeout:
|
||||
return "", f"Timeout after {cfg.timeout}s"
|
||||
except litellm.exceptions.APIConnectionError as e:
|
||||
return "", f"APIConnectionError: {e}"
|
||||
except litellm.exceptions.AuthenticationError as e:
|
||||
return "", f"AuthenticationError: {e}"
|
||||
except Exception as e:
|
||||
return "", f"{type(e).__name__}: {e}"
|
||||
|
||||
# Check for refusals in the model's own text
|
||||
refusal = self._detect_refusal(text)
|
||||
if refusal:
|
||||
return text, f"RefusalPattern matched: '{refusal}'"
|
||||
return text, None
|
||||
|
||||
@staticmethod
|
||||
def _detect_refusal(text: str) -> Optional[str]:
|
||||
"""Return the first matching refusal pattern, or None."""
|
||||
lower = text.lower()
|
||||
for pattern in REFUSAL_PATTERNS:
|
||||
if pattern.lower() in lower:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate(
|
||||
text: str,
|
||||
schema: Optional[dict],
|
||||
) -> tuple[Any, Optional[str]]:
|
||||
"""Parse and validate the model response.
|
||||
|
||||
Returns (parsed_content, error_reason|None).
|
||||
When schema is None, returns (raw_text, None) — only refusal
|
||||
detection (already done in _call_model) applies.
|
||||
"""
|
||||
if schema is None:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
# Try to extract JSON from a markdown code fence
|
||||
extracted = _extract_json_from_fence(text)
|
||||
if extracted is not None:
|
||||
parsed = extracted
|
||||
else:
|
||||
return None, f"JSONDecodeError: {exc}"
|
||||
|
||||
try:
|
||||
validate(instance=parsed, schema=schema)
|
||||
except ValidationError as exc:
|
||||
return None, f"SchemaValidationError: {exc.message}"
|
||||
|
||||
return parsed, None
|
||||
|
||||
async def _get_redis(self) -> Optional[aioredis.Redis]:
|
||||
if self._redis_url is None:
|
||||
return None
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=2,
|
||||
socket_timeout=2,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def _publish_metrics(self, result: RouteResult) -> None:
|
||||
"""Non-blocking publish to Redis channel 'llm_router_metrics'."""
|
||||
payload = {
|
||||
**result.to_dict(),
|
||||
"metrics_snapshot": self.metrics.snapshot(),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
try:
|
||||
r = await self._get_redis()
|
||||
if r is not None:
|
||||
await r.publish("llm_router_metrics", json.dumps(payload))
|
||||
except Exception as exc:
|
||||
# Never let a metrics failure break the caller
|
||||
logger.warning(f"[llm_router] metrics publish failed: {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_json_from_fence(text: str) -> Optional[Any]:
|
||||
"""Extract JSON from a ```json ... ``` markdown code fence, if present."""
|
||||
import re
|
||||
match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return None
|
||||
449
services/planner-agent/tests/test_llm_router.py
Normal file
449
services/planner-agent/tests/test_llm_router.py
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
"""
|
||||
Unit tests for llm_router.py.
|
||||
|
||||
All LLM and Redis calls are mocked — no network required.
|
||||
|
||||
Run:
|
||||
pip install pytest pytest-asyncio litellm jsonschema redis
|
||||
pytest services/planner-agent/tests/test_llm_router.py -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Allow importing from src/ without installation
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from llm_router import (
|
||||
AttemptRecord,
|
||||
LLMRouter,
|
||||
ModelConfig,
|
||||
ModelMetrics,
|
||||
RouteResult,
|
||||
_extract_json_from_fence,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _fake_completion(content: str):
|
||||
"""Build a minimal litellm-style response object."""
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
choice = MagicMock()
|
||||
choice.message = msg
|
||||
resp = MagicMock()
|
||||
resp.choices = [choice]
|
||||
return resp
|
||||
|
||||
|
||||
def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]:
|
||||
"""Build a minimal test chain with no api_base."""
|
||||
return [ModelConfig(name=name, timeout=timeout) for name, timeout in models]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ModelMetrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelMetrics:
|
||||
def test_record_and_snapshot(self):
|
||||
m = ModelMetrics()
|
||||
m.record("qwen", "success")
|
||||
m.record("qwen", "success")
|
||||
m.record("qwen", "fallback")
|
||||
m.record("haiku", "success")
|
||||
|
||||
snap = m.snapshot()
|
||||
assert snap["qwen"]["success"] == 2
|
||||
assert snap["qwen"]["fallback"] == 1
|
||||
assert snap["haiku"]["success"] == 1
|
||||
|
||||
def test_success_rate(self):
|
||||
m = ModelMetrics()
|
||||
m.record("model-a", "success")
|
||||
m.record("model-a", "fallback")
|
||||
assert m.success_rate("model-a") == 0.5
|
||||
|
||||
def test_success_rate_unknown_model(self):
|
||||
m = ModelMetrics()
|
||||
assert m.success_rate("ghost") is None
|
||||
|
||||
def test_total_calls(self):
|
||||
m = ModelMetrics()
|
||||
m.record("x", "success")
|
||||
m.record("x", "error")
|
||||
assert m.total_calls("x") == 2
|
||||
|
||||
def test_snapshot_is_copy(self):
|
||||
m = ModelMetrics()
|
||||
m.record("x", "success")
|
||||
snap = m.snapshot()
|
||||
snap["x"]["success"] = 999 # mutate the copy
|
||||
assert m.snapshot()["x"]["success"] == 1 # original unchanged
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RouteResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRouteResult:
|
||||
def test_succeeded(self):
|
||||
r = RouteResult("hello", "hello", "model-a", [], 100)
|
||||
assert r.succeeded is True
|
||||
|
||||
def test_not_succeeded(self):
|
||||
r = RouteResult(None, "", "none", [], 100)
|
||||
assert r.succeeded is False
|
||||
|
||||
def test_to_dict_structure(self):
|
||||
attempt = AttemptRecord("m", "success", None, 50)
|
||||
r = RouteResult("x", "x", "m", [attempt], 50)
|
||||
d = r.to_dict()
|
||||
assert d["model_used"] == "m"
|
||||
assert len(d["attempts"]) == 1
|
||||
assert d["attempts"][0]["outcome"] == "success"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_json_from_fence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractJsonFromFence:
|
||||
def test_json_fence(self):
|
||||
text = 'Sure!\n```json\n{"a": 1}\n```\nDone.'
|
||||
assert _extract_json_from_fence(text) == {"a": 1}
|
||||
|
||||
def test_plain_fence(self):
|
||||
text = "```\n[1, 2]\n```"
|
||||
assert _extract_json_from_fence(text) == [1, 2]
|
||||
|
||||
def test_no_fence(self):
|
||||
assert _extract_json_from_fence("no json here") is None
|
||||
|
||||
def test_broken_fence(self):
|
||||
assert _extract_json_from_fence("```json\n{broken```") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMRouter — validation & refusal detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLLMRouterValidation:
|
||||
def setup_method(self):
|
||||
self.router = LLMRouter(redis_url=None)
|
||||
|
||||
def test_validate_no_schema_returns_text(self):
|
||||
parsed, err = self.router._validate("hello world", schema=None)
|
||||
assert parsed == "hello world"
|
||||
assert err is None
|
||||
|
||||
def test_validate_valid_json(self):
|
||||
schema = {"type": "object", "required": ["action"]}
|
||||
parsed, err = self.router._validate('{"action": "redeploy"}', schema)
|
||||
assert err is None
|
||||
assert parsed == {"action": "redeploy"}
|
||||
|
||||
def test_validate_invalid_json(self):
|
||||
schema = {"type": "object"}
|
||||
_, err = self.router._validate("not json {", schema)
|
||||
assert err is not None
|
||||
assert "JSONDecodeError" in err
|
||||
|
||||
def test_validate_schema_violation(self):
|
||||
schema = {"type": "object", "required": ["action"]}
|
||||
_, err = self.router._validate('{"other": 1}', schema)
|
||||
assert err is not None
|
||||
assert "SchemaValidationError" in err
|
||||
|
||||
def test_validate_extracts_fenced_json(self):
|
||||
schema = {"type": "object", "required": ["action"]}
|
||||
text = '```json\n{"action": "restart"}\n```'
|
||||
parsed, err = self.router._validate(text, schema)
|
||||
assert err is None
|
||||
assert parsed == {"action": "restart"}
|
||||
|
||||
def test_detect_refusal_nie_wiem(self):
|
||||
assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem"
|
||||
|
||||
def test_detect_refusal_as_an_ai(self):
|
||||
# Text contains both "as an AI" and "I cannot"; first match wins.
|
||||
# We only assert that a refusal IS detected, not which pattern fires.
|
||||
assert self.router._detect_refusal("As an AI I cannot help") is not None
|
||||
|
||||
def test_detect_refusal_none(self):
|
||||
assert self.router._detect_refusal("Sure, here is the action.") is None
|
||||
|
||||
def test_detect_refusal_case_insensitive(self):
|
||||
assert self.router._detect_refusal("I CANNOT do that") == "I cannot"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMRouter — routing logic (mocked litellm + Redis)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def router_no_redis():
|
||||
"""Router with a 3-model chain and Redis disabled."""
|
||||
chain = _chain_of(
|
||||
("local/qwen", 8.0),
|
||||
("claude-haiku-test", 30.0),
|
||||
("claude-sonnet-test", 30.0),
|
||||
)
|
||||
return LLMRouter(redis_url=None, chain=chain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLLMRouterRouting:
|
||||
async def test_primary_success(self, router_no_redis):
|
||||
"""Primary model succeeds — no fallback."""
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion('{"action": "ok"}')),
|
||||
):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
schema={"type": "object", "required": ["action"]},
|
||||
)
|
||||
|
||||
assert result.model_used == "local/qwen"
|
||||
assert result.content == {"action": "ok"}
|
||||
assert len(result.attempts) == 1
|
||||
assert result.attempts[0].outcome == "success"
|
||||
|
||||
async def test_fallback_on_json_error(self, router_no_redis):
|
||||
"""Primary returns bad JSON → falls back to haiku which returns valid JSON."""
|
||||
responses = [
|
||||
_fake_completion("not json at all"),
|
||||
_fake_completion('{"action": "restart"}'),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal call_count
|
||||
r = responses[call_count]
|
||||
call_count += 1
|
||||
return r
|
||||
|
||||
with patch("litellm.acompletion", fake_acompletion):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
schema={"type": "object", "required": ["action"]},
|
||||
)
|
||||
|
||||
assert result.model_used == "claude-haiku-test"
|
||||
assert result.content == {"action": "restart"}
|
||||
assert result.attempts[0].outcome == "invalid"
|
||||
assert result.attempts[1].outcome == "success"
|
||||
|
||||
async def test_fallback_on_refusal(self, router_no_redis):
|
||||
"""Primary returns refusal text → falls back."""
|
||||
responses = [
|
||||
_fake_completion("I cannot help with that request."),
|
||||
_fake_completion("Sure! Here is the plan."),
|
||||
]
|
||||
idx = 0
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal idx
|
||||
r = responses[idx]
|
||||
idx += 1
|
||||
return r
|
||||
|
||||
with patch("litellm.acompletion", fake_acompletion):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
)
|
||||
|
||||
assert result.model_used == "claude-haiku-test"
|
||||
assert result.attempts[0].outcome == "rejected"
|
||||
assert "RefusalPattern" in result.attempts[0].reason
|
||||
|
||||
async def test_fallback_on_timeout(self, router_no_redis):
|
||||
"""Primary times out → falls back to haiku."""
|
||||
call_count = 0
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise asyncio.TimeoutError()
|
||||
return _fake_completion("fallback response")
|
||||
|
||||
with patch("litellm.acompletion", fake_acompletion):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
)
|
||||
|
||||
assert result.model_used == "claude-haiku-test"
|
||||
assert "Timeout" in result.attempts[0].reason
|
||||
|
||||
async def test_fallback_on_api_exception(self, router_no_redis):
|
||||
"""Primary raises a connection error → falls back."""
|
||||
import litellm.exceptions
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise litellm.exceptions.APIConnectionError(
|
||||
message="Connection refused", llm_provider="ollama", model="qwen"
|
||||
)
|
||||
return _fake_completion("ok")
|
||||
|
||||
with patch("litellm.acompletion", fake_acompletion):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
)
|
||||
|
||||
assert result.model_used == "claude-haiku-test"
|
||||
assert "APIConnectionError" in result.attempts[0].reason
|
||||
|
||||
async def test_all_models_fail_raises(self, router_no_redis):
|
||||
"""All models return bad JSON → RuntimeError with attempt log."""
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("not json")),
|
||||
):
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
schema={"type": "object"},
|
||||
)
|
||||
|
||||
assert "All 3 models in chain failed" in str(exc_info.value)
|
||||
|
||||
async def test_schema_none_no_json_required(self, router_no_redis):
|
||||
"""Without a schema, plain text responses are accepted."""
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("Here is your plan.")),
|
||||
):
|
||||
result = await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
)
|
||||
|
||||
assert result.content == "Here is your plan."
|
||||
assert result.succeeded
|
||||
|
||||
async def test_metrics_recorded_on_success(self, router_no_redis):
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("ok")),
|
||||
):
|
||||
await router_no_redis.route([{"role": "user", "content": "x"}])
|
||||
|
||||
snap = router_no_redis.metrics.snapshot()
|
||||
assert snap["local/qwen"]["success"] == 1
|
||||
|
||||
async def test_metrics_fallback_recorded(self, router_no_redis):
|
||||
"""Primary fails → fallback → metrics show primary=fallback, haiku=success."""
|
||||
responses = [
|
||||
_fake_completion("I cannot help"), # refusal
|
||||
_fake_completion("ok"),
|
||||
]
|
||||
idx = 0
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal idx
|
||||
r = responses[idx]; idx += 1
|
||||
return r
|
||||
|
||||
with patch("litellm.acompletion", fake_acompletion):
|
||||
await router_no_redis.route([{"role": "user", "content": "x"}])
|
||||
|
||||
snap = router_no_redis.metrics.snapshot()
|
||||
assert snap["local/qwen"]["fallback"] == 1
|
||||
assert snap["claude-haiku-test"]["success"] == 1
|
||||
|
||||
async def test_context_label_in_logs(self, router_no_redis, caplog):
|
||||
"""context= parameter appears in log output."""
|
||||
import logging
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("ok")),
|
||||
):
|
||||
with caplog.at_level(logging.INFO, logger="llm_router"):
|
||||
await router_no_redis.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
context="supervisor.reconcile",
|
||||
)
|
||||
|
||||
assert any("supervisor.reconcile" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMRouter — Redis metrics publish
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLLMRouterRedis:
|
||||
async def test_metrics_published_on_success(self):
|
||||
chain = _chain_of(("m1", 8.0),)
|
||||
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
router._redis = mock_redis
|
||||
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("ok")),
|
||||
):
|
||||
await router.route([{"role": "user", "content": "x"}])
|
||||
|
||||
mock_redis.publish.assert_awaited_once()
|
||||
channel, payload_str = mock_redis.publish.call_args[0]
|
||||
assert channel == "llm_router_metrics"
|
||||
payload = json.loads(payload_str)
|
||||
assert payload["model_used"] == "m1"
|
||||
assert "metrics_snapshot" in payload
|
||||
assert "timestamp" in payload
|
||||
|
||||
async def test_redis_failure_does_not_raise(self):
|
||||
"""A broken Redis must never break the LLM call result."""
|
||||
chain = _chain_of(("m1", 8.0),)
|
||||
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish.side_effect = ConnectionError("Redis down")
|
||||
router._redis = mock_redis
|
||||
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("ok")),
|
||||
):
|
||||
result = await router.route([{"role": "user", "content": "x"}])
|
||||
|
||||
assert result.succeeded # LLM call still returned
|
||||
|
||||
async def test_metrics_published_on_full_failure(self):
|
||||
chain = _chain_of(("m1", 8.0),)
|
||||
router = LLMRouter(redis_url="redis://localhost:6379", chain=chain)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
router._redis = mock_redis
|
||||
|
||||
with patch(
|
||||
"litellm.acompletion",
|
||||
AsyncMock(return_value=_fake_completion("not json")),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await router.route(
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
schema={"type": "object"},
|
||||
)
|
||||
|
||||
# Metrics must still be published even when we raise
|
||||
mock_redis.publish.assert_awaited_once()
|
||||
_, payload_str = mock_redis.publish.call_args[0]
|
||||
payload = json.loads(payload_str)
|
||||
assert payload["model_used"] == "none"
|
||||
Loading…
Reference in a new issue