""" Unit tests for llm_router.py. All LLM and Redis calls are mocked — no network required. Run: pip install pytest pytest-asyncio litellm jsonschema redis pytest services/planner-agent/tests/test_llm_router.py -v """ import asyncio import json import sys from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest # Allow importing from src/ without installation sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from llm_router import ( AttemptRecord, LLMRouter, ModelConfig, ModelMetrics, RouteResult, _extract_json_from_fence, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _fake_completion(content: str): """Build a minimal litellm-style response object.""" msg = MagicMock() msg.content = content choice = MagicMock() choice.message = msg resp = MagicMock() resp.choices = [choice] return resp def _chain_of(*models: tuple[str, float]) -> list[ModelConfig]: """Build a minimal test chain with no api_base.""" return [ModelConfig(name=name, timeout=timeout) for name, timeout in models] # --------------------------------------------------------------------------- # ModelMetrics # --------------------------------------------------------------------------- class TestModelMetrics: def test_record_and_snapshot(self): m = ModelMetrics() m.record("qwen", "success") m.record("qwen", "success") m.record("qwen", "fallback") m.record("haiku", "success") snap = m.snapshot() assert snap["qwen"]["success"] == 2 assert snap["qwen"]["fallback"] == 1 assert snap["haiku"]["success"] == 1 def test_success_rate(self): m = ModelMetrics() m.record("model-a", "success") m.record("model-a", "fallback") assert m.success_rate("model-a") == 0.5 def test_success_rate_unknown_model(self): m = ModelMetrics() assert m.success_rate("ghost") is None def test_total_calls(self): m = ModelMetrics() m.record("x", "success") m.record("x", "error") assert m.total_calls("x") == 2 def test_snapshot_is_copy(self): m = ModelMetrics() m.record("x", "success") snap = m.snapshot() snap["x"]["success"] = 999 # mutate the copy assert m.snapshot()["x"]["success"] == 1 # original unchanged # --------------------------------------------------------------------------- # RouteResult # --------------------------------------------------------------------------- class TestRouteResult: def test_succeeded(self): r = RouteResult("hello", "hello", "model-a", [], 100) assert r.succeeded is True def test_not_succeeded(self): r = RouteResult(None, "", "none", [], 100) assert r.succeeded is False def test_to_dict_structure(self): attempt = AttemptRecord("m", "success", None, 50) r = RouteResult("x", "x", "m", [attempt], 50) d = r.to_dict() assert d["model_used"] == "m" assert len(d["attempts"]) == 1 assert d["attempts"][0]["outcome"] == "success" # --------------------------------------------------------------------------- # _extract_json_from_fence # --------------------------------------------------------------------------- class TestExtractJsonFromFence: def test_json_fence(self): text = 'Sure!\n```json\n{"a": 1}\n```\nDone.' assert _extract_json_from_fence(text) == {"a": 1} def test_plain_fence(self): text = "```\n[1, 2]\n```" assert _extract_json_from_fence(text) == [1, 2] def test_no_fence(self): assert _extract_json_from_fence("no json here") is None def test_broken_fence(self): assert _extract_json_from_fence("```json\n{broken```") is None # --------------------------------------------------------------------------- # LLMRouter — validation & refusal detection # --------------------------------------------------------------------------- class TestLLMRouterValidation: def setup_method(self): self.router = LLMRouter(redis_url=None) def test_validate_no_schema_returns_text(self): parsed, err = self.router._validate("hello world", schema=None) assert parsed == "hello world" assert err is None def test_validate_valid_json(self): schema = {"type": "object", "required": ["action"]} parsed, err = self.router._validate('{"action": "redeploy"}', schema) assert err is None assert parsed == {"action": "redeploy"} def test_validate_invalid_json(self): schema = {"type": "object"} _, err = self.router._validate("not json {", schema) assert err is not None assert "JSONDecodeError" in err def test_validate_schema_violation(self): schema = {"type": "object", "required": ["action"]} _, err = self.router._validate('{"other": 1}', schema) assert err is not None assert "SchemaValidationError" in err def test_validate_extracts_fenced_json(self): schema = {"type": "object", "required": ["action"]} text = '```json\n{"action": "restart"}\n```' parsed, err = self.router._validate(text, schema) assert err is None assert parsed == {"action": "restart"} def test_detect_refusal_nie_wiem(self): assert self.router._detect_refusal("Nie wiem co mam zrobić") == "nie wiem" def test_detect_refusal_as_an_ai(self): # Text contains both "as an AI" and "I cannot"; first match wins. # We only assert that a refusal IS detected, not which pattern fires. assert self.router._detect_refusal("As an AI I cannot help") is not None def test_detect_refusal_none(self): assert self.router._detect_refusal("Sure, here is the action.") is None def test_detect_refusal_case_insensitive(self): assert self.router._detect_refusal("I CANNOT do that") == "I cannot" # --------------------------------------------------------------------------- # LLMRouter — routing logic (mocked litellm + Redis) # --------------------------------------------------------------------------- @pytest.fixture def router_no_redis(): """Router with a 3-model chain and Redis disabled.""" chain = _chain_of( ("local/qwen", 8.0), ("claude-haiku-test", 30.0), ("claude-sonnet-test", 30.0), ) return LLMRouter(redis_url=None, chain=chain) @pytest.mark.asyncio class TestLLMRouterRouting: async def test_primary_success(self, router_no_redis): """Primary model succeeds — no fallback.""" with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion('{"action": "ok"}')), ): result = await router_no_redis.route( messages=[{"role": "user", "content": "test"}], schema={"type": "object", "required": ["action"]}, ) assert result.model_used == "local/qwen" assert result.content == {"action": "ok"} assert len(result.attempts) == 1 assert result.attempts[0].outcome == "success" async def test_fallback_on_json_error(self, router_no_redis): """Primary returns bad JSON → falls back to haiku which returns valid JSON.""" responses = [ _fake_completion("not json at all"), _fake_completion('{"action": "restart"}'), ] call_count = 0 async def fake_acompletion(**kwargs): nonlocal call_count r = responses[call_count] call_count += 1 return r with patch("litellm.acompletion", fake_acompletion): result = await router_no_redis.route( messages=[{"role": "user", "content": "x"}], schema={"type": "object", "required": ["action"]}, ) assert result.model_used == "claude-haiku-test" assert result.content == {"action": "restart"} assert result.attempts[0].outcome == "invalid" assert result.attempts[1].outcome == "success" async def test_fallback_on_refusal(self, router_no_redis): """Primary returns refusal text → falls back.""" responses = [ _fake_completion("I cannot help with that request."), _fake_completion("Sure! Here is the plan."), ] idx = 0 async def fake_acompletion(**kwargs): nonlocal idx r = responses[idx] idx += 1 return r with patch("litellm.acompletion", fake_acompletion): result = await router_no_redis.route( messages=[{"role": "user", "content": "x"}], ) assert result.model_used == "claude-haiku-test" assert result.attempts[0].outcome == "rejected" assert "RefusalPattern" in result.attempts[0].reason async def test_fallback_on_timeout(self, router_no_redis): """Primary times out → falls back to haiku.""" call_count = 0 async def fake_acompletion(**kwargs): nonlocal call_count call_count += 1 if call_count == 1: raise asyncio.TimeoutError() return _fake_completion("fallback response") with patch("litellm.acompletion", fake_acompletion): result = await router_no_redis.route( messages=[{"role": "user", "content": "x"}], ) assert result.model_used == "claude-haiku-test" assert "Timeout" in result.attempts[0].reason async def test_fallback_on_api_exception(self, router_no_redis): """Primary raises a connection error → falls back.""" import litellm.exceptions call_count = 0 async def fake_acompletion(**kwargs): nonlocal call_count call_count += 1 if call_count == 1: raise litellm.exceptions.APIConnectionError( message="Connection refused", llm_provider="ollama", model="qwen" ) return _fake_completion("ok") with patch("litellm.acompletion", fake_acompletion): result = await router_no_redis.route( messages=[{"role": "user", "content": "x"}], ) assert result.model_used == "claude-haiku-test" assert "APIConnectionError" in result.attempts[0].reason async def test_all_models_fail_raises(self, router_no_redis): """All models return bad JSON → RuntimeError with attempt log.""" with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("not json")), ): with pytest.raises(RuntimeError) as exc_info: await router_no_redis.route( messages=[{"role": "user", "content": "x"}], schema={"type": "object"}, ) assert "All 3 models in chain failed" in str(exc_info.value) async def test_schema_none_no_json_required(self, router_no_redis): """Without a schema, plain text responses are accepted.""" with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("Here is your plan.")), ): result = await router_no_redis.route( messages=[{"role": "user", "content": "x"}], ) assert result.content == "Here is your plan." assert result.succeeded async def test_metrics_recorded_on_success(self, router_no_redis): with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("ok")), ): await router_no_redis.route([{"role": "user", "content": "x"}]) snap = router_no_redis.metrics.snapshot() assert snap["local/qwen"]["success"] == 1 async def test_metrics_fallback_recorded(self, router_no_redis): """Primary fails → fallback → metrics show primary=fallback, haiku=success.""" responses = [ _fake_completion("I cannot help"), # refusal _fake_completion("ok"), ] idx = 0 async def fake_acompletion(**kwargs): nonlocal idx r = responses[idx]; idx += 1 return r with patch("litellm.acompletion", fake_acompletion): await router_no_redis.route([{"role": "user", "content": "x"}]) snap = router_no_redis.metrics.snapshot() assert snap["local/qwen"]["fallback"] == 1 assert snap["claude-haiku-test"]["success"] == 1 async def test_context_label_in_logs(self, router_no_redis, caplog): """context= parameter appears in log output.""" import logging with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("ok")), ): with caplog.at_level(logging.INFO, logger="llm_router"): await router_no_redis.route( messages=[{"role": "user", "content": "x"}], context="supervisor.reconcile", ) assert any("supervisor.reconcile" in r.message for r in caplog.records) # --------------------------------------------------------------------------- # LLMRouter — Redis metrics publish # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestLLMRouterRedis: async def test_metrics_published_on_success(self): chain = _chain_of(("m1", 8.0),) router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) mock_redis = AsyncMock() router._redis = mock_redis with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("ok")), ): await router.route([{"role": "user", "content": "x"}]) mock_redis.publish.assert_awaited_once() channel, payload_str = mock_redis.publish.call_args[0] assert channel == "llm_router_metrics" payload = json.loads(payload_str) assert payload["model_used"] == "m1" assert "metrics_snapshot" in payload assert "timestamp" in payload async def test_redis_failure_does_not_raise(self): """A broken Redis must never break the LLM call result.""" chain = _chain_of(("m1", 8.0),) router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) mock_redis = AsyncMock() mock_redis.publish.side_effect = ConnectionError("Redis down") router._redis = mock_redis with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("ok")), ): result = await router.route([{"role": "user", "content": "x"}]) assert result.succeeded # LLM call still returned async def test_metrics_published_on_full_failure(self): chain = _chain_of(("m1", 8.0),) router = LLMRouter(redis_url="redis://localhost:6379", chain=chain) mock_redis = AsyncMock() router._redis = mock_redis with patch( "litellm.acompletion", AsyncMock(return_value=_fake_completion("not json")), ): with pytest.raises(RuntimeError): await router.route( messages=[{"role": "user", "content": "x"}], schema={"type": "object"}, ) # Metrics must still be published even when we raise mock_redis.publish.assert_awaited_once() _, payload_str = mock_redis.publish.call_args[0] payload = json.loads(payload_str) assert payload["model_used"] == "none"