import asyncio import httpx import os import json # Load model config at startup with open("llm/model_config.json", "r") as f: CONFIG = json.load(f) PROVIDERS = CONFIG["providers"] MODEL_PROVIDER_MAPPING = CONFIG["models"] async def call_model_api(model: str, prompt: str) -> str: provider_key = MODEL_PROVIDER_MAPPING.get(model) if not provider_key: raise ValueError(f"No provider configured for model: {model}") provider = PROVIDERS.get(provider_key) if not provider: raise ValueError(f"Provider {provider_key} not found in config") url = provider["url"] api_key_env = provider["key_env"] api_key = os.getenv(api_key_env) if not api_key: raise ValueError(f"Missing API key for provider {provider_key}") headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } body = { "model": model, "messages": [{"role": "user", "content": prompt}], "temperature": 0.7, } async with httpx.AsyncClient(timeout=30) as client: response = await client.post(url, headers=headers, json=body) response.raise_for_status() return response.json()["choices"][0]["message"]["content"] async def query_llm_agent(name: str, prompt: str, settings: dict) -> str: selected_model = settings.get("models", {}).get(name) if not selected_model: return f"[{name}] No model selected." if selected_model not in MODEL_PROVIDER_MAPPING: return f"[{name}] Model '{selected_model}' is not supported." try: response = await call_model_api(selected_model, prompt) return f"[{name}] {response}" except Exception as e: return f"[{name}] Error: {str(e)}" async def query_all_llms(prompt: str, settings: dict) -> list: agents = ["LLM-A", "LLM-B", "LLM-C"] tasks = [query_llm_agent(agent, prompt, settings) for agent in agents] return await asyncio.gather(*tasks) async def query_aggregator(responses: list, settings: dict) -> str: model = settings.get("aggregator") if not model: return "[Aggregator] No aggregator model selected." if model not in MODEL_PROVIDER_MAPPING: return f"[Aggregator] Model '{model}' is not supported." system_prompt = ( "You are an aggregator AI. Your task is to read the following responses " "from different AI agents and produce a single, high-quality response.\n\n" + "\n\n".join(responses) ) try: result = await call_model_api(model, system_prompt) return f"[Aggregator] {result}" except Exception as e: return f"[Aggregator] Error: {str(e)}" def query_all_llms_sync(prompt: str, settings: dict) -> list: return asyncio.run(query_moa_chain(prompt, settings)) async def query_moa_chain(prompt: str, settings: dict) -> list: responses = await query_all_llms(prompt, settings) aggregator = await query_aggregator(responses, settings) return responses + [aggregator]