Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import io | |
import time | |
import streamlit as st | |
from PIL import Image | |
from service import Service | |
""" | |
使用 mistralai 官方库的 Service 类处理 API 请求 | |
""" | |
# 设置页面配置 - 必须是第一个Streamlit命令 | |
st.set_page_config( | |
page_title="Mistral 聊天助手", | |
page_icon="🤖", | |
layout="wide", | |
initial_sidebar_state="collapsed" | |
) | |
# 初始化API服务 | |
service = Service() | |
# 初始化会话状态 | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "image_data" not in st.session_state: | |
st.session_state.image_data = None | |
def encode_image_to_base64(image): | |
"""将图像转换为 base64 字符串""" | |
if image is None: | |
return None | |
try: | |
# 如果是PIL图像 | |
if isinstance(image, Image.Image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return f"data:image/png;base64,{img_str}" | |
# 如果是字节流或文件上传对象 | |
elif hasattr(image, 'read') or isinstance(image, bytes): | |
if hasattr(image, 'read'): | |
image_bytes = image.read() | |
else: | |
image_bytes = image | |
img_str = base64.b64encode(image_bytes).decode("utf-8") | |
return f"data:image/png;base64,{img_str}" | |
# 如果是文件路径 | |
elif isinstance(image, str) and os.path.isfile(image): | |
with open(image, "rb") as img_file: | |
img_str = base64.b64encode(img_file.read()).decode("utf-8") | |
return f"data:image/png;base64,{img_str}" | |
else: | |
st.error(f"不支持的图像类型: {type(image)}") | |
return None | |
except Exception as e: | |
st.error(f"编码图像时出错: {str(e)}") | |
return None | |
def read_file_content(file_path): | |
"""提取文件内容""" | |
if file_path is None: | |
return None | |
try: | |
print(f"尝试读取文件内容: {file_path}") | |
file_ext = os.path.splitext(file_path)[1].lower() | |
# 文本文件扩展名列表 | |
text_exts = ['.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.csv', '.xml', '.yaml', '.yml', '.ini', '.conf'] | |
if file_ext in text_exts: | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
print(f"成功读取文件内容,长度: {len(content)}") | |
return content | |
except UnicodeDecodeError: | |
# 尝试使用其他编码 | |
try: | |
with open(file_path, 'r', encoding='gbk') as f: | |
content = f.read() | |
print(f"使用GBK编码成功读取文件内容,长度: {len(content)}") | |
return content | |
except: | |
print(f"无法解码文件内容,可能是二进制文件") | |
return f"无法读取文件内容,文件可能是二进制格式或使用了不支持的编码。" | |
else: | |
return f"文件类型 {file_ext} 暂不支持直接读取内容,但我可以尝试分析文件名称。" | |
except Exception as e: | |
print(f"读取文件时出错: {str(e)}") | |
return f"读取文件时出错: {str(e)}" | |
def respond( | |
message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
image=None | |
): | |
try: | |
print(f"响应函数收到:message={message[:50]}...(已截断), 图片={image is not None}") | |
# 准备完整的消息历史 | |
messages = [{"role": "system", "content": system_message}] | |
# 添加历史消息 | |
for msg in history: | |
if msg["role"] == "user": | |
messages.append({"role": "user", "content": msg["content"]}) | |
elif msg["role"] == "assistant": | |
messages.append({"role": "assistant", "content": msg["content"]}) | |
# 设置模型和参数 | |
service.model = "mistral-small-latest" # 可以根据需要修改为其他模型 | |
# 处理带图像的请求 | |
if image is not None: | |
print("处理带图像的请求...") | |
# 使用 chat_with_image 方法处理多模态请求 | |
response = service.chat_with_image( | |
text_prompt=message if message else "请分析这张图片", | |
image_base64=image, | |
history=messages | |
) | |
print("图像请求已发送到API") | |
else: | |
print("处理纯文本请求...") | |
# 纯文本请求,添加用户消息并获取响应 | |
messages.append({"role": "user", "content": message}) | |
response = service.get_response(messages) | |
# 返回响应结果 | |
print(f"API返回响应: {response[:50]}...(已截断)") | |
return response | |
except Exception as e: | |
print(f"API 请求错误: {str(e)}") | |
return f"处理请求时出错: {str(e)}" | |
# 加载系统提示 | |
def load_system_prompt(): | |
return """你是一个有帮助的AI助手,可以回答用户的问题,也可以分析用户上传的图片。 | |
如果用户上传了图片,请详细描述图片内容,并回答用户关于图片的问题。 | |
如果用户没有上传图片,请正常回答用户的文本问题。 | |
""" | |
# 获取API响应 | |
def get_api_response(prompt, image_data=None): | |
try: | |
# 准备消息历史(不包括最新的用户消息) | |
messages = [] | |
# 添加系统消息 | |
messages.append({"role": "system", "content": load_system_prompt()}) | |
# 添加历史消息 | |
for msg in st.session_state.messages: | |
if msg["role"] != "system": # 跳过系统消息,因为我们已经添加了 | |
messages.append({"role": msg["role"], "content": msg["content"]}) | |
# 处理带图像的请求 | |
if image_data: | |
st.info("正在处理图像...") | |
# 使用 chat_with_image 方法处理多模态请求 | |
return service.chat_with_image( | |
text_prompt=prompt if prompt else "请分析这张图片", | |
image_base64=image_data, | |
history=messages | |
) | |
else: | |
# 添加最新的用户消息 | |
messages.append({"role": "user", "content": prompt}) | |
# 纯文本请求 | |
return service.get_response(messages) | |
except Exception as e: | |
st.error(f"API 请求错误: {str(e)}") | |
return f"处理请求时出错: {str(e)}" | |
# 显示标题和说明 | |
st.title("🤖 Mistral 多模态聊天助手") | |
st.markdown(""" | |
### 使用说明 | |
- 输入文字问题并按回车发送 | |
- 点击"📋 粘贴图片"按钮,然后粘贴剪贴板中的图片 | |
- 也可以使用"📎 上传图片"上传本地图片文件 | |
- 图片和文字可以一起发送,或单独发送 | |
""") | |
# 创建两列布局 | |
col1, col2 = st.columns([3, 1]) | |
with col2: | |
st.subheader("选项") | |
# 添加图片上传按钮 | |
uploaded_file = st.file_uploader("📎 上传图片", type=["jpg", "jpeg", "png"], key="file_uploader") | |
# 粘贴图片按钮 | |
if st.button("📋 粘贴图片"): | |
st.session_state.paste_mode = True | |
# 粘贴模式激活时显示粘贴区域 | |
if "paste_mode" in st.session_state and st.session_state.paste_mode: | |
st.markdown("### 粘贴图片区域") | |
st.markdown("按 Ctrl+V 粘贴图片") | |
# 使用实验性功能接收粘贴的图片 | |
pasted_image = st.camera_input("粘贴的图片会显示在这里", key="camera") | |
if pasted_image: | |
st.session_state.image_data = encode_image_to_base64(pasted_image) | |
st.session_state.paste_mode = False | |
st.experimental_rerun() | |
# 如果通过文件上传器上传了图片 | |
if uploaded_file: | |
st.session_state.image_data = encode_image_to_base64(uploaded_file) | |
st.image(uploaded_file, caption="已上传的图片", use_column_width=True) | |
# 清除图片按钮 | |
if st.session_state.image_data and st.button("🗑️ 清除图片"): | |
st.session_state.image_data = None | |
st.experimental_rerun() | |
# 清除对话按钮 | |
if st.button("🧹 清除对话"): | |
st.session_state.messages = [] | |
st.session_state.image_data = None | |
st.experimental_rerun() | |
with col1: | |
# 显示聊天历史 | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
# 显示消息内容 | |
st.markdown(message["content"]) | |
# 如果消息包含图片 | |
if "image" in message and message["image"]: | |
st.image(message["image"], use_column_width=True) | |
# 显示当前上传的图片预览 | |
if st.session_state.image_data: | |
with st.expander("📷 当前图片预览", expanded=True): | |
# 从base64解码图片以显示预览 | |
if "base64" in st.session_state.image_data: | |
image_b64 = st.session_state.image_data.split(",")[1] | |
image_bytes = base64.b64decode(image_b64) | |
st.image(image_bytes, caption="即将发送的图片", use_column_width=True) | |
# 用户输入 | |
prompt = st.chat_input("输入您的问题...", key="user_input") | |
# 处理用户输入 | |
if prompt: | |
# 添加用户消息到历史 | |
user_message = {"role": "user", "content": prompt} | |
if st.session_state.image_data: | |
user_message["image"] = st.session_state.image_data | |
st.session_state.messages.append(user_message) | |
# 显示用户消息 | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
if st.session_state.image_data: | |
# 从base64解码图片以显示预览 | |
if "base64" in st.session_state.image_data: | |
image_b64 = st.session_state.image_data.split(",")[1] | |
image_bytes = base64.b64decode(image_b64) | |
st.image(image_bytes, use_column_width=True) | |
# 显示助手思考中的状态 | |
with st.chat_message("assistant"): | |
with st.spinner("思考中..."): | |
# 获取API响应 | |
response = get_api_response(prompt, st.session_state.image_data) | |
# 显示响应 | |
message_placeholder = st.empty() | |
full_response = "" | |
# 模拟流式响应 | |
for chunk in response.split(): | |
full_response += chunk + " " | |
message_placeholder.markdown(full_response + "▌") | |
time.sleep(0.01) | |
message_placeholder.markdown(full_response) | |
# 添加助手响应到历史 | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
# 清除当前图片数据,防止重复使用 | |
st.session_state.image_data = None | |
# 重新运行页面以更新UI | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
# 从环境变量获取 API 密钥,或者提示用户设置 | |
api_key = os.environ.get("MISTRAL_API_KEY", "") | |
if not api_key: | |
st.sidebar.warning("未设置 MISTRAL_API_KEY 环境变量。请设置环境变量或在代码中直接设置密钥。") | |
api_key = st.sidebar.text_input("输入您的 Mistral API 密钥:", type="password") | |
# 设置 API 密钥 | |
if api_key: | |
service.headers = {"Authorization": f"Bearer {api_key}"} | |
st.sidebar.success("API密钥已配置") | |
else: | |
st.sidebar.error("请设置 Mistral API 密钥以继续使用") | |