|
import asyncio |
|
from pathlib import Path |
|
from typing import Dict, List |
|
|
|
import streamlit as st |
|
import yaml |
|
from loguru import logger as _logger |
|
import shutil |
|
import uuid |
|
|
|
from metagpt.const import METAGPT_ROOT |
|
from metagpt.ext.spo.components.optimizer import PromptOptimizer |
|
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType |
|
|
|
|
|
|
|
def get_user_workspace(): |
|
if "user_id" not in st.session_state: |
|
st.session_state.user_id = str(uuid.uuid4()) |
|
|
|
workspace_dir = Path("workspace") / st.session_state.user_id |
|
workspace_dir.mkdir(parents=True, exist_ok=True) |
|
return workspace_dir |
|
|
|
def cleanup_workspace(workspace_dir: Path) -> None: |
|
try: |
|
if workspace_dir.exists(): |
|
shutil.rmtree(workspace_dir) |
|
_logger.info(f"Cleaned up workspace directory: {workspace_dir}") |
|
except Exception as e: |
|
_logger.error(f"Error cleaning up workspace: {e}") |
|
|
|
def get_template_path(template_name: str, is_new_template: bool = False) -> str: |
|
""" |
|
Get template file path |
|
:param template_name: Name of the template |
|
:param is_new_template: Whether it's a new template created by user |
|
:return: Path object for the template file |
|
""" |
|
|
|
if is_new_template: |
|
|
|
if "user_id" not in st.session_state: |
|
st.session_state.user_id = str(uuid.uuid4()) |
|
user_settings_path = st.session_state.user_id |
|
return f"{user_settings_path}/{template_name}.yaml" |
|
else: |
|
|
|
return f"{template_name}.yaml" |
|
|
|
def get_all_templates() -> List[str]: |
|
""" |
|
Get list of all available templates (both default and user-specific) |
|
:return: List of template names |
|
""" |
|
settings_path = Path("metagpt/ext/spo/settings") |
|
|
|
|
|
templates = [f.stem for f in settings_path.glob("*.yaml")] |
|
|
|
|
|
if "user_id" in st.session_state: |
|
user_path = settings_path / st.session_state.user_id |
|
if user_path.exists(): |
|
user_templates = [f.stem for f in user_path.glob("*.yaml")] |
|
templates.extend(user_templates) |
|
|
|
return sorted(list(set(templates))) |
|
|
|
|
|
def load_yaml_template(template_path: Path) -> Dict: |
|
if template_path.exists(): |
|
with open(template_path, "r", encoding="utf-8") as f: |
|
return yaml.safe_load(f) |
|
return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]} |
|
|
|
|
|
def save_yaml_template(template_path: Path, data: Dict) -> None: |
|
template_format = { |
|
"prompt": str(data.get("prompt", "")), |
|
"requirements": str(data.get("requirements", "")), |
|
"count": data.get("count"), |
|
"qa": [ |
|
{"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()} |
|
for qa in data.get("qa", []) |
|
], |
|
} |
|
|
|
template_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
with open(template_path, "w", encoding="utf-8") as f: |
|
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2) |
|
|
|
|
|
def display_optimization_results(result_data): |
|
for result in result_data: |
|
round_num = result["round"] |
|
success = result["succeed"] |
|
prompt = result["prompt"] |
|
|
|
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"): |
|
st.markdown("**Prompt:**") |
|
st.code(prompt, language="text") |
|
st.markdown("<br>", unsafe_allow_html=True) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}") |
|
with col2: |
|
st.markdown(f"**Tokens:** {result['tokens']}") |
|
|
|
st.markdown("**Answers:**") |
|
for idx, answer in enumerate(result["answers"]): |
|
st.markdown(f"**Question {idx + 1}:**") |
|
st.text(answer["question"]) |
|
st.markdown("**Answer:**") |
|
st.text(answer["answer"]) |
|
st.markdown("---") |
|
|
|
|
|
success_count = sum(1 for r in result_data if r["succeed"]) |
|
total_rounds = len(result_data) |
|
|
|
st.markdown("### Summary") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("Total Rounds", total_rounds) |
|
with col2: |
|
st.metric("Successful Rounds", success_count) |
|
|
|
|
|
def main(): |
|
if "optimization_results" not in st.session_state: |
|
st.session_state.optimization_results = [] |
|
|
|
workspace_dir = get_user_workspace() |
|
|
|
st.title("SPO | Self-Supervised Prompt Optimization 🤖") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuration") |
|
|
|
|
|
settings_path = Path("metagpt/ext/spo/settings") |
|
existing_templates = [f.stem for f in settings_path.glob("*.yaml")] |
|
|
|
template_mode = st.radio("Template Mode", ["Use Existing", "Create New"]) |
|
|
|
existing_templates = get_all_templates() |
|
|
|
if template_mode == "Use Existing": |
|
template_name = st.selectbox("Select Template", existing_templates) |
|
is_new_template = False |
|
else: |
|
template_name = st.text_input("New Template Name") |
|
is_new_template = True |
|
|
|
|
|
st.subheader("LLM Settings") |
|
|
|
base_url = st.text_input("Base URL", value="https://api.example.com") |
|
api_key = st.text_input("API Key", type="password") |
|
|
|
opt_model = st.selectbox( |
|
"Optimization Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
|
) |
|
opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7) |
|
|
|
eval_model = st.selectbox( |
|
"Evaluation Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
|
) |
|
eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3) |
|
|
|
exec_model = st.selectbox( |
|
"Execution Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0 |
|
) |
|
exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0) |
|
|
|
|
|
st.subheader("Optimizer Settings") |
|
initial_round = st.number_input("Initial Round", 1, 100, 1) |
|
max_rounds = st.number_input("Maximum Rounds", 1, 100, 10) |
|
|
|
|
|
st.header("Template Configuration") |
|
|
|
if template_name: |
|
template_real_name = get_template_path(template_name, is_new_template) |
|
settings_path = Path("metagpt/ext/spo/settings") |
|
|
|
template_path = settings_path / template_real_name |
|
|
|
template_data = load_yaml_template(template_path) |
|
|
|
if "current_template" not in st.session_state or st.session_state.current_template != template_name: |
|
st.session_state.current_template = template_name |
|
st.session_state.qas = template_data.get("qa", []) |
|
|
|
|
|
prompt = st.text_area("Prompt", template_data.get("prompt", ""), height=100) |
|
requirements = st.text_area("Requirements", template_data.get("requirements", ""), height=100) |
|
|
|
|
|
st.subheader("Q&A Examples") |
|
|
|
|
|
if st.button("Add New Q&A"): |
|
st.session_state.qas.append({"question": "", "answer": ""}) |
|
|
|
|
|
new_qas = [] |
|
for i in range(len(st.session_state.qas)): |
|
st.markdown(f"**QA #{i + 1}**") |
|
col1, col2, col3 = st.columns([45, 45, 10]) |
|
|
|
with col1: |
|
question = st.text_area( |
|
f"Question {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100 |
|
) |
|
with col2: |
|
answer = st.text_area( |
|
f"Answer {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100 |
|
) |
|
with col3: |
|
if st.button("🗑️", key=f"delete_{i}"): |
|
st.session_state.qas.pop(i) |
|
st.rerun() |
|
|
|
new_qas.append({"question": question, "answer": answer}) |
|
|
|
|
|
if st.button("Save Template"): |
|
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas} |
|
|
|
save_yaml_template(template_path, new_template_data) |
|
|
|
st.session_state.qas = new_qas |
|
st.success(f"Template saved to {template_path}") |
|
|
|
st.subheader("Current Template Preview") |
|
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt} |
|
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml") |
|
|
|
st.subheader("Optimization Logs") |
|
log_container = st.empty() |
|
|
|
class StreamlitSink: |
|
def write(self, message): |
|
current_logs = st.session_state.get("logs", []) |
|
current_logs.append(message.strip()) |
|
st.session_state.logs = current_logs |
|
|
|
log_container.code("\n".join(current_logs), language="plaintext") |
|
|
|
streamlit_sink = StreamlitSink() |
|
_logger.remove() |
|
|
|
def prompt_optimizer_filter(record): |
|
return "optimizer" in record["name"].lower() |
|
|
|
_logger.add( |
|
streamlit_sink.write, |
|
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", |
|
filter=prompt_optimizer_filter, |
|
) |
|
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG") |
|
|
|
|
|
if st.button("Start Optimization"): |
|
try: |
|
|
|
SPO_LLM.initialize( |
|
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, "api_key": api_key}, |
|
evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, "api_key": api_key}, |
|
execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, "api_key": api_key}, |
|
) |
|
|
|
|
|
optimizer = PromptOptimizer( |
|
optimized_path=str(workspace_dir), |
|
initial_round=initial_round, |
|
max_rounds=max_rounds, |
|
template=f"{template_real_name}", |
|
name=template_name, |
|
) |
|
|
|
|
|
with st.spinner("Optimizing prompts..."): |
|
optimizer.optimize() |
|
|
|
st.success("Optimization completed!") |
|
|
|
st.header("Optimization Results") |
|
|
|
prompt_path = optimizer.root_path / "prompts" |
|
result_data = optimizer.data_utils.load_results(prompt_path) |
|
|
|
st.session_state.optimization_results = result_data |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
_logger.error(f"Error during optimization: {str(e)}") |
|
|
|
if st.session_state.optimization_results: |
|
st.header("Optimization Results") |
|
display_optimization_results(st.session_state.optimization_results) |
|
|
|
st.markdown("---") |
|
st.subheader("Test Optimized Prompt") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt") |
|
|
|
with col2: |
|
test_question = st.text_area("Your Question", value="", height=200, key="test_question") |
|
|
|
if st.button("Test Prompt"): |
|
if test_prompt and test_question: |
|
try: |
|
with st.spinner("Generating response..."): |
|
SPO_LLM.initialize( |
|
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url, |
|
"api_key": api_key}, |
|
evaluate_kwargs={"model": eval_model, "temperature": eval_temp, "base_url": base_url, |
|
"api_key": api_key}, |
|
execute_kwargs={"model": exec_model, "temperature": exec_temp, "base_url": base_url, |
|
"api_key": api_key}, |
|
) |
|
|
|
llm = SPO_LLM.get_instance() |
|
messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}] |
|
|
|
async def get_response(): |
|
return await llm.responser(request_type=RequestType.EXECUTE, messages=messages) |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
try: |
|
response = loop.run_until_complete(get_response()) |
|
finally: |
|
loop.close() |
|
|
|
st.subheader("Response:") |
|
st.markdown(response) |
|
|
|
except Exception as e: |
|
st.error(f"Error generating response: {str(e)}") |
|
else: |
|
st.warning("Please enter both prompt and question.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|