""" llm_router.py — LLM routing with local-first fallback chain. Routing strategy: 1. Local Qwen via Ollama (piha:11434, timeout 8 s) 2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s) 3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s) A model is rejected when it: - times out - raises any API / network exception - returns text matching a refusal pattern - returns JSON that fails the caller-supplied JSON Schema After every call (success or full-chain failure) a metrics event is published to the Redis channel "llm_router_metrics". Usage ----- router = LLMRouter() result = await router.route( messages=[{"role": "user", "content": "What should I do?"}], schema={"type": "object", "required": ["action"], "properties": {...}}, ) print(result.model_used, result.content) await router.close() """ import asyncio import json import logging import time from dataclasses import dataclass, field from typing import Any, Optional import litellm import redis.asyncio as aioredis from jsonschema import validate, ValidationError litellm.suppress_debug_info = True logger = logging.getLogger("llm_router") # --------------------------------------------------------------------------- # Refusal patterns — any substring match (case-insensitive) triggers fallback # --------------------------------------------------------------------------- REFUSAL_PATTERNS: list[str] = [ "nie wiem", "I cannot", "I can't", "as an AI", "I don't know", "I'm not able", "I am not able", "I'm unable", "I am unable", "beyond my capabilities", ] # --------------------------------------------------------------------------- # Data structures # --------------------------------------------------------------------------- @dataclass class ModelConfig: """Configuration for one model in the fallback chain.""" name: str # litellm model string, e.g. "ollama/qwen2.5:7b" timeout: float # hard wall-clock timeout in seconds api_base: Optional[str] = None # override API base URL (Ollama needs this) extra_kwargs: dict = field(default_factory=dict) def __str__(self) -> str: base = f" @ {self.api_base}" if self.api_base else "" return f"{self.name}{base} (timeout={self.timeout}s)" @dataclass class AttemptRecord: model: str outcome: str # "success" | "rejected" | "invalid" reason: Optional[str] # None on success latency_ms: int @dataclass class RouteResult: """Return value of LLMRouter.route().""" content: Any # parsed JSON (if schema given) or raw str raw_text: str model_used: str # "none" if every model failed attempts: list[AttemptRecord] latency_ms: int # wall-clock from first attempt to return @property def succeeded(self) -> bool: return self.model_used != "none" def to_dict(self) -> dict: return { "model_used": self.model_used, "latency_ms": self.latency_ms, "attempts": [ { "model": a.model, "outcome": a.outcome, "reason": a.reason, "latency_ms": a.latency_ms, } for a in self.attempts ], } # --------------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------------- class ModelMetrics: """Thread-safe-ish counter per model × outcome. Outcomes: "success", "fallback", "error" ("fallback" = rejected but another model succeeded after it; "error" = rejected and it was the last in chain or chain exhausted) """ def __init__(self) -> None: self._counts: dict[str, dict[str, int]] = {} def record(self, model: str, outcome: str) -> None: if model not in self._counts: self._counts[model] = {"success": 0, "fallback": 0, "error": 0} self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1 def snapshot(self) -> dict[str, dict[str, int]]: return {m: dict(c) for m, c in self._counts.items()} def total_calls(self, model: str) -> int: return sum(self._counts.get(model, {}).values()) def success_rate(self, model: str) -> Optional[float]: counts = self._counts.get(model, {}) total = sum(counts.values()) if total == 0: return None return counts.get("success", 0) / total # --------------------------------------------------------------------------- # Router # --------------------------------------------------------------------------- class LLMRouter: """Route LLM calls through a local-first fallback chain. Parameters ---------- redis_url: Redis connection URL for metrics publishing. Set to None to disable Redis (useful in tests / local dev). ollama_host: Base URL of the Ollama API. Defaults to piha's Tailscale address. ollama_model: Model tag as known to Ollama (e.g. "qwen2.5:7b"). chain: Override the entire fallback chain. When None the default Qwen → haiku → sonnet chain is used. """ DEFAULT_OLLAMA_HOST = "http://100.108.208.3:11434" DEFAULT_OLLAMA_MODEL = "qwen2.5:7b" DEFAULT_REDIS_URL = "redis://100.108.208.3:6379" def __init__( self, redis_url: Optional[str] = DEFAULT_REDIS_URL, ollama_host: str = DEFAULT_OLLAMA_HOST, ollama_model: str = DEFAULT_OLLAMA_MODEL, chain: Optional[list[ModelConfig]] = None, ) -> None: if chain is not None: self.chain = chain else: self.chain = [ ModelConfig( name=f"ollama/{ollama_model}", timeout=8.0, api_base=ollama_host, ), ModelConfig( name="claude-haiku-4-5-20251001", timeout=30.0, ), ModelConfig( name="claude-sonnet-4-6", timeout=30.0, ), ] self.metrics = ModelMetrics() self._redis_url = redis_url self._redis: Optional[aioredis.Redis] = None # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ async def route( self, messages: list[dict], schema: Optional[dict] = None, context: Optional[str] = None, ) -> RouteResult: """Try each model in order; return the first valid response. Parameters ---------- messages: OpenAI-style message list, e.g. [{"role": "user", "content": "..."}] schema: Optional JSON Schema dict. When provided the model's response must be valid JSON that conforms to the schema. context: Optional free-text caller label included in log lines (e.g. "supervisor.reconcile") for easier tracing. Raises ------ RuntimeError When every model in the chain fails. The exception message contains a JSON-formatted attempt log. """ tag = f"[{context}] " if context else "" start = time.monotonic() attempts: list[AttemptRecord] = [] for i, cfg in enumerate(self.chain): is_last = i == len(self.chain) - 1 attempt_start = time.monotonic() logger.info( f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}" ) raw_text, call_error = await self._call_model(cfg, messages) attempt_ms = round((time.monotonic() - attempt_start) * 1000) if call_error: self.metrics.record(cfg.name, "error" if is_last else "fallback") logger.warning( f"{tag}[llm_router] {cfg.name} → rejected " f"({call_error}) [{attempt_ms}ms]" ) attempts.append(AttemptRecord( model=cfg.name, outcome="rejected", reason=call_error, latency_ms=attempt_ms, )) continue parsed, schema_error = self._validate(raw_text, schema) if schema_error: self.metrics.record(cfg.name, "error" if is_last else "fallback") logger.warning( f"{tag}[llm_router] {cfg.name} → invalid " f"({schema_error}) [{attempt_ms}ms]" ) attempts.append(AttemptRecord( model=cfg.name, outcome="invalid", reason=schema_error, latency_ms=attempt_ms, )) continue # ── success ─────────────────────────────────────────────── self.metrics.record(cfg.name, "success") total_ms = round((time.monotonic() - start) * 1000) logger.info( f"{tag}[llm_router] {cfg.name} → success " f"[attempt {attempt_ms}ms, total {total_ms}ms]" ) attempts.append(AttemptRecord( model=cfg.name, outcome="success", reason=None, latency_ms=attempt_ms, )) result = RouteResult( content=parsed, raw_text=raw_text, model_used=cfg.name, attempts=attempts, latency_ms=total_ms, ) await self._publish_metrics(result) return result # ── all models exhausted ────────────────────────────────────── total_ms = round((time.monotonic() - start) * 1000) result = RouteResult( content=None, raw_text="", model_used="none", attempts=attempts, latency_ms=total_ms, ) await self._publish_metrics(result) attempt_log = json.dumps( [{"model": a.model, "reason": a.reason} for a in attempts], indent=2, ) raise RuntimeError( f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n" f"Attempts:\n{attempt_log}" ) async def close(self) -> None: """Release the Redis connection.""" if self._redis is not None: await self._redis.aclose() self._redis = None # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ async def _call_model( self, cfg: ModelConfig, messages: list[dict], ) -> tuple[str, Optional[str]]: """Invoke one model. Returns (raw_text, error_reason|None).""" kwargs: dict[str, Any] = { "model": cfg.name, "messages": messages, "timeout": cfg.timeout, **cfg.extra_kwargs, } if cfg.api_base: kwargs["api_base"] = cfg.api_base try: # asyncio.wait_for as belt-and-suspenders — litellm timeout # is passed to the underlying HTTP client, but asyncio task # cancellation ensures we never block the event loop. resp = await asyncio.wait_for( litellm.acompletion(**kwargs), timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead ) text = (resp.choices[0].message.content or "").strip() except asyncio.TimeoutError: return "", f"Timeout after {cfg.timeout}s" except litellm.exceptions.Timeout: return "", f"Timeout after {cfg.timeout}s" except litellm.exceptions.APIConnectionError as e: return "", f"APIConnectionError: {e}" except litellm.exceptions.AuthenticationError as e: return "", f"AuthenticationError: {e}" except Exception as e: return "", f"{type(e).__name__}: {e}" # Check for refusals in the model's own text refusal = self._detect_refusal(text) if refusal: return text, f"RefusalPattern matched: '{refusal}'" return text, None @staticmethod def _detect_refusal(text: str) -> Optional[str]: """Return the first matching refusal pattern, or None.""" lower = text.lower() for pattern in REFUSAL_PATTERNS: if pattern.lower() in lower: return pattern return None @staticmethod def _validate( text: str, schema: Optional[dict], ) -> tuple[Any, Optional[str]]: """Parse and validate the model response. Returns (parsed_content, error_reason|None). When schema is None, returns (raw_text, None) — only refusal detection (already done in _call_model) applies. """ if schema is None: return text, None try: parsed = json.loads(text) except json.JSONDecodeError as exc: # Try to extract JSON from a markdown code fence extracted = _extract_json_from_fence(text) if extracted is not None: parsed = extracted else: return None, f"JSONDecodeError: {exc}" try: validate(instance=parsed, schema=schema) except ValidationError as exc: return None, f"SchemaValidationError: {exc.message}" return parsed, None async def _get_redis(self) -> Optional[aioredis.Redis]: if self._redis_url is None: return None if self._redis is None: self._redis = aioredis.from_url( self._redis_url, decode_responses=True, socket_connect_timeout=2, socket_timeout=2, ) return self._redis async def _publish_metrics(self, result: RouteResult) -> None: """Non-blocking publish to Redis channel 'llm_router_metrics'.""" payload = { **result.to_dict(), "metrics_snapshot": self.metrics.snapshot(), "timestamp": time.time(), } try: r = await self._get_redis() if r is not None: await r.publish("llm_router_metrics", json.dumps(payload)) except Exception as exc: # Never let a metrics failure break the caller logger.warning(f"[llm_router] metrics publish failed: {exc}") # --------------------------------------------------------------------------- # Utility # --------------------------------------------------------------------------- def _extract_json_from_fence(text: str) -> Optional[Any]: """Extract JSON from a ```json ... ``` markdown code fence, if present.""" import re match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL) if match: try: return json.loads(match.group(1)) except json.JSONDecodeError: pass return None