OmniThink / src /lm.py
ZekunXi's picture
push
80a598c
import logging
import os
import random
import threading
import json
import dashscope
from typing import Optional, Literal, Any
import time
import backoff
import dspy
import litellm
import requests
from dashscope import Generation
try:
from anthropic import RateLimitError
except ImportError:
RateLimitError = None
MAX_API_RETRY = 3
LLM_MIT_RETRY_SLEEP = 5
SUPPORT_ARGS = {"model", "messages", "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens",
"n", "presence_penalty", "response_format", "seed", "stop", "stream", "temperature", "top_p",
"tools", "tool_choice", "user", "function_call", "functions", "tenant", "max_completion_tokens"}
def truncate_long_strings(d):
if isinstance(d, dict):
return {k: truncate_long_strings(v) for k, v in d.items()}
elif isinstance(d, list):
return [truncate_long_strings(item) for item in d]
elif isinstance(d, str) and len(d) > 100:
return d[:100] + '...'
else:
return d
class QwenModel(dspy.OpenAI):
"""A wrapper class for dspy.OpenAI."""
def __init__(
self,
model: str = "qwen-max-allinone",
api_key: Optional[str] = None,
**kwargs
):
super().__init__(model=model, api_key=api_key, **kwargs)
self.model = model
self.api_key = api_key
self._token_usage_lock = threading.Lock()
self.prompt_tokens = 0
self.completion_tokens = 0
def log_usage(self, response):
"""Log the total tokens from the OpenAI API response."""
usage_data = response.get('usage')
if usage_data:
with self._token_usage_lock:
self.prompt_tokens += usage_data.get('input_tokens', 0)
self.completion_tokens += usage_data.get('output_tokens', 0)
def get_usage_and_reset(self):
"""Get the total tokens used and reset the token usage."""
usage = {
self.kwargs.get('model') or self.kwargs.get('engine'):
{'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens}
}
self.prompt_tokens = 0
self.completion_tokens = 0
return usage
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage."""
assert only_completed, "for now"
assert return_sorted is False, "for now"
messages = [{'role': 'user', 'content': prompt}]
max_retries = 3
attempt = 0
while attempt < max_retries:
try:
response = Generation.call(
model=self.model,
messages=messages,
result_format='message',
)
choices = response["output"]["choices"]
break
except Exception as e:
print(f"请求失败: {e}. 尝试重新请求...")
delay = random.uniform(0, 3)
print(f"等待 {delay:.2f} 秒后重试...")
time.sleep(delay)
attempt += 1
self.log_usage(response)
completed_choices = [c for c in choices if c["finish_reason"] != "length"]
if only_completed and len(completed_choices):
choices = completed_choices
completions = [c['message']['content'] for c in choices]
return completions