import asyncio
from pathlib import Path
from typing import Dict, List
import streamlit as st
import yaml
from loguru import logger as _logger
import shutil
import uuid
from metagpt.const import METAGPT_ROOT
from metagpt.ext.spo.components.optimizer import PromptOptimizer
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType
def get_user_workspace():
if "user_id" not in st.session_state:
st.session_state.user_id = str(uuid.uuid4())
workspace_dir = Path("workspace") / st.session_state.user_id
workspace_dir.mkdir(parents=True, exist_ok=True)
return workspace_dir
def cleanup_workspace(workspace_dir: Path) -> None:
try:
if workspace_dir.exists():
shutil.rmtree(workspace_dir)
_logger.info(f"Cleaned up workspace directory: {workspace_dir}")
except Exception as e:
_logger.error(f"Error cleaning up workspace: {e}")
def get_template_path(template_name: str, is_new_template: bool = False) -> str:
"""
Get template file path
:param template_name: Name of the template
:param is_new_template: Whether it's a new template created by user
:return: Path object for the template file
"""
if is_new_template:
# Create user-specific subdirectory in settings folder
if "user_id" not in st.session_state:
st.session_state.user_id = str(uuid.uuid4())
user_settings_path = st.session_state.user_id
return f"{user_settings_path}/{template_name}.yaml"
else:
# Use root settings path for existing templates
return f"{template_name}.yaml"
def get_all_templates() -> List[str]:
"""
Get list of all available templates (both default and user-specific)
:return: List of template names
"""
settings_path = Path("metagpt/ext/spo/settings")
# Get default templates
templates = [f.stem for f in settings_path.glob("*.yaml")]
# Get user-specific templates if user_id exists
if "user_id" in st.session_state:
user_path = settings_path / st.session_state.user_id
if user_path.exists():
user_templates = [f"{st.session_state.user_id}/{f.stem}" for f in user_path.glob("*.yaml")]
templates.extend(user_templates)
return sorted(list(set(templates)))
def load_yaml_template(template_path: Path) -> Dict:
if template_path.exists():
with open(template_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]}
def save_yaml_template(template_path: Path, data: Dict, is_new: bool) -> None:
if is_new:
template_format = {
"prompt": str(data.get("prompt", "")),
"requirements": str(data.get("requirements", "")),
"count": data.get("count"),
"qa": [
{"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()}
for qa in data.get("qa", [])
],
}
template_path.parent.mkdir(parents=True, exist_ok=True)
with open(template_path, "w", encoding="utf-8") as f:
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
else:
pass
def display_optimization_results(result_data):
for result in result_data:
round_num = result["round"]
success = result["succeed"]
prompt = result["prompt"]
with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"):
st.markdown("**提示词:**")
st.code(prompt, language="text")
st.markdown("
", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}")
with col2:
st.markdown(f"**令牌数:** {result['tokens']}")
st.markdown("**回答:**")
for idx, answer in enumerate(result["answers"]):
st.markdown(f"**问题 {idx + 1}:**")
st.text(answer["question"])
st.markdown("**答案:**")
st.text(answer["answer"])
st.markdown("---")
# 总结
success_count = sum(1 for r in result_data if r["succeed"])
total_rounds = len(result_data)
st.markdown("### 总结")
col1, col2 = st.columns(2)
with col1:
st.metric("总轮次", total_rounds)
with col2:
st.metric("成功轮次", success_count)
def main():
if "optimization_results" not in st.session_state:
st.session_state.optimization_results = []
workspace_dir = get_user_workspace()
st.markdown(
"""
SPO | 自监督提示词优化 🤖
一个自监督提示词优化框架
""",
unsafe_allow_html=True
)
# 侧边栏配置
with st.sidebar:
st.header("配置")
# 模板选择/创建
settings_path = Path("metagpt/ext/spo/settings")
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
template_mode = st.radio("模板模式", ["使用现有", "创建新模板"])
existing_templates = get_all_templates()
if template_mode == "使用现有":
template_name = st.selectbox("选择模板", existing_templates)
is_new_template = False
else:
template_name = st.text_input("新模板名称")
is_new_template = True
# LLM 设置
st.subheader("LLM 设置")
base_url = st.text_input("基础 URL", value="https://api.example.com")
api_key = st.text_input("API 密钥", type="password")
opt_model = st.selectbox(
"优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
)
opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7)
eval_model = st.selectbox(
"评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
)
eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3)
exec_model = st.selectbox(
"执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
)
exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0)
# 优化器设置
st.subheader("优化器设置")
initial_round = st.number_input("初始轮次", 1, 100, 1)
max_rounds = st.number_input("最大轮次", 1, 100, 10)
# 主要内容区域
st.header("模板配置")
if template_name:
template_real_name = get_template_path(template_name, is_new_template)
settings_path = Path("metagpt/ext/spo/settings")
template_path = settings_path / template_real_name
template_data = load_yaml_template(template_path)
if "current_template" not in st.session_state or st.session_state.current_template != template_name:
st.session_state.current_template = template_name
st.session_state.qas = template_data.get("qa", [])
# 编辑模板部分
prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100)
requirements = st.text_area("要求", template_data.get("requirements", ""), height=100)
# 问答部分
st.subheader("问答示例")
# 添加新问答按钮
if st.button("添加新问答"):
st.session_state.qas.append({"question": "", "answer": ""})
# 编辑问答
new_qas = []
for i in range(len(st.session_state.qas)):
st.markdown(f"**问答 #{i + 1}**")
col1, col2, col3 = st.columns([45, 45, 10])
with col1:
question = st.text_area(
f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
)
with col2:
answer = st.text_area(
f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
)
with col3:
if st.button("🗑️", key=f"delete_{i}"):
st.session_state.qas.pop(i)
st.rerun()
new_qas.append({"question": question, "answer": answer})
# 保存模板按钮
if st.button("保存模板"):
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
save_yaml_template(template_path, new_template_data, is_new_template)
st.session_state.qas = new_qas
st.success(f"模板已保存到 {template_path}")
st.subheader("当前模板预览")
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
st.subheader("优化日志")
log_container = st.empty()
class StreamlitSink:
def write(self, message):
current_logs = st.session_state.get("logs", [])
current_logs.append(message.strip())
st.session_state.logs = current_logs
log_container.code("\n".join(current_logs), language="plaintext")
streamlit_sink = StreamlitSink()
_logger.remove()
def prompt_optimizer_filter(record):
return "optimizer" in record["name"].lower()
_logger.add(
streamlit_sink.write,
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}",
filter=prompt_optimizer_filter,
)
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
# 开始优化按钮
if st.button("开始优化"):
try:
# Initialize LLM
SPO_LLM.initialize(
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
"api_key": api_key},
evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url,
"api_key": api_key},
execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url,
"api_key": api_key},
)
# Create optimizer instance
optimizer = PromptOptimizer(
optimized_path=str(workspace_dir),
initial_round=initial_round,
max_rounds=max_rounds,
template=f"{template_real_name}",
name=template_name,
)
# Run optimization with progress bar
with st.spinner("Optimizing prompts..."):
optimizer.optimize()
st.success("优化完成!")
st.header("优化结果")
prompt_path = optimizer.root_path / "prompts"
result_data = optimizer.data_utils.load_results(prompt_path)
st.session_state.optimization_results = result_data
except Exception as e:
st.error(f"发生错误:{str(e)}")
_logger.error(f"优化过程中出错:{str(e)}")
if st.session_state.optimization_results:
st.header("优化结果")
display_optimization_results(st.session_state.optimization_results)
st.markdown("---")
st.subheader("测试优化后的提示词")
col1, col2 = st.columns(2)
with col1:
test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt")
with col2:
test_question = st.text_area("你的问题", value="", height=200, key="test_question")
if st.button("测试提示词"):
if test_prompt and test_question:
try:
with st.spinner("正在生成回答..."):
SPO_LLM.initialize(
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
"api_key": api_key},
evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url,
"api_key": api_key},
execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url,
"api_key": api_key},
)
llm = SPO_LLM.get_instance()
messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
async def get_response():
return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
response = loop.run_until_complete(get_response())
finally:
loop.close()
st.subheader("回答:")
st.markdown(response)
except Exception as e:
st.error(f"生成回答时出错:{str(e)}")
else:
st.warning("请输入提示词和问题。")
if __name__ == "__main__":
main()