|
import asyncio |
|
import httpx |
|
import os |
|
import json |
|
|
|
|
|
with open("llm/model_config.json", "r") as f: |
|
CONFIG = json.load(f) |
|
|
|
PROVIDERS = CONFIG["providers"] |
|
MODEL_PROVIDER_MAPPING = CONFIG["models"] |
|
|
|
async def call_model_api(model: str, prompt: str) -> str: |
|
provider_key = MODEL_PROVIDER_MAPPING.get(model) |
|
if not provider_key: |
|
raise ValueError(f"No provider configured for model: {model}") |
|
|
|
provider = PROVIDERS.get(provider_key) |
|
if not provider: |
|
raise ValueError(f"Provider {provider_key} not found in config") |
|
|
|
url = provider["url"] |
|
api_key_env = provider["key_env"] |
|
api_key = os.getenv(api_key_env) |
|
|
|
if not api_key: |
|
raise ValueError(f"Missing API key for provider {provider_key}") |
|
|
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json", |
|
} |
|
|
|
body = { |
|
"model": model, |
|
"messages": [{"role": "user", "content": prompt}], |
|
"temperature": 0.7, |
|
} |
|
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
response = await client.post(url, headers=headers, json=body) |
|
response.raise_for_status() |
|
return response.json()["choices"][0]["message"]["content"] |
|
|
|
async def query_llm_agent(name: str, prompt: str, settings: dict) -> str: |
|
selected_model = settings.get("models", {}).get(name) |
|
|
|
if not selected_model: |
|
return f"[{name}] No model selected." |
|
|
|
if selected_model not in MODEL_PROVIDER_MAPPING: |
|
return f"[{name}] Model '{selected_model}' is not supported." |
|
|
|
try: |
|
response = await call_model_api(selected_model, prompt) |
|
return f"[{name}] {response}" |
|
except Exception as e: |
|
return f"[{name}] Error: {str(e)}" |
|
|
|
async def query_all_llms(prompt: str, settings: dict) -> list: |
|
agents = ["LLM-A", "LLM-B", "LLM-C"] |
|
tasks = [query_llm_agent(agent, prompt, settings) for agent in agents] |
|
return await asyncio.gather(*tasks) |
|
|
|
async def query_aggregator(responses: list, settings: dict) -> str: |
|
model = settings.get("aggregator") |
|
if not model: |
|
return "[Aggregator] No aggregator model selected." |
|
if model not in MODEL_PROVIDER_MAPPING: |
|
return f"[Aggregator] Model '{model}' is not supported." |
|
|
|
system_prompt = ( |
|
"You are an aggregator AI. Your task is to read the following responses " |
|
"from different AI agents and produce a single, high-quality response.\n\n" |
|
+ "\n\n".join(responses) |
|
) |
|
|
|
try: |
|
result = await call_model_api(model, system_prompt) |
|
return f"[Aggregator] {result}" |
|
except Exception as e: |
|
return f"[Aggregator] Error: {str(e)}" |
|
|
|
def query_all_llms_sync(prompt: str, settings: dict) -> list: |
|
return asyncio.run(query_moa_chain(prompt, settings)) |
|
|
|
async def query_moa_chain(prompt: str, settings: dict) -> list: |
|
responses = await query_all_llms(prompt, settings) |
|
aggregator = await query_aggregator(responses, settings) |
|
return responses + [aggregator] |
|
|