LiteLLM reads OLLAMA_API_BASE, not OLLAMA_HOST.
- llm_router.py: DEFAULT_OLLAMA_HOST → DEFAULT_OLLAMA_API_BASE, param ollama_host → ollama_api_base
- planner.py: env var os.getenv("OLLAMA_HOST") → os.getenv("OLLAMA_API_BASE"), param renamed accordingly
- /opt/homelab/config/planner-agent/.env on SOLARIA updated in-place (not in git)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
450 lines
15 KiB
Python
450 lines
15 KiB
Python
"""
|
||
llm_router.py — LLM routing with local-first fallback chain.
|
||
|
||
Routing strategy:
|
||
1. Local Qwen via Ollama (piha:11434, timeout 8 s)
|
||
2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s)
|
||
3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s)
|
||
|
||
A model is rejected when it:
|
||
- times out
|
||
- raises any API / network exception
|
||
- returns text matching a refusal pattern
|
||
- returns JSON that fails the caller-supplied JSON Schema
|
||
|
||
After every call (success or full-chain failure) a metrics event is
|
||
published to the Redis channel "llm_router_metrics".
|
||
|
||
Usage
|
||
-----
|
||
router = LLMRouter()
|
||
result = await router.route(
|
||
messages=[{"role": "user", "content": "What should I do?"}],
|
||
schema={"type": "object", "required": ["action"], "properties": {...}},
|
||
)
|
||
print(result.model_used, result.content)
|
||
await router.close()
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Optional
|
||
|
||
import litellm
|
||
import redis.asyncio as aioredis
|
||
from jsonschema import validate, ValidationError
|
||
|
||
litellm.suppress_debug_info = True
|
||
|
||
logger = logging.getLogger("llm_router")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Refusal patterns — any substring match (case-insensitive) triggers fallback
|
||
# ---------------------------------------------------------------------------
|
||
REFUSAL_PATTERNS: list[str] = [
|
||
"nie wiem",
|
||
"I cannot",
|
||
"I can't",
|
||
"as an AI",
|
||
"I don't know",
|
||
"I'm not able",
|
||
"I am not able",
|
||
"I'm unable",
|
||
"I am unable",
|
||
"beyond my capabilities",
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Data structures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class ModelConfig:
|
||
"""Configuration for one model in the fallback chain."""
|
||
name: str # litellm model string, e.g. "ollama/qwen2.5:7b"
|
||
timeout: float # hard wall-clock timeout in seconds
|
||
api_base: Optional[str] = None # override API base URL (Ollama needs this)
|
||
extra_kwargs: dict = field(default_factory=dict)
|
||
|
||
def __str__(self) -> str:
|
||
base = f" @ {self.api_base}" if self.api_base else ""
|
||
return f"{self.name}{base} (timeout={self.timeout}s)"
|
||
|
||
|
||
@dataclass
|
||
class AttemptRecord:
|
||
model: str
|
||
outcome: str # "success" | "rejected" | "invalid"
|
||
reason: Optional[str] # None on success
|
||
latency_ms: int
|
||
|
||
|
||
@dataclass
|
||
class RouteResult:
|
||
"""Return value of LLMRouter.route()."""
|
||
content: Any # parsed JSON (if schema given) or raw str
|
||
raw_text: str
|
||
model_used: str # "none" if every model failed
|
||
attempts: list[AttemptRecord]
|
||
latency_ms: int # wall-clock from first attempt to return
|
||
|
||
@property
|
||
def succeeded(self) -> bool:
|
||
return self.model_used != "none"
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"model_used": self.model_used,
|
||
"latency_ms": self.latency_ms,
|
||
"attempts": [
|
||
{
|
||
"model": a.model,
|
||
"outcome": a.outcome,
|
||
"reason": a.reason,
|
||
"latency_ms": a.latency_ms,
|
||
}
|
||
for a in self.attempts
|
||
],
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Metrics
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ModelMetrics:
|
||
"""Thread-safe-ish counter per model × outcome.
|
||
|
||
Outcomes: "success", "fallback", "error"
|
||
("fallback" = rejected but another model succeeded after it;
|
||
"error" = rejected and it was the last in chain or chain exhausted)
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._counts: dict[str, dict[str, int]] = {}
|
||
|
||
def record(self, model: str, outcome: str) -> None:
|
||
if model not in self._counts:
|
||
self._counts[model] = {"success": 0, "fallback": 0, "error": 0}
|
||
self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1
|
||
|
||
def snapshot(self) -> dict[str, dict[str, int]]:
|
||
return {m: dict(c) for m, c in self._counts.items()}
|
||
|
||
def total_calls(self, model: str) -> int:
|
||
return sum(self._counts.get(model, {}).values())
|
||
|
||
def success_rate(self, model: str) -> Optional[float]:
|
||
counts = self._counts.get(model, {})
|
||
total = sum(counts.values())
|
||
if total == 0:
|
||
return None
|
||
return counts.get("success", 0) / total
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Router
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class LLMRouter:
|
||
"""Route LLM calls through a local-first fallback chain.
|
||
|
||
Parameters
|
||
----------
|
||
redis_url:
|
||
Redis connection URL for metrics publishing.
|
||
Set to None to disable Redis (useful in tests / local dev).
|
||
ollama_api_base:
|
||
Base URL of the Ollama API. Maps to the ``OLLAMA_API_BASE``
|
||
environment variable (the name LiteLLM recognises).
|
||
Defaults to SATURN's Tailscale address.
|
||
ollama_model:
|
||
Model tag as known to Ollama (e.g. "qwen2.5:7b").
|
||
chain:
|
||
Override the entire fallback chain. When None the default
|
||
Qwen → haiku → sonnet chain is used.
|
||
"""
|
||
|
||
DEFAULT_OLLAMA_API_BASE = "http://100.108.208.3:11434"
|
||
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
|
||
DEFAULT_REDIS_URL = "redis://100.108.208.3:6379"
|
||
|
||
def __init__(
|
||
self,
|
||
redis_url: Optional[str] = DEFAULT_REDIS_URL,
|
||
ollama_api_base: str = DEFAULT_OLLAMA_API_BASE,
|
||
ollama_model: str = DEFAULT_OLLAMA_MODEL,
|
||
chain: Optional[list[ModelConfig]] = None,
|
||
) -> None:
|
||
if chain is not None:
|
||
self.chain = chain
|
||
else:
|
||
self.chain = [
|
||
ModelConfig(
|
||
name=f"ollama/{ollama_model}",
|
||
timeout=8.0,
|
||
api_base=ollama_api_base,
|
||
),
|
||
ModelConfig(
|
||
name="claude-haiku-4-5-20251001",
|
||
timeout=30.0,
|
||
),
|
||
ModelConfig(
|
||
name="claude-sonnet-4-6",
|
||
timeout=30.0,
|
||
),
|
||
]
|
||
|
||
self.metrics = ModelMetrics()
|
||
self._redis_url = redis_url
|
||
self._redis: Optional[aioredis.Redis] = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# Public API
|
||
# ------------------------------------------------------------------
|
||
|
||
async def route(
|
||
self,
|
||
messages: list[dict],
|
||
schema: Optional[dict] = None,
|
||
context: Optional[str] = None,
|
||
) -> RouteResult:
|
||
"""Try each model in order; return the first valid response.
|
||
|
||
Parameters
|
||
----------
|
||
messages:
|
||
OpenAI-style message list, e.g.
|
||
[{"role": "user", "content": "..."}]
|
||
schema:
|
||
Optional JSON Schema dict. When provided the model's response
|
||
must be valid JSON that conforms to the schema.
|
||
context:
|
||
Optional free-text caller label included in log lines (e.g.
|
||
"supervisor.reconcile") for easier tracing.
|
||
|
||
Raises
|
||
------
|
||
RuntimeError
|
||
When every model in the chain fails. The exception message
|
||
contains a JSON-formatted attempt log.
|
||
"""
|
||
tag = f"[{context}] " if context else ""
|
||
start = time.monotonic()
|
||
attempts: list[AttemptRecord] = []
|
||
|
||
for i, cfg in enumerate(self.chain):
|
||
is_last = i == len(self.chain) - 1
|
||
attempt_start = time.monotonic()
|
||
|
||
logger.info(
|
||
f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}"
|
||
)
|
||
|
||
raw_text, call_error = await self._call_model(cfg, messages)
|
||
attempt_ms = round((time.monotonic() - attempt_start) * 1000)
|
||
|
||
if call_error:
|
||
self.metrics.record(cfg.name, "error" if is_last else "fallback")
|
||
logger.warning(
|
||
f"{tag}[llm_router] {cfg.name} → rejected "
|
||
f"({call_error}) [{attempt_ms}ms]"
|
||
)
|
||
attempts.append(AttemptRecord(
|
||
model=cfg.name, outcome="rejected",
|
||
reason=call_error, latency_ms=attempt_ms,
|
||
))
|
||
continue
|
||
|
||
parsed, schema_error = self._validate(raw_text, schema)
|
||
if schema_error:
|
||
self.metrics.record(cfg.name, "error" if is_last else "fallback")
|
||
logger.warning(
|
||
f"{tag}[llm_router] {cfg.name} → invalid "
|
||
f"({schema_error}) [{attempt_ms}ms]"
|
||
)
|
||
attempts.append(AttemptRecord(
|
||
model=cfg.name, outcome="invalid",
|
||
reason=schema_error, latency_ms=attempt_ms,
|
||
))
|
||
continue
|
||
|
||
# ── success ───────────────────────────────────────────────
|
||
self.metrics.record(cfg.name, "success")
|
||
total_ms = round((time.monotonic() - start) * 1000)
|
||
logger.info(
|
||
f"{tag}[llm_router] {cfg.name} → success "
|
||
f"[attempt {attempt_ms}ms, total {total_ms}ms]"
|
||
)
|
||
attempts.append(AttemptRecord(
|
||
model=cfg.name, outcome="success",
|
||
reason=None, latency_ms=attempt_ms,
|
||
))
|
||
result = RouteResult(
|
||
content=parsed,
|
||
raw_text=raw_text,
|
||
model_used=cfg.name,
|
||
attempts=attempts,
|
||
latency_ms=total_ms,
|
||
)
|
||
await self._publish_metrics(result)
|
||
return result
|
||
|
||
# ── all models exhausted ──────────────────────────────────────
|
||
total_ms = round((time.monotonic() - start) * 1000)
|
||
result = RouteResult(
|
||
content=None,
|
||
raw_text="",
|
||
model_used="none",
|
||
attempts=attempts,
|
||
latency_ms=total_ms,
|
||
)
|
||
await self._publish_metrics(result)
|
||
|
||
attempt_log = json.dumps(
|
||
[{"model": a.model, "reason": a.reason} for a in attempts],
|
||
indent=2,
|
||
)
|
||
raise RuntimeError(
|
||
f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n"
|
||
f"Attempts:\n{attempt_log}"
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
"""Release the Redis connection."""
|
||
if self._redis is not None:
|
||
await self._redis.aclose()
|
||
self._redis = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _call_model(
|
||
self,
|
||
cfg: ModelConfig,
|
||
messages: list[dict],
|
||
) -> tuple[str, Optional[str]]:
|
||
"""Invoke one model. Returns (raw_text, error_reason|None)."""
|
||
kwargs: dict[str, Any] = {
|
||
"model": cfg.name,
|
||
"messages": messages,
|
||
"timeout": cfg.timeout,
|
||
**cfg.extra_kwargs,
|
||
}
|
||
if cfg.api_base:
|
||
kwargs["api_base"] = cfg.api_base
|
||
|
||
try:
|
||
# asyncio.wait_for as belt-and-suspenders — litellm timeout
|
||
# is passed to the underlying HTTP client, but asyncio task
|
||
# cancellation ensures we never block the event loop.
|
||
resp = await asyncio.wait_for(
|
||
litellm.acompletion(**kwargs),
|
||
timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead
|
||
)
|
||
text = (resp.choices[0].message.content or "").strip()
|
||
except asyncio.TimeoutError:
|
||
return "", f"Timeout after {cfg.timeout}s"
|
||
except litellm.exceptions.Timeout:
|
||
return "", f"Timeout after {cfg.timeout}s"
|
||
except litellm.exceptions.APIConnectionError as e:
|
||
return "", f"APIConnectionError: {e}"
|
||
except litellm.exceptions.AuthenticationError as e:
|
||
return "", f"AuthenticationError: {e}"
|
||
except Exception as e:
|
||
return "", f"{type(e).__name__}: {e}"
|
||
|
||
# Check for refusals in the model's own text
|
||
refusal = self._detect_refusal(text)
|
||
if refusal:
|
||
return text, f"RefusalPattern matched: '{refusal}'"
|
||
return text, None
|
||
|
||
@staticmethod
|
||
def _detect_refusal(text: str) -> Optional[str]:
|
||
"""Return the first matching refusal pattern, or None."""
|
||
lower = text.lower()
|
||
for pattern in REFUSAL_PATTERNS:
|
||
if pattern.lower() in lower:
|
||
return pattern
|
||
return None
|
||
|
||
@staticmethod
|
||
def _validate(
|
||
text: str,
|
||
schema: Optional[dict],
|
||
) -> tuple[Any, Optional[str]]:
|
||
"""Parse and validate the model response.
|
||
|
||
Returns (parsed_content, error_reason|None).
|
||
When schema is None, returns (raw_text, None) — only refusal
|
||
detection (already done in _call_model) applies.
|
||
"""
|
||
if schema is None:
|
||
return text, None
|
||
|
||
try:
|
||
parsed = json.loads(text)
|
||
except json.JSONDecodeError as exc:
|
||
# Try to extract JSON from a markdown code fence
|
||
extracted = _extract_json_from_fence(text)
|
||
if extracted is not None:
|
||
parsed = extracted
|
||
else:
|
||
return None, f"JSONDecodeError: {exc}"
|
||
|
||
try:
|
||
validate(instance=parsed, schema=schema)
|
||
except ValidationError as exc:
|
||
return None, f"SchemaValidationError: {exc.message}"
|
||
|
||
return parsed, None
|
||
|
||
async def _get_redis(self) -> Optional[aioredis.Redis]:
|
||
if self._redis_url is None:
|
||
return None
|
||
if self._redis is None:
|
||
self._redis = aioredis.from_url(
|
||
self._redis_url,
|
||
decode_responses=True,
|
||
socket_connect_timeout=2,
|
||
socket_timeout=2,
|
||
)
|
||
return self._redis
|
||
|
||
async def _publish_metrics(self, result: RouteResult) -> None:
|
||
"""Non-blocking publish to Redis channel 'llm_router_metrics'."""
|
||
payload = {
|
||
**result.to_dict(),
|
||
"metrics_snapshot": self.metrics.snapshot(),
|
||
"timestamp": time.time(),
|
||
}
|
||
try:
|
||
r = await self._get_redis()
|
||
if r is not None:
|
||
await r.publish("llm_router_metrics", json.dumps(payload))
|
||
except Exception as exc:
|
||
# Never let a metrics failure break the caller
|
||
logger.warning(f"[llm_router] metrics publish failed: {exc}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Utility
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _extract_json_from_fence(text: str) -> Optional[Any]:
|
||
"""Extract JSON from a ```json ... ``` markdown code fence, if present."""
|
||
import re
|
||
match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
return json.loads(match.group(1))
|
||
except json.JSONDecodeError:
|
||
pass
|
||
return None
|