|
import asyncio |
|
import json |
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
from openai import AsyncOpenAI |
|
from tqdm.asyncio import tqdm |
|
base_url="http://47.88.8.18:8088/v1" |
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6Imp3dCJ9.eyJ1c2VybmFtZSI6IjQzNzkzNyIsInBhc3N3b3JkIjoiNDM3OTM3MTIzIiwiZXhwIjoyMDMxODEwOTAzfQ.M710cSMdw1OZ2TBVPRhlnoavZ8CQG5tXgj3WGl3FoIg" |
|
client = AsyncOpenAI(base_url=base_url, api_key=api_key) |
|
|
|
|
|
MAX_RETRIES = 10 |
|
BASE_DELAY = 1 |
|
MAX_DELAY = 60 |
|
MAX_CONCURRENT = 64 |
|
|
|
model = "gpt-4" |
|
|
|
|
|
@retry(stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=4, max=60)) |
|
async def get_chat_completion(message: str, semaphore, retry_count=0) -> str: |
|
try: |
|
async with semaphore: |
|
response = await client.chat.completions.create( |
|
model=model, |
|
messages=[{"role": "system", "content": "you are a helpful assistant"}, {"role": "user", "content": message}], |
|
timeout=80 |
|
) |
|
response_result = response.choices[0].message.content |
|
|
|
temp = {} |
|
temp["prompt"] = message |
|
temp["label"] = response_result |
|
return temp |
|
except Exception as e: |
|
print(f"Error in get_chat_completion for message {type(e).__name__} - {str(e)}") |
|
raise |
|
|
|
|
|
async def request_model(prompts): |
|
|
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT) |
|
async def wrapped_get_chat_completion(prompt): |
|
try: |
|
return await get_chat_completion(prompt, semaphore) |
|
except Exception as e: |
|
print(f"Task failed after all retries with error: {e}") |
|
return None |
|
|
|
tasks = [wrapped_get_chat_completion(prompt) for prompt in prompts] |
|
|
|
results = [] |
|
for future in tqdm.as_completed(tasks, total=len(tasks), desc="Processing prompts"): |
|
result = await future |
|
results.append(result) |
|
|
|
return results |
|
|
|
if __name__ == "__main__": |
|
prompts = ["测试测试测试"+str(i) for i in range(10)] |
|
|
|
results = asyncio.run(request_model(prompts)) |
|
|
|
with open(f'{model}_result.json', 'w', encoding="utf-8") as f: |
|
f.write(json.dumps(results,ensure_ascii=False)) |