Spaces:
Runtime error
Runtime error
File size: 7,668 Bytes
f7161fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import base64
import datetime
import json
import os
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from loguru import logger
from zhipuai import ZhipuAI
from src import shared, config
from src.base_model import BaseLLMModel
from src.presets import (
INITIAL_SYSTEM_PROMPT,
TIMEOUT_ALL,
TIMEOUT_STREAMING,
STANDARD_ERROR_MSG,
CONNECTION_TIMEOUT_MSG,
READ_TIMEOUT_MSG,
ERROR_RETRIEVE_MSG,
GENERAL_ERROR_MSG,
CHAT_COMPLETION_URL,
SUMMARY_CHAT_SYSTEM_PROMPT
)
from src.openai_client import OpenAIClient
from src.utils import (
count_token,
construct_system,
construct_user,
get_last_day_of_month,
i18n,
replace_special_symbols,
)
def decode_chat_response(response):
try:
error_msg = ""
for chunk in response:
if chunk:
# chunk = chunk.decode()
chunk = chunk.choices[0].delta
chunk_length = len(chunk.content)
try:
if chunk_length > 1 and chunk!="":
try:
yield chunk.content
except Exception as e:
logger.error(f"Error xxx: {e}")
continue
except Exception as ee:
logger.error(f"ERROR: {chunk}, {ee}")
continue
if error_msg and not error_msg.endswith("[DONE]"):
raise Exception(error_msg)
except GeneratorExit as ge:
raise ValueError(f"GeneratorExit: {ge}")
except Exception as e:
raise Exception(f"Error in generate: {str(e)}")
class ZhipuAIClient(OpenAIClient):
def __init__(
self,
model_name,
api_key,
system_prompt=INITIAL_SYSTEM_PROMPT,
temperature=1.0,
top_p=1.0,
user_name="",
) -> None:
super().__init__(
api_key = api_key,
model_name=model_name,
temperature=temperature,
top_p=top_p,
system_prompt=system_prompt,
# user=user_name,
)
self.api_key = api_key
self.need_api_key = True
self._refresh_header()
self.client = None
# self.user_name = user_name
logger.info(f"user name: {user_name}")
def get_answer_stream_iter(self):
if not self.api_key:
raise ValueError("API key is not set")
response = self._get_response(stream=True)
if response is not None:
stream_iter = decode_chat_response(response)
partial_text = ""
for chunk in stream_iter:
partial_text += chunk
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
# def get_answer_at_once(self):
# if not self.api_key:
# raise ValueError("API key is not set")
# response = self._get_response()
# response = json.loads(response.text)
# content = response["choices"][0]["message"]["content"]
# total_token_count = response["usage"]["total_tokens"]
# return content, total_token_count
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
def _get_response(self, stream=False):
zhipuai_api_key = self.api_key
system_prompt = self.system_prompt
history = self.history
# logger.debug(f"{history}")
# headers = {
# "Authorization": f"Bearer {zhipuai_api_key}",
# "Content-Type": "application/json",
# }
if system_prompt is not None:
history = [construct_system(system_prompt), *history]
payload = {
"model": self.model_name,
"messages": history,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n_choices,
"stream": stream,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
}
if self.max_generation_token is not None:
payload["max_tokens"] = self.max_generation_token
if self.stop_sequence is not None:
payload["stop"] = self.stop_sequence
if self.logit_bias is not None:
payload["logit_bias"] = self.logit_bias
if self.user_identifier:
payload["user"] = self.user_identifier
if stream:
timeout = TIMEOUT_STREAMING
else:
timeout = TIMEOUT_ALL
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
# if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
# logger.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
# with config.retrieve_proxy():
# try:
# response = requests.post(
# shared.state.chat_completion_url,
# headers=headers,
# json=payload,
# stream=stream,
# timeout=timeout,
# )
# except Exception as e:
# logger.error(f"Error: {e}")
# response = None
# return response
if self.client is None:
self.client = ZhipuAI(api_key = zhipuai_api_key)
response = self.client.chat.completions.create(
model=self.model_name,
# model="glm-3-turbo",
messages=history,
temperature=self.temperature,
top_p= self.top_p,
stream= stream,
)
# "n": self.n_choices,
# "stream": stream,
# "presence_penalty": self.presence_penalty,
# "frequency_penalty": self.frequency_penalty,
return response
# todo: fix bug
def billing_info(self):
status_text = "获取API使用情况失败,未更新ZhipuAI代价代码。"
return status_text
# try:
# curr_time = datetime.datetime.now()
# last_day_of_month = get_last_day_of_month(
# curr_time).strftime("%Y-%m-%d")
# first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
# usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
# try:
# usage_data = self._get_billing_data(usage_url)
# except Exception as e:
# logger.warning(f"获取API使用情况失败:" + str(e))
# return i18n("**获取API使用情况失败**")
# rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
# except requests.exceptions.ConnectTimeout:
# status_text = (
# STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
# )
# return status_text
# except requests.exceptions.ReadTimeout:
# status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
# return status_text
# except Exception as e:
# import traceback
# traceback.print_exc()
# logger.error(i18n("获取API使用情况失败:") + str(e))
# return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG |