diff --git a/services/planner-agent/requirements.txt b/services/planner-agent/requirements.txt new file mode 100644 index 0000000..3d88ca9 --- /dev/null +++ b/services/planner-agent/requirements.txt @@ -0,0 +1,3 @@ +litellm>=1.40.0 +redis[asyncio]>=5.0.0 +jsonschema>=4.21.0 diff --git a/services/planner-agent/src/llm_router.py b/services/planner-agent/src/llm_router.py new file mode 100644 index 0000000..505a89b --- /dev/null +++ b/services/planner-agent/src/llm_router.py @@ -0,0 +1,447 @@ +""" +llm_router.py — LLM routing with local-first fallback chain. + +Routing strategy: + 1. Local Qwen via Ollama (piha:11434, timeout 8 s) + 2. claude-haiku-4-5 (Anthropic cloud, timeout 30 s) + 3. claude-sonnet-4-6 (Anthropic cloud, timeout 30 s) + +A model is rejected when it: + - times out + - raises any API / network exception + - returns text matching a refusal pattern + - returns JSON that fails the caller-supplied JSON Schema + +After every call (success or full-chain failure) a metrics event is +published to the Redis channel "llm_router_metrics". + +Usage +----- + router = LLMRouter() + result = await router.route( + messages=[{"role": "user", "content": "What should I do?"}], + schema={"type": "object", "required": ["action"], "properties": {...}}, + ) + print(result.model_used, result.content) + await router.close() +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import litellm +import redis.asyncio as aioredis +from jsonschema import validate, ValidationError + +litellm.suppress_debug_info = True + +logger = logging.getLogger("llm_router") + +# --------------------------------------------------------------------------- +# Refusal patterns — any substring match (case-insensitive) triggers fallback +# --------------------------------------------------------------------------- +REFUSAL_PATTERNS: list[str] = [ + "nie wiem", + "I cannot", + "I can't", + "as an AI", + "I don't know", + "I'm not able", + "I am not able", + "I'm unable", + "I am unable", + "beyond my capabilities", +] + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class ModelConfig: + """Configuration for one model in the fallback chain.""" + name: str # litellm model string, e.g. "ollama/qwen2.5:7b" + timeout: float # hard wall-clock timeout in seconds + api_base: Optional[str] = None # override API base URL (Ollama needs this) + extra_kwargs: dict = field(default_factory=dict) + + def __str__(self) -> str: + base = f" @ {self.api_base}" if self.api_base else "" + return f"{self.name}{base} (timeout={self.timeout}s)" + + +@dataclass +class AttemptRecord: + model: str + outcome: str # "success" | "rejected" | "invalid" + reason: Optional[str] # None on success + latency_ms: int + + +@dataclass +class RouteResult: + """Return value of LLMRouter.route().""" + content: Any # parsed JSON (if schema given) or raw str + raw_text: str + model_used: str # "none" if every model failed + attempts: list[AttemptRecord] + latency_ms: int # wall-clock from first attempt to return + + @property + def succeeded(self) -> bool: + return self.model_used != "none" + + def to_dict(self) -> dict: + return { + "model_used": self.model_used, + "latency_ms": self.latency_ms, + "attempts": [ + { + "model": a.model, + "outcome": a.outcome, + "reason": a.reason, + "latency_ms": a.latency_ms, + } + for a in self.attempts + ], + } + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + +class ModelMetrics: + """Thread-safe-ish counter per model × outcome. + + Outcomes: "success", "fallback", "error" + ("fallback" = rejected but another model succeeded after it; + "error" = rejected and it was the last in chain or chain exhausted) + """ + + def __init__(self) -> None: + self._counts: dict[str, dict[str, int]] = {} + + def record(self, model: str, outcome: str) -> None: + if model not in self._counts: + self._counts[model] = {"success": 0, "fallback": 0, "error": 0} + self._counts[model][outcome] = self._counts[model].get(outcome, 0) + 1 + + def snapshot(self) -> dict[str, dict[str, int]]: + return {m: dict(c) for m, c in self._counts.items()} + + def total_calls(self, model: str) -> int: + return sum(self._counts.get(model, {}).values()) + + def success_rate(self, model: str) -> Optional[float]: + counts = self._counts.get(model, {}) + total = sum(counts.values()) + if total == 0: + return None + return counts.get("success", 0) / total + + +# --------------------------------------------------------------------------- +# Router +# --------------------------------------------------------------------------- + +class LLMRouter: + """Route LLM calls through a local-first fallback chain. + + Parameters + ---------- + redis_url: + Redis connection URL for metrics publishing. + Set to None to disable Redis (useful in tests / local dev). + ollama_host: + Base URL of the Ollama API. Defaults to piha's Tailscale address. + ollama_model: + Model tag as known to Ollama (e.g. "qwen2.5:7b"). + chain: + Override the entire fallback chain. When None the default + Qwen → haiku → sonnet chain is used. + """ + + DEFAULT_OLLAMA_HOST = "http://100.108.208.3:11434" + DEFAULT_OLLAMA_MODEL = "qwen2.5:7b" + DEFAULT_REDIS_URL = "redis://100.108.208.3:6379" + + def __init__( + self, + redis_url: Optional[str] = DEFAULT_REDIS_URL, + ollama_host: str = DEFAULT_OLLAMA_HOST, + ollama_model: str = DEFAULT_OLLAMA_MODEL, + chain: Optional[list[ModelConfig]] = None, + ) -> None: + if chain is not None: + self.chain = chain + else: + self.chain = [ + ModelConfig( + name=f"ollama/{ollama_model}", + timeout=8.0, + api_base=ollama_host, + ), + ModelConfig( + name="claude-haiku-4-5-20251001", + timeout=30.0, + ), + ModelConfig( + name="claude-sonnet-4-6", + timeout=30.0, + ), + ] + + self.metrics = ModelMetrics() + self._redis_url = redis_url + self._redis: Optional[aioredis.Redis] = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def route( + self, + messages: list[dict], + schema: Optional[dict] = None, + context: Optional[str] = None, + ) -> RouteResult: + """Try each model in order; return the first valid response. + + Parameters + ---------- + messages: + OpenAI-style message list, e.g. + [{"role": "user", "content": "..."}] + schema: + Optional JSON Schema dict. When provided the model's response + must be valid JSON that conforms to the schema. + context: + Optional free-text caller label included in log lines (e.g. + "supervisor.reconcile") for easier tracing. + + Raises + ------ + RuntimeError + When every model in the chain fails. The exception message + contains a JSON-formatted attempt log. + """ + tag = f"[{context}] " if context else "" + start = time.monotonic() + attempts: list[AttemptRecord] = [] + + for i, cfg in enumerate(self.chain): + is_last = i == len(self.chain) - 1 + attempt_start = time.monotonic() + + logger.info( + f"{tag}[llm_router] attempt {i+1}/{len(self.chain)}: {cfg}" + ) + + raw_text, call_error = await self._call_model(cfg, messages) + attempt_ms = round((time.monotonic() - attempt_start) * 1000) + + if call_error: + self.metrics.record(cfg.name, "error" if is_last else "fallback") + logger.warning( + f"{tag}[llm_router] {cfg.name} → rejected " + f"({call_error}) [{attempt_ms}ms]" + ) + attempts.append(AttemptRecord( + model=cfg.name, outcome="rejected", + reason=call_error, latency_ms=attempt_ms, + )) + continue + + parsed, schema_error = self._validate(raw_text, schema) + if schema_error: + self.metrics.record(cfg.name, "error" if is_last else "fallback") + logger.warning( + f"{tag}[llm_router] {cfg.name} → invalid " + f"({schema_error}) [{attempt_ms}ms]" + ) + attempts.append(AttemptRecord( + model=cfg.name, outcome="invalid", + reason=schema_error, latency_ms=attempt_ms, + )) + continue + + # ── success ─────────────────────────────────────────────── + self.metrics.record(cfg.name, "success") + total_ms = round((time.monotonic() - start) * 1000) + logger.info( + f"{tag}[llm_router] {cfg.name} → success " + f"[attempt {attempt_ms}ms, total {total_ms}ms]" + ) + attempts.append(AttemptRecord( + model=cfg.name, outcome="success", + reason=None, latency_ms=attempt_ms, + )) + result = RouteResult( + content=parsed, + raw_text=raw_text, + model_used=cfg.name, + attempts=attempts, + latency_ms=total_ms, + ) + await self._publish_metrics(result) + return result + + # ── all models exhausted ────────────────────────────────────── + total_ms = round((time.monotonic() - start) * 1000) + result = RouteResult( + content=None, + raw_text="", + model_used="none", + attempts=attempts, + latency_ms=total_ms, + ) + await self._publish_metrics(result) + + attempt_log = json.dumps( + [{"model": a.model, "reason": a.reason} for a in attempts], + indent=2, + ) + raise RuntimeError( + f"{tag}[llm_router] All {len(self.chain)} models in chain failed.\n" + f"Attempts:\n{attempt_log}" + ) + + async def close(self) -> None: + """Release the Redis connection.""" + if self._redis is not None: + await self._redis.aclose() + self._redis = None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _call_model( + self, + cfg: ModelConfig, + messages: list[dict], + ) -> tuple[str, Optional[str]]: + """Invoke one model. Returns (raw_text, error_reason|None).""" + kwargs: dict[str, Any] = { + "model": cfg.name, + "messages": messages, + "timeout": cfg.timeout, + **cfg.extra_kwargs, + } + if cfg.api_base: + kwargs["api_base"] = cfg.api_base + + try: + # asyncio.wait_for as belt-and-suspenders — litellm timeout + # is passed to the underlying HTTP client, but asyncio task + # cancellation ensures we never block the event loop. + resp = await asyncio.wait_for( + litellm.acompletion(**kwargs), + timeout=cfg.timeout + 2, # +2 s grace for HTTP overhead + ) + text = (resp.choices[0].message.content or "").strip() + except asyncio.TimeoutError: + return "", f"Timeout after {cfg.timeout}s" + except litellm.exceptions.Timeout: + return "", f"Timeout after {cfg.timeout}s" + except litellm.exceptions.APIConnectionError as e: + return "", f"APIConnectionError: {e}" + except litellm.exceptions.AuthenticationError as e: + return "", f"AuthenticationError: {e}" + except Exception as e: + return "", f"{type(e).__name__}: {e}" + + # Check for refusals in the model's own text + refusal = self._detect_refusal(text) + if refusal: + return text, f"RefusalPattern matched: '{refusal}'" + return text, None + + @staticmethod + def _detect_refusal(text: str) -> Optional[str]: + """Return the first matching refusal pattern, or None.""" + lower = text.lower() + for pattern in REFUSAL_PATTERNS: + if pattern.lower() in lower: + return pattern + return None + + @staticmethod + def _validate( + text: str, + schema: Optional[dict], + ) -> tuple[Any, Optional[str]]: + """Parse and validate the model response. + + Returns (parsed_content, error_reason|None). + When schema is None, returns (raw_text, None) — only refusal + detection (already done in _call_model) applies. + """ + if schema is None: + return text, None + + try: + parsed = json.loads(text) + except json.JSONDecodeError as exc: + # Try to extract JSON from a markdown code fence + extracted = _extract_json_from_fence(text) + if extracted is not None: + parsed = extracted + else: + return None, f"JSONDecodeError: {exc}" + + try: + validate(instance=parsed, schema=schema) + except ValidationError as exc: + return None, f"SchemaValidationError: {exc.message}" + + return parsed, None + + async def _get_redis(self) -> Optional[aioredis.Redis]: + if self._redis_url is None: + return None + if self._redis is None: + self._redis = aioredis.from_url( + self._redis_url, + decode_responses=True, + socket_connect_timeout=2, + socket_timeout=2, + ) + return self._redis + + async def _publish_metrics(self, result: RouteResult) -> None: + """Non-blocking publish to Redis channel 'llm_router_metrics'.""" + payload = { + **result.to_dict(), + "metrics_snapshot": self.metrics.snapshot(), + "timestamp": time.time(), + } + try: + r = await self._get_redis() + if r is not None: + await r.publish("llm_router_metrics", json.dumps(payload)) + except Exception as exc: + # Never let a metrics failure break the caller + logger.warning(f"[llm_router] metrics publish failed: {exc}") + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- + +def _extract_json_from_fence(text: str) -> Optional[Any]: + """Extract JSON from a ```json ... ``` markdown code fence, if present.""" + import re + match = re.search(r"```(?:json)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + pass + return None diff --git a/services/planner-agent/tests/test_llm_router.py b/services/planner-agent/tests/test_llm_router.py new file mode 100644 index 0000000..9c025de --- /dev/null +++ b/services/planner-agent/tests/test_llm_router.py @@ -0,0 +1,449 @@ +""" +Unit tests for llm_router.py. + +All LLM and Redis calls are mocked — no network required. + +Run: + pip install pytest pytest-asyncio litellm jsonschema redis + pytest services/planner-agent/tests/test_llm_router.py -v +""" + +import asyncio +import json +import sys +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Allow importing from src/ without installation +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from llm_router import ( + AttemptRecord, + LLMRouter, + ModelConfig, + ModelMetrics, + RouteResult, + _extract_json_from_fence, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _fake_completion(content: str): + """Build a minimal litellm-style response object.""" + msg = MagicMock() + msg.content = content + choice = MagicMock() + choice.message = msg + resp = MagicMock() + resp.choices = [choice] + return resp + + +def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]: + """Build a minimal test chain with no api_base.""" + return [ModelConfig(name=name, timeout=timeout) for name, timeout in models] + + +# --------------------------------------------------------------------------- +# ModelMetrics +# --------------------------------------------------------------------------- + +class TestModelMetrics: + def test_record_and_snapshot(self): + m = ModelMetrics() + m.record("qwen", "success") + m.record("qwen", "success") + m.record("qwen", "fallback") + m.record("haiku", "success") + + snap = m.snapshot() + assert snap["qwen"]["success"] == 2 + assert snap["qwen"]["fallback"] == 1 + assert snap["haiku"]["success"] == 1 + + def test_success_rate(self): + m = ModelMetrics() + m.record("model-a", "success") + m.record("model-a", "fallback") + assert m.success_rate("model-a") == 0.5 + + def test_success_rate_unknown_model(self): + m = ModelMetrics() + assert m.success_rate("ghost") is None + + def test_total_calls(self): + m = ModelMetrics() + m.record("x", "success") + m.record("x", "error") + assert m.total_calls("x") == 2 + + def test_snapshot_is_copy(self): + m = ModelMetrics() + m.record("x", "success") + snap = m.snapshot() + snap["x"]["success"] = 999 # mutate the copy + assert m.snapshot()["x"]["success"] == 1 # original unchanged + + +# --------------------------------------------------------------------------- +# RouteResult +# --------------------------------------------------------------------------- + +class TestRouteResult: + def test_succeeded(self): + r = RouteResult("hello", "hello", "model-a", [], 100) + assert r.succeeded is True + + def test_not_succeeded(self): + r = RouteResult(None, "", "none", [], 100) + assert r.succeeded is False + + def test_to_dict_structure(self): + attempt = AttemptRecord("m", "success", None, 50) + r = RouteResult("x", "x", "m", [attempt], 50) + d = r.to_dict() + assert d["model_used"] == "m" + assert len(d["attempts"]) == 1 + assert d["attempts"][0]["outcome"] == "success" + + +# --------------------------------------------------------------------------- +# _extract_json_from_fence +# --------------------------------------------------------------------------- + +class TestExtractJsonFromFence: + def test_json_fence(self): + text = 'Sure!\n```json\n{"a": 1}\n```\nDone.' + assert _extract_json_from_fence(text) == {"a": 1} + + def test_plain_fence(self): + text = "```\n[1, 2]\n```" + assert _extract_json_from_fence(text) == [1, 2] + + def test_no_fence(self): + assert _extract_json_from_fence("no json here") is None + + def test_broken_fence(self): + assert _extract_json_from_fence("```json\n{broken```") is None + + +# --------------------------------------------------------------------------- +# LLMRouter — validation & refusal detection +# --------------------------------------------------------------------------- + +class TestLLMRouterValidation: + def setup_method(self): + self.router = LLMRouter(redis_url=None) + + def test_validate_no_schema_returns_text(self): + parsed, err = self.router._validate("hello world", schema=None) + assert parsed == "hello world" + assert err is None + + def test_validate_valid_json(self): + schema = {"type": "object", "required": ["action"]} + parsed, err = self.router._validate('{"action": "redeploy"}', schema) + assert err is None + assert parsed == {"action": "redeploy"} + + def test_validate_invalid_json(self): + schema = {"type": "object"} + _, err = self.router._validate("not json {", schema) + assert err is not None + assert "JSONDecodeError" in err + + def test_validate_schema_violation(self): + schema = {"type": "object", "required": ["action"]} + _, err = self.router._validate('{"other": 1}', schema) + assert err is not None + assert "SchemaValidationError" in err + + def test_validate_extracts_fenced_json(self): + schema = {"type": "object", "required": ["action"]} + text = '```json\n{"action": "restart"}\n```' + parsed, err = self.router._validate(text, schema) + assert err is None + assert parsed == {"action": "restart"} + + def test_detect_refusal_nie_wiem(self): + assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem" + + def test_detect_refusal_as_an_ai(self): + # Text contains both "as an AI" and "I cannot"; first match wins. + # We only assert that a refusal IS detected, not which pattern fires. + assert self.router._detect_refusal("As an AI I cannot help") is not None + + def test_detect_refusal_none(self): + assert self.router._detect_refusal("Sure, here is the action.") is None + + def test_detect_refusal_case_insensitive(self): + assert self.router._detect_refusal("I CANNOT do that") == "I cannot" + + +# --------------------------------------------------------------------------- +# LLMRouter — routing logic (mocked litellm + Redis) +# --------------------------------------------------------------------------- + +@pytest.fixture +def router_no_redis(): + """Router with a 3-model chain and Redis disabled.""" + chain = _chain_of( + ("local/qwen", 8.0), + ("claude-haiku-test", 30.0), + ("claude-sonnet-test", 30.0), + ) + return LLMRouter(redis_url=None, chain=chain) + + +@pytest.mark.asyncio +class TestLLMRouterRouting: + async def test_primary_success(self, router_no_redis): + """Primary model succeeds — no fallback.""" + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion('{"action": "ok"}')), + ): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "test"}], + schema={"type": "object", "required": ["action"]}, + ) + + assert result.model_used == "local/qwen" + assert result.content == {"action": "ok"} + assert len(result.attempts) == 1 + assert result.attempts[0].outcome == "success" + + async def test_fallback_on_json_error(self, router_no_redis): + """Primary returns bad JSON → falls back to haiku which returns valid JSON.""" + responses = [ + _fake_completion("not json at all"), + _fake_completion('{"action": "restart"}'), + ] + call_count = 0 + + async def fake_acompletion(**kwargs): + nonlocal call_count + r = responses[call_count] + call_count += 1 + return r + + with patch("litellm.acompletion", fake_acompletion): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + schema={"type": "object", "required": ["action"]}, + ) + + assert result.model_used == "claude-haiku-test" + assert result.content == {"action": "restart"} + assert result.attempts[0].outcome == "invalid" + assert result.attempts[1].outcome == "success" + + async def test_fallback_on_refusal(self, router_no_redis): + """Primary returns refusal text → falls back.""" + responses = [ + _fake_completion("I cannot help with that request."), + _fake_completion("Sure! Here is the plan."), + ] + idx = 0 + + async def fake_acompletion(**kwargs): + nonlocal idx + r = responses[idx] + idx += 1 + return r + + with patch("litellm.acompletion", fake_acompletion): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + ) + + assert result.model_used == "claude-haiku-test" + assert result.attempts[0].outcome == "rejected" + assert "RefusalPattern" in result.attempts[0].reason + + async def test_fallback_on_timeout(self, router_no_redis): + """Primary times out → falls back to haiku.""" + call_count = 0 + + async def fake_acompletion(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise asyncio.TimeoutError() + return _fake_completion("fallback response") + + with patch("litellm.acompletion", fake_acompletion): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + ) + + assert result.model_used == "claude-haiku-test" + assert "Timeout" in result.attempts[0].reason + + async def test_fallback_on_api_exception(self, router_no_redis): + """Primary raises a connection error → falls back.""" + import litellm.exceptions + + call_count = 0 + + async def fake_acompletion(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise litellm.exceptions.APIConnectionError( + message="Connection refused", llm_provider="ollama", model="qwen" + ) + return _fake_completion("ok") + + with patch("litellm.acompletion", fake_acompletion): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + ) + + assert result.model_used == "claude-haiku-test" + assert "APIConnectionError" in result.attempts[0].reason + + async def test_all_models_fail_raises(self, router_no_redis): + """All models return bad JSON → RuntimeError with attempt log.""" + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("not json")), + ): + with pytest.raises(RuntimeError) as exc_info: + await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + schema={"type": "object"}, + ) + + assert "All 3 models in chain failed" in str(exc_info.value) + + async def test_schema_none_no_json_required(self, router_no_redis): + """Without a schema, plain text responses are accepted.""" + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("Here is your plan.")), + ): + result = await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + ) + + assert result.content == "Here is your plan." + assert result.succeeded + + async def test_metrics_recorded_on_success(self, router_no_redis): + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("ok")), + ): + await router_no_redis.route([{"role": "user", "content": "x"}]) + + snap = router_no_redis.metrics.snapshot() + assert snap["local/qwen"]["success"] == 1 + + async def test_metrics_fallback_recorded(self, router_no_redis): + """Primary fails → fallback → metrics show primary=fallback, haiku=success.""" + responses = [ + _fake_completion("I cannot help"), # refusal + _fake_completion("ok"), + ] + idx = 0 + + async def fake_acompletion(**kwargs): + nonlocal idx + r = responses[idx]; idx += 1 + return r + + with patch("litellm.acompletion", fake_acompletion): + await router_no_redis.route([{"role": "user", "content": "x"}]) + + snap = router_no_redis.metrics.snapshot() + assert snap["local/qwen"]["fallback"] == 1 + assert snap["claude-haiku-test"]["success"] == 1 + + async def test_context_label_in_logs(self, router_no_redis, caplog): + """context= parameter appears in log output.""" + import logging + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("ok")), + ): + with caplog.at_level(logging.INFO, logger="llm_router"): + await router_no_redis.route( + messages=[{"role": "user", "content": "x"}], + context="supervisor.reconcile", + ) + + assert any("supervisor.reconcile" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# LLMRouter — Redis metrics publish +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +class TestLLMRouterRedis: + async def test_metrics_published_on_success(self): + chain = _chain_of(("m1", 8.0),) + router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) + + mock_redis = AsyncMock() + router._redis = mock_redis + + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("ok")), + ): + await router.route([{"role": "user", "content": "x"}]) + + mock_redis.publish.assert_awaited_once() + channel, payload_str = mock_redis.publish.call_args[0] + assert channel == "llm_router_metrics" + payload = json.loads(payload_str) + assert payload["model_used"] == "m1" + assert "metrics_snapshot" in payload + assert "timestamp" in payload + + async def test_redis_failure_does_not_raise(self): + """A broken Redis must never break the LLM call result.""" + chain = _chain_of(("m1", 8.0),) + router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) + + mock_redis = AsyncMock() + mock_redis.publish.side_effect = ConnectionError("Redis down") + router._redis = mock_redis + + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("ok")), + ): + result = await router.route([{"role": "user", "content": "x"}]) + + assert result.succeeded # LLM call still returned + + async def test_metrics_published_on_full_failure(self): + chain = _chain_of(("m1", 8.0),) + router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) + + mock_redis = AsyncMock() + router._redis = mock_redis + + with patch( + "litellm.acompletion", + AsyncMock(return_value=_fake_completion("not json")), + ): + with pytest.raises(RuntimeError): + await router.route( + messages=[{"role": "user", "content": "x"}], + schema={"type": "object"}, + ) + + # Metrics must still be published even when we raise + mock_redis.publish.assert_awaited_once() + _, payload_str = mock_redis.publish.call_args[0] + payload = json.loads(payload_str) + assert payload["model_used"] == "none"