import asyncio from pathlib import Path from typing import Dict, List import streamlit as st import yaml from loguru import logger as _logger import shutil import uuid from metagpt.const import METAGPT_ROOT from metagpt.ext.spo.components.optimizer import PromptOptimizer from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType def get_user_workspace(): if "user_id" not in st.session_state: st.session_state.user_id = str(uuid.uuid4()) workspace_dir = Path("workspace") / st.session_state.user_id workspace_dir.mkdir(parents=True, exist_ok=True) return workspace_dir def cleanup_workspace(workspace_dir: Path) -> None: try: if workspace_dir.exists(): shutil.rmtree(workspace_dir) _logger.info(f"Cleaned up workspace directory: {workspace_dir}") except Exception as e: _logger.error(f"Error cleaning up workspace: {e}") def get_template_path(template_name: str, is_new_template: bool = False) -> str: """ Get template file path :param template_name: Name of the template :param is_new_template: Whether it's a new template created by user :return: Path object for the template file """ if is_new_template: # Create user-specific subdirectory in settings folder if "user_id" not in st.session_state: st.session_state.user_id = str(uuid.uuid4()) user_settings_path = st.session_state.user_id return f"{user_settings_path}/{template_name}.yaml" else: # Use root settings path for existing templates return f"{template_name}.yaml" def get_all_templates() -> List[str]: """ Get list of all available templates (both default and user-specific) :return: List of template names """ settings_path = Path("metagpt/ext/spo/settings") # Get default templates templates = [f.stem for f in settings_path.glob("*.yaml")] # Get user-specific templates if user_id exists if "user_id" in st.session_state: user_path = settings_path / st.session_state.user_id if user_path.exists(): user_templates = [f"{st.session_state.user_id}/{f.stem}" for f in user_path.glob("*.yaml")] templates.extend(user_templates) return sorted(list(set(templates))) def load_yaml_template(template_path: Path) -> Dict: if template_path.exists(): with open(template_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]} def save_yaml_template(template_path: Path, data: Dict, is_new: bool) -> None: if is_new: template_format = { "prompt": str(data.get("prompt", "")), "requirements": str(data.get("requirements", "")), "count": data.get("count"), "qa": [ {"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()} for qa in data.get("qa", []) ], } template_path.parent.mkdir(parents=True, exist_ok=True) with open(template_path, "w", encoding="utf-8") as f: yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2) else: pass def display_optimization_results(result_data): for result in result_data: round_num = result["round"] success = result["succeed"] prompt = result["prompt"] with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"): st.markdown("**提示词:**") st.code(prompt, language="text") st.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns(2) with col1: st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}") with col2: st.markdown(f"**令牌数:** {result['tokens']}") st.markdown("**回答:**") for idx, answer in enumerate(result["answers"]): st.markdown(f"**问题 {idx + 1}:**") st.text(answer["question"]) st.markdown("**答案:**") st.text(answer["answer"]) st.markdown("---") # 总结 success_count = sum(1 for r in result_data if r["succeed"]) total_rounds = len(result_data) st.markdown("### 总结") col1, col2 = st.columns(2) with col1: st.metric("总轮次", total_rounds) with col2: st.metric("成功轮次", success_count) def main(): if "optimization_results" not in st.session_state: st.session_state.optimization_results = [] workspace_dir = get_user_workspace() st.markdown( """

SPO | 自监督提示词优化 🤖

