File size: 3,023 Bytes
718aa48
a84172d
 
770da3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84172d
 
 
 
 
 
770da3f
a84172d
770da3f
a84172d
 
718aa48
 
a84172d
 
 
 
 
770da3f
a84172d
 
 
770da3f
a84172d
 
 
718aa48
 
 
 
770da3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718aa48
 
770da3f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import asyncio
import httpx
import os
import json

# Load model config at startup
with open("llm/model_config.json", "r") as f:
    CONFIG = json.load(f)

PROVIDERS = CONFIG["providers"]
MODEL_PROVIDER_MAPPING = CONFIG["models"]

async def call_model_api(model: str, prompt: str) -> str:
    provider_key = MODEL_PROVIDER_MAPPING.get(model)
    if not provider_key:
        raise ValueError(f"No provider configured for model: {model}")

    provider = PROVIDERS.get(provider_key)
    if not provider:
        raise ValueError(f"Provider {provider_key} not found in config")

    url = provider["url"]
    api_key_env = provider["key_env"]
    api_key = os.getenv(api_key_env)

    if not api_key:
        raise ValueError(f"Missing API key for provider {provider_key}")

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }

    body = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": 0.7,
    }

    async with httpx.AsyncClient(timeout=30) as client:
        response = await client.post(url, headers=headers, json=body)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]

async def query_llm_agent(name: str, prompt: str, settings: dict) -> str:
    selected_model = settings.get("models", {}).get(name)

    if not selected_model:
        return f"[{name}] No model selected."

    if selected_model not in MODEL_PROVIDER_MAPPING:
        return f"[{name}] Model '{selected_model}' is not supported."

    try:
        response = await call_model_api(selected_model, prompt)
        return f"[{name}] {response}"
    except Exception as e:
        return f"[{name}] Error: {str(e)}"

async def query_all_llms(prompt: str, settings: dict) -> list:
    agents = ["LLM-A", "LLM-B", "LLM-C"]
    tasks = [query_llm_agent(agent, prompt, settings) for agent in agents]
    return await asyncio.gather(*tasks)

async def query_aggregator(responses: list, settings: dict) -> str:
    model = settings.get("aggregator")
    if not model:
        return "[Aggregator] No aggregator model selected."
    if model not in MODEL_PROVIDER_MAPPING:
        return f"[Aggregator] Model '{model}' is not supported."

    system_prompt = (
        "You are an aggregator AI. Your task is to read the following responses "
        "from different AI agents and produce a single, high-quality response.\n\n"
        + "\n\n".join(responses)
    )

    try:
        result = await call_model_api(model, system_prompt)
        return f"[Aggregator] {result}"
    except Exception as e:
        return f"[Aggregator] Error: {str(e)}"

def query_all_llms_sync(prompt: str, settings: dict) -> list:
    return asyncio.run(query_moa_chain(prompt, settings))

async def query_moa_chain(prompt: str, settings: dict) -> list:
    responses = await query_all_llms(prompt, settings)
    aggregator = await query_aggregator(responses, settings)
    return responses + [aggregator]