OmniThink / src /test.py
ZekunXi's picture
push
80a598c
import asyncio
import json
from tenacity import retry, stop_after_attempt, wait_exponential
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm
base_url="http://47.88.8.18:8088/v1"
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6Imp3dCJ9.eyJ1c2VybmFtZSI6IjQzNzkzNyIsInBhc3N3b3JkIjoiNDM3OTM3MTIzIiwiZXhwIjoyMDMxODEwOTAzfQ.M710cSMdw1OZ2TBVPRhlnoavZ8CQG5tXgj3WGl3FoIg"
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
# 设置并发限制
MAX_RETRIES = 10
BASE_DELAY = 1
MAX_DELAY = 60
MAX_CONCURRENT = 64
model = "gpt-4"
# model = "gpt-4o-2024-08-06"
@retry(stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=4, max=60))
async def get_chat_completion(message: str, semaphore, retry_count=0) -> str:
try:
async with semaphore: # 使用传入的信号量限制并发
response = await client.chat.completions.create(
model=model,
messages=[{"role": "system", "content": "you are a helpful assistant"}, {"role": "user", "content": message}],
timeout=80
)
response_result = response.choices[0].message.content
# message[model] = response_result
temp = {}
temp["prompt"] = message
temp["label"] = response_result
return temp
except Exception as e:
print(f"Error in get_chat_completion for message {type(e).__name__} - {str(e)}")
raise
async def request_model(prompts):
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
async def wrapped_get_chat_completion(prompt):
try:
return await get_chat_completion(prompt, semaphore)
except Exception as e:
print(f"Task failed after all retries with error: {e}")
return None
tasks = [wrapped_get_chat_completion(prompt) for prompt in prompts]
results = []
for future in tqdm.as_completed(tasks, total=len(tasks), desc="Processing prompts"):
result = await future
results.append(result)
return results
if __name__ == "__main__":
prompts = ["测试测试测试"+str(i) for i in range(10)]
results = asyncio.run(request_model(prompts))
with open(f'{model}_result.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results,ensure_ascii=False))