File size: 3,023 Bytes
718aa48 a84172d 770da3f a84172d 770da3f a84172d 770da3f a84172d 718aa48 a84172d 770da3f a84172d 770da3f a84172d 718aa48 770da3f 718aa48 770da3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import asyncio
import httpx
import os
import json
# Load model config at startup
with open("llm/model_config.json", "r") as f:
CONFIG = json.load(f)
PROVIDERS = CONFIG["providers"]
MODEL_PROVIDER_MAPPING = CONFIG["models"]
async def call_model_api(model: str, prompt: str) -> str:
provider_key = MODEL_PROVIDER_MAPPING.get(model)
if not provider_key:
raise ValueError(f"No provider configured for model: {model}")
provider = PROVIDERS.get(provider_key)
if not provider:
raise ValueError(f"Provider {provider_key} not found in config")
url = provider["url"]
api_key_env = provider["key_env"]
api_key = os.getenv(api_key_env)
if not api_key:
raise ValueError(f"Missing API key for provider {provider_key}")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
body = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, headers=headers, json=body)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
async def query_llm_agent(name: str, prompt: str, settings: dict) -> str:
selected_model = settings.get("models", {}).get(name)
if not selected_model:
return f"[{name}] No model selected."
if selected_model not in MODEL_PROVIDER_MAPPING:
return f"[{name}] Model '{selected_model}' is not supported."
try:
response = await call_model_api(selected_model, prompt)
return f"[{name}] {response}"
except Exception as e:
return f"[{name}] Error: {str(e)}"
async def query_all_llms(prompt: str, settings: dict) -> list:
agents = ["LLM-A", "LLM-B", "LLM-C"]
tasks = [query_llm_agent(agent, prompt, settings) for agent in agents]
return await asyncio.gather(*tasks)
async def query_aggregator(responses: list, settings: dict) -> str:
model = settings.get("aggregator")
if not model:
return "[Aggregator] No aggregator model selected."
if model not in MODEL_PROVIDER_MAPPING:
return f"[Aggregator] Model '{model}' is not supported."
system_prompt = (
"You are an aggregator AI. Your task is to read the following responses "
"from different AI agents and produce a single, high-quality response.\n\n"
+ "\n\n".join(responses)
)
try:
result = await call_model_api(model, system_prompt)
return f"[Aggregator] {result}"
except Exception as e:
return f"[Aggregator] Error: {str(e)}"
def query_all_llms_sync(prompt: str, settings: dict) -> list:
return asyncio.run(query_moa_chain(prompt, settings))
async def query_moa_chain(prompt: str, settings: dict) -> list:
responses = await query_all_llms(prompt, settings)
aggregator = await query_aggregator(responses, settings)
return responses + [aggregator]
|