论文 GitHub 一个自监督提示词优化框架
""", unsafe_allow_html=True ) # 侧边栏配置 with st.sidebar: st.header("配置") # 模板选择/创建 settings_path = Path("metagpt/ext/spo/settings") existing_templates = [f.stem for f in settings_path.glob("*.yaml")] template_mode = st.radio("模板模式", ["使用现有", "创建新模板"]) existing_templates = get_all_templates() if template_mode == "使用现有": template_name = st.selectbox("选择模板", existing_templates) is_new_template = False else: template_name = st.text_input("新模板名称") is_new_template = True # LLM 设置 st.subheader("LLM 设置") base_url = st.text_input("基础 URL", value="https://api.example.com") api_key = st.text_input("API 密钥", type="password") opt_model = st.selectbox( "优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 ) opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7) eval_model = st.selectbox( "评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 ) eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3) exec_model = st.selectbox( "执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 ) exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0) # 优化器设置 st.subheader("优化器设置") initial_round = st.number_input("初始轮次", 1, 100, 1) max_rounds = st.number_input("最大轮次", 1, 100, 10) # 主要内容区域 st.header("模板配置") if template_name: template_real_name = get_template_path(template_name, is_new_template) settings_path = Path("metagpt/ext/spo/settings") template_path = settings_path / template_real_name template_data = load_yaml_template(template_path) if "current_template" not in st.session_state or st.session_state.current_template != template_name: st.session_state.current_template = template_name st.session_state.qas = template_data.get("qa", []) # 编辑模板部分 prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100) requirements = st.text_area("要求", template_data.get("requirements", ""), height=100) # 问答部分 st.subheader("问答示例") # 添加新问答按钮 if st.button("添加新问答"): st.session_state.qas.append({"question": "", "answer": ""}) # 编辑问答 new_qas = [] for i in range(len(st.session_state.qas)): st.markdown(f"**问答 #{i + 1}**") col1, col2, col3 = st.columns([45, 45, 10]) with col1: question = st.text_area( f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100 ) with col2: answer = st.text_area( f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100 ) with col3: if st.button("🗑️", key=f"delete_{i}"): st.session_state.qas.pop(i) st.rerun() new_qas.append({"question": question, "answer": answer}) # 保存模板按钮 if st.button("保存模板"): new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas} save_yaml_template(template_path, new_template_data, is_new_template) st.session_state.qas = new_qas st.success(f"模板已保存到 {template_path}") st.subheader("当前模板预览") preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt} st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml") st.subheader("优化日志") log_container = st.empty() class StreamlitSink: def write(self, message): current_logs = st.session_state.get("logs", []) current_logs.append(message.strip()) st.session_state.logs = current_logs log_container.code("\n".join(current_logs), language="plaintext") streamlit_sink = StreamlitSink() _logger.remove() def prompt_optimizer_filter(record): return "optimizer" in record["name"].lower() _logger.add( streamlit_sink.write, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", filter=prompt_optimizer_filter, ) _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG") # 开始优化按钮 if st.button("开始优化"): try: # Initialize LLM SPO_LLM.initialize( optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, "api_key": api_key}, evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, "api_key": api_key}, execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, "api_key": api_key}, ) # Create optimizer instance optimizer = PromptOptimizer( optimized_path=str(workspace_dir), initial_round=initial_round, max_rounds=max_rounds, template=f"{template_real_name}", name=template_name, ) # Run optimization with progress bar with st.spinner("Optimizing prompts..."): optimizer.optimize() st.success("优化完成!") st.header("优化结果") prompt_path = optimizer.root_path / "prompts" result_data = optimizer.data_utils.load_results(prompt_path) st.session_state.optimization_results = result_data except Exception as e: st.error(f"发生错误:{str(e)}") _logger.error(f"优化过程中出错:{str(e)}") if st.session_state.optimization_results: st.header("优化结果") display_optimization_results(st.session_state.optimization_results) st.markdown("---") st.subheader("测试优化后的提示词") col1, col2 = st.columns(2) with col1: test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt") with col2: test_question = st.text_area("你的问题", value="", height=200, key="test_question") if st.button("测试提示词"): if test_prompt and test_question: try: with st.spinner("正在生成回答..."): SPO_LLM.initialize( optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, "api_key": api_key}, evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, "api_key": api_key}, execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, "api_key": api_key}, ) llm = SPO_LLM.get_instance() messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}] async def get_response(): return await llm.responser(request_type=RequestType.EXECUTE, messages=messages) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: response = loop.run_until_complete(get_response()) finally: loop.close() st.subheader("回答:") st.markdown(response) except Exception as e: st.error(f"生成回答时出错:{str(e)}") else: st.warning("请输入提示词和问题。") if __name__ == "__main__": main()