homelab-codex-ws/services/planner-agent/src/llm_router.py

450 lines
15 KiB
Python
Raw Normal View History

"""
llm_router.py LLM routing with local-first fallback chain.
Routing strategy:
1. Local Qwen via Ollama (piha:11434, timeout 8 s)
2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s)
3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s)
A model is rejected when it:
- times out
- raises any API / network exception
- returns text matching a refusal pattern
- returns JSON that fails the caller-supplied JSON Schema
After every call (success or full-chain failure) a metrics event is
published to the Redis channel "llm_router_metrics".
Usage
-----
router = LLMRouter()
result = await router.route(
messages=[{"role": "user", "content": "What should I do?"}],
schema={"type": "object", "required": ["action"], "properties": {...}},
)
print(result.model_used, result.content)
await router.close()
"""
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Optional
import litellm
import redis.asyncio as aioredis
from jsonschema import validate, ValidationError
litellm.suppress_debug_info = True
logger = logging.getLogger("llm_router")
# ---------------------------------------------------------------------------
# Refusal patterns — any substring match (case-insensitive) triggers fallback
# ---------------------------------------------------------------------------
REFUSAL_PATTERNS: list[str] = [
"nie wiem",
"I cannot",
"I can't",
"as an AI",
"I don't know",
"I'm not able",
"I am not able",
"I'm unable",
"I am unable",
"beyond my capabilities",
]
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ModelConfig:
"""Configuration for one model in the fallback chain."""
name: str # litellm model string, e.g. "ollama/qwen2.5:7b"
timeout: float # hard wall-clock timeout in seconds
api_base: Optional[str] = None # override API base URL (Ollama needs this)
extra_kwargs: dict = field(default_factory=dict)
def __str__(self) -> str:
base = f" @ {self.api_base}" if self.api_base else ""
return f"{self.name}{base} (timeout={self.timeout}s)"
@dataclass
class AttemptRecord:
model: str
outcome: str # "success" | "rejected" | "invalid"
reason: Optional[str] # None on success
latency_ms: int
@dataclass
class RouteResult:
"""Return value of LLMRouter.route()."""
content: Any # parsed JSON (if schema given) or raw str
raw_text: str
model_used: str # "none" if every model failed
attempts: list[AttemptRecord]
latency_ms: int # wall-clock from first attempt to return
@property
def succeeded(self) -> bool:
return self.model_used != "none"
def to_dict(self) -> dict:
return {
"model_used": self.model_used,
"latency_ms": self.latency_ms,
"attempts": [
{
"model": a.model,
"outcome": a.outcome,
"reason": a.reason,
"latency_ms": a.latency_ms,
}
for a in self.attempts
],
}
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
class ModelMetrics:
"""Thread-safe-ish counter per model × outcome.
Outcomes: "success", "fallback", "error"
("fallback" = rejected but another model succeeded after it;
"error" = rejected and it was the last in chain or chain exhausted)
"""
def __init__(self) -> None:
self._counts: dict[str, dict[str, int]] = {}
def record(self, model: str, outcome: str) -> None:
if model not in self._counts:
self._counts[model] = {"success": 0, "fallback": 0, "error": 0}
self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1
def snapshot(self) -> dict[str, dict[str, int]]:
return {m: dict(c) for m, c in self._counts.items()}
def total_calls(self, model: str) -> int:
return sum(self._counts.get(model, {}).values())
def success_rate(self, model: str) -> Optional[float]:
counts = self._counts.get(model, {})
total = sum(counts.values())
if total == 0:
return None
return counts.get("success", 0) / total
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
class LLMRouter:
"""Route LLM calls through a local-first fallback chain.
Parameters
----------
redis_url:
Redis connection URL for metrics publishing.
Set to None to disable Redis (useful in tests / local dev).
ollama_api_base:
Base URL of the Ollama API. Maps to the ``OLLAMA_API_BASE``
environment variable (the name LiteLLM recognises).
Defaults to SATURN's Tailscale address.
ollama_model:
Model tag as known to Ollama (e.g. "qwen2.5:7b").
chain:
Override the entire fallback chain. When None the default
Qwen haiku sonnet chain is used.
"""
DEFAULT_OLLAMA_API_BASE = "http://100.108.208.3:11434"
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
DEFAULT_REDIS_URL = "redis://100.108.208.3:6379"
def __init__(
self,
redis_url: Optional[str] = DEFAULT_REDIS_URL,
ollama_api_base: str = DEFAULT_OLLAMA_API_BASE,
ollama_model: str = DEFAULT_OLLAMA_MODEL,
chain: Optional[list[ModelConfig]] = None,
) -> None:
if chain is not None:
self.chain = chain
else:
self.chain = [
ModelConfig(
name=f"ollama/{ollama_model}",
timeout=8.0,
api_base=ollama_api_base,
),
ModelConfig(
name="claude-haiku-4-5-20251001",
timeout=30.0,
),
ModelConfig(
name="claude-sonnet-4-6",
timeout=30.0,
),
]
self.metrics = ModelMetrics()
self._redis_url = redis_url
self._redis: Optional[aioredis.Redis] = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def route(
self,
messages: list[dict],
schema: Optional[dict] = None,
context: Optional[str] = None,
) -> RouteResult:
"""Try each model in order; return the first valid response.
Parameters
----------
messages:
OpenAI-style message list, e.g.
[{"role": "user", "content": "..."}]
schema:
Optional JSON Schema dict. When provided the model's response
must be valid JSON that conforms to the schema.
context:
Optional free-text caller label included in log lines (e.g.
"supervisor.reconcile") for easier tracing.
Raises
------
RuntimeError
When every model in the chain fails. The exception message
contains a JSON-formatted attempt log.
"""
tag = f"[{context}] " if context else ""
start = time.monotonic()
attempts: list[AttemptRecord] = []
for i, cfg in enumerate(self.chain):
is_last = i == len(self.chain) - 1
attempt_start = time.monotonic()
logger.info(
f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}"
)
raw_text, call_error = await self._call_model(cfg, messages)
attempt_ms = round((time.monotonic() - attempt_start) * 1000)
if call_error:
self.metrics.record(cfg.name, "error" if is_last else "fallback")
logger.warning(
f"{tag}[llm_router] {cfg.name} → rejected "
f"({call_error}) [{attempt_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="rejected",
reason=call_error, latency_ms=attempt_ms,
))
continue
parsed, schema_error = self._validate(raw_text, schema)
if schema_error:
self.metrics.record(cfg.name, "error" if is_last else "fallback")
logger.warning(
f"{tag}[llm_router] {cfg.name} → invalid "
f"({schema_error}) [{attempt_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="invalid",
reason=schema_error, latency_ms=attempt_ms,
))
continue
# ── success ───────────────────────────────────────────────
self.metrics.record(cfg.name, "success")
total_ms = round((time.monotonic() - start) * 1000)
logger.info(
f"{tag}[llm_router] {cfg.name} → success "
f"[attempt {attempt_ms}ms, total {total_ms}ms]"
)
attempts.append(AttemptRecord(
model=cfg.name, outcome="success",
reason=None, latency_ms=attempt_ms,
))
result = RouteResult(
content=parsed,
raw_text=raw_text,
model_used=cfg.name,
attempts=attempts,
latency_ms=total_ms,
)
await self._publish_metrics(result)
return result
# ── all models exhausted ──────────────────────────────────────
total_ms = round((time.monotonic() - start) * 1000)
result = RouteResult(
content=None,
raw_text="",
model_used="none",
attempts=attempts,
latency_ms=total_ms,
)
await self._publish_metrics(result)
attempt_log = json.dumps(
[{"model": a.model, "reason": a.reason} for a in attempts],
indent=2,
)
raise RuntimeError(
f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n"
f"Attempts:\n{attempt_log}"
)
async def close(self) -> None:
"""Release the Redis connection."""
if self._redis is not None:
await self._redis.aclose()
self._redis = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _call_model(
self,
cfg: ModelConfig,
messages: list[dict],
) -> tuple[str, Optional[str]]:
"""Invoke one model. Returns (raw_text, error_reason|None)."""
kwargs: dict[str, Any] = {
"model": cfg.name,
"messages": messages,
"timeout": cfg.timeout,
**cfg.extra_kwargs,
}
if cfg.api_base:
kwargs["api_base"] = cfg.api_base
try:
# asyncio.wait_for as belt-and-suspenders — litellm timeout
# is passed to the underlying HTTP client, but asyncio task
# cancellation ensures we never block the event loop.
resp = await asyncio.wait_for(
litellm.acompletion(**kwargs),
timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead
)
text = (resp.choices[0].message.content or "").strip()
except asyncio.TimeoutError:
return "", f"Timeout after {cfg.timeout}s"
except litellm.exceptions.Timeout:
return "", f"Timeout after {cfg.timeout}s"
except litellm.exceptions.APIConnectionError as e:
return "", f"APIConnectionError: {e}"
except litellm.exceptions.AuthenticationError as e:
return "", f"AuthenticationError: {e}"
except Exception as e:
return "", f"{type(e).__name__}: {e}"
# Check for refusals in the model's own text
refusal = self._detect_refusal(text)
if refusal:
return text, f"RefusalPattern matched: '{refusal}'"
return text, None
@staticmethod
def _detect_refusal(text: str) -> Optional[str]:
"""Return the first matching refusal pattern, or None."""
lower = text.lower()
for pattern in REFUSAL_PATTERNS:
if pattern.lower() in lower:
return pattern
return None
@staticmethod
def _validate(
text: str,
schema: Optional[dict],
) -> tuple[Any, Optional[str]]:
"""Parse and validate the model response.
Returns (parsed_content, error_reason|None).
When schema is None, returns (raw_text, None) only refusal
detection (already done in _call_model) applies.
"""
if schema is None:
return text, None
try:
parsed = json.loads(text)
except json.JSONDecodeError as exc:
# Try to extract JSON from a markdown code fence
extracted = _extract_json_from_fence(text)
if extracted is not None:
parsed = extracted
else:
return None, f"JSONDecodeError: {exc}"
try:
validate(instance=parsed, schema=schema)
except ValidationError as exc:
return None, f"SchemaValidationError: {exc.message}"
return parsed, None
async def _get_redis(self) -> Optional[aioredis.Redis]:
if self._redis_url is None:
return None
if self._redis is None:
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2,
)
return self._redis
async def _publish_metrics(self, result: RouteResult) -> None:
"""Non-blocking publish to Redis channel 'llm_router_metrics'."""
payload = {
**result.to_dict(),
"metrics_snapshot": self.metrics.snapshot(),
"timestamp": time.time(),
}
try:
r = await self._get_redis()
if r is not None:
await r.publish("llm_router_metrics", json.dumps(payload))
except Exception as exc:
# Never let a metrics failure break the caller
logger.warning(f"[llm_router] metrics publish failed: {exc}")
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _extract_json_from_fence(text: str) -> Optional[Any]:
"""Extract JSON from a ```json ... ``` markdown code fence, if present."""
import re
match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
return None