|
import logging |
|
import os |
|
import random |
|
import threading |
|
import json |
|
import dashscope |
|
from typing import Optional, Literal, Any |
|
import time |
|
|
|
import backoff |
|
import dspy |
|
import litellm |
|
import requests |
|
|
|
|
|
from dashscope import Generation |
|
|
|
try: |
|
from anthropic import RateLimitError |
|
except ImportError: |
|
RateLimitError = None |
|
|
|
|
|
MAX_API_RETRY = 3 |
|
LLM_MIT_RETRY_SLEEP = 5 |
|
SUPPORT_ARGS = {"model", "messages", "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens", |
|
"n", "presence_penalty", "response_format", "seed", "stop", "stream", "temperature", "top_p", |
|
"tools", "tool_choice", "user", "function_call", "functions", "tenant", "max_completion_tokens"} |
|
|
|
|
|
def truncate_long_strings(d): |
|
if isinstance(d, dict): |
|
return {k: truncate_long_strings(v) for k, v in d.items()} |
|
elif isinstance(d, list): |
|
return [truncate_long_strings(item) for item in d] |
|
elif isinstance(d, str) and len(d) > 100: |
|
return d[:100] + '...' |
|
else: |
|
return d |
|
|
|
|
|
class QwenModel(dspy.OpenAI): |
|
"""A wrapper class for dspy.OpenAI.""" |
|
|
|
def __init__( |
|
self, |
|
model: str = "qwen-max-allinone", |
|
api_key: Optional[str] = None, |
|
**kwargs |
|
): |
|
super().__init__(model=model, api_key=api_key, **kwargs) |
|
self.model = model |
|
self.api_key = api_key |
|
self._token_usage_lock = threading.Lock() |
|
self.prompt_tokens = 0 |
|
self.completion_tokens = 0 |
|
|
|
def log_usage(self, response): |
|
"""Log the total tokens from the OpenAI API response.""" |
|
usage_data = response.get('usage') |
|
if usage_data: |
|
with self._token_usage_lock: |
|
self.prompt_tokens += usage_data.get('input_tokens', 0) |
|
self.completion_tokens += usage_data.get('output_tokens', 0) |
|
|
|
def get_usage_and_reset(self): |
|
"""Get the total tokens used and reset the token usage.""" |
|
usage = { |
|
self.kwargs.get('model') or self.kwargs.get('engine'): |
|
{'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} |
|
} |
|
self.prompt_tokens = 0 |
|
self.completion_tokens = 0 |
|
|
|
return usage |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
only_completed: bool = True, |
|
return_sorted: bool = False, |
|
**kwargs, |
|
) -> list[dict[str, Any]]: |
|
"""Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" |
|
|
|
assert only_completed, "for now" |
|
assert return_sorted is False, "for now" |
|
|
|
|
|
|
|
messages = [{'role': 'user', 'content': prompt}] |
|
max_retries = 3 |
|
attempt = 0 |
|
|
|
while attempt < max_retries: |
|
try: |
|
response = Generation.call( |
|
model=self.model, |
|
messages=messages, |
|
result_format='message', |
|
) |
|
choices = response["output"]["choices"] |
|
break |
|
|
|
except Exception as e: |
|
print(f"请求失败: {e}. 尝试重新请求...") |
|
delay = random.uniform(0, 3) |
|
print(f"等待 {delay:.2f} 秒后重试...") |
|
time.sleep(delay) |
|
attempt += 1 |
|
|
|
self.log_usage(response) |
|
|
|
completed_choices = [c for c in choices if c["finish_reason"] != "length"] |
|
|
|
if only_completed and len(completed_choices): |
|
choices = completed_choices |
|
|
|
completions = [c['message']['content'] for c in choices] |
|
|
|
return completions |
|
|