mirror of
https://github.com/x1xhlol/system-prompts-and-models-of-ai-tools.git
synced 2026-06-18 23:39:34 +00:00
193 lines
6.8 KiB
Python
193 lines
6.8 KiB
Python
"""
|
|
Model Router — intelligently routes tasks to LLM providers with fallback.
|
|
مُوجّه النماذج — يرسل كل مهمة لأفضل مزود مع احتياط عند الفشل.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from core.config.models import (
|
|
FALLBACK_CHAIN,
|
|
TASK_ROUTING,
|
|
Provider,
|
|
Task,
|
|
)
|
|
from core.config.settings import Settings, get_settings
|
|
from core.llm.anthropic_client import AnthropicClient
|
|
from core.llm.base import LLMClient, LLMResponse, Message
|
|
from core.llm.gemini_client import GeminiClient
|
|
from core.llm.glm_client import GLMClient
|
|
from core.llm.openai_compat import DeepSeekClient, GroqClient, OpenAIClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class UsageRecord:
|
|
"""Tracks calls/tokens per provider | يتتبع الاستدعاءات والرموز لكل مزود."""
|
|
|
|
calls: int = 0
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
errors: int = 0
|
|
fallbacks_triggered: int = 0
|
|
|
|
|
|
class ModelRouter:
|
|
"""
|
|
Routes a Task to the appropriate LLM client with fallback chain.
|
|
يوجّه المهمة إلى عميل النموذج المناسب مع سلسلة احتياط.
|
|
"""
|
|
|
|
def __init__(self, settings: Settings | None = None) -> None:
|
|
self.settings = settings or get_settings()
|
|
self._clients: dict[Provider, LLMClient | None] = {}
|
|
self.usage: dict[Provider, UsageRecord] = {p: UsageRecord() for p in Provider}
|
|
self._build_clients()
|
|
|
|
# ── Client construction ─────────────────────────────────────
|
|
def _build_clients(self) -> None:
|
|
"""Instantiate clients only for providers that have API keys set."""
|
|
s = self.settings
|
|
|
|
if s.anthropic_api_key:
|
|
self._clients[Provider.ANTHROPIC] = AnthropicClient(
|
|
api_key=s.anthropic_api_key.get_secret_value(),
|
|
model=s.anthropic_model,
|
|
timeout=s.anthropic_timeout,
|
|
)
|
|
|
|
if s.deepseek_api_key:
|
|
self._clients[Provider.DEEPSEEK] = DeepSeekClient(
|
|
api_key=s.deepseek_api_key.get_secret_value(),
|
|
model=s.deepseek_model,
|
|
base_url=s.deepseek_base_url,
|
|
)
|
|
|
|
if s.glm_api_key:
|
|
self._clients[Provider.GLM] = GLMClient(
|
|
api_key=s.glm_api_key.get_secret_value(),
|
|
model=s.glm_model,
|
|
base_url=s.glm_base_url,
|
|
)
|
|
|
|
if s.google_api_key:
|
|
self._clients[Provider.GEMINI] = GeminiClient(
|
|
api_key=s.google_api_key.get_secret_value(),
|
|
model=s.gemini_model,
|
|
)
|
|
|
|
if s.groq_api_key:
|
|
self._clients[Provider.GROQ] = GroqClient(
|
|
api_key=s.groq_api_key.get_secret_value(),
|
|
model=s.groq_model,
|
|
base_url=s.groq_base_url,
|
|
)
|
|
|
|
if s.openai_api_key:
|
|
self._clients[Provider.OPENAI] = OpenAIClient(
|
|
api_key=s.openai_api_key.get_secret_value(),
|
|
model=s.openai_model,
|
|
base_url=s.openai_base_url,
|
|
)
|
|
|
|
configured = [p.value for p in self._clients]
|
|
logger.info("ModelRouter initialized with providers: %s", configured)
|
|
|
|
# ── Public API ──────────────────────────────────────────────
|
|
def available_providers(self) -> list[Provider]:
|
|
"""List providers that are actually configured."""
|
|
return list(self._clients.keys())
|
|
|
|
def get_client(self, provider: Provider) -> LLMClient | None:
|
|
return self._clients.get(provider)
|
|
|
|
async def run(
|
|
self,
|
|
task: Task,
|
|
messages: list[Message] | str,
|
|
*,
|
|
system: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
preferred_provider: Provider | None = None,
|
|
) -> LLMResponse:
|
|
"""
|
|
Execute a task through the routing + fallback chain.
|
|
نفّذ المهمة عبر سلسلة التوجيه والاحتياط.
|
|
"""
|
|
# Normalize input
|
|
if isinstance(messages, str):
|
|
messages = [Message(role="user", content=messages)]
|
|
|
|
primary = preferred_provider or TASK_ROUTING.get(task, Provider.ANTHROPIC)
|
|
chain = [primary] + [p for p in FALLBACK_CHAIN.get(primary, []) if p != primary]
|
|
|
|
last_error: Exception | None = None
|
|
for idx, provider in enumerate(chain):
|
|
client = self._clients.get(provider)
|
|
if client is None:
|
|
logger.debug("Skipping unconfigured provider: %s", provider)
|
|
continue
|
|
|
|
usage = self.usage[provider]
|
|
try:
|
|
usage.calls += 1
|
|
if idx > 0:
|
|
self.usage[primary].fallbacks_triggered += 1
|
|
logger.warning(
|
|
"Task=%s fallback to provider=%s (primary=%s)",
|
|
task.value,
|
|
provider.value,
|
|
primary.value,
|
|
)
|
|
|
|
response = await client.chat(
|
|
messages=messages,
|
|
system=system,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
usage.input_tokens += response.input_tokens
|
|
usage.output_tokens += response.output_tokens
|
|
return response
|
|
|
|
except Exception as e:
|
|
usage.errors += 1
|
|
last_error = e
|
|
logger.exception(
|
|
"Provider=%s failed for task=%s: %s", provider.value, task.value, e
|
|
)
|
|
continue
|
|
|
|
raise RuntimeError(f"All providers failed for task {task.value}. Last error: {last_error}")
|
|
|
|
def usage_summary(self) -> dict[str, Any]:
|
|
"""Human-readable usage summary."""
|
|
return {
|
|
provider.value: {
|
|
"calls": record.calls,
|
|
"input_tokens": record.input_tokens,
|
|
"output_tokens": record.output_tokens,
|
|
"total_tokens": record.input_tokens + record.output_tokens,
|
|
"errors": record.errors,
|
|
"fallbacks_triggered": record.fallbacks_triggered,
|
|
}
|
|
for provider, record in self.usage.items()
|
|
}
|
|
|
|
|
|
# ── Singleton ───────────────────────────────────────────────────
|
|
_router_instance: ModelRouter | None = None
|
|
|
|
|
|
def get_router() -> ModelRouter:
|
|
"""Global router singleton."""
|
|
global _router_instance
|
|
if _router_instance is None:
|
|
_router_instance = ModelRouter()
|
|
return _router_instance
|