|
import os |
|
import tempfile |
|
import gradio as gr |
|
import pandas as pd |
|
import traceback |
|
from core_agent import GAIAAgent |
|
from api_integration import GAIAApiClient |
|
|
|
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
def save_task_file(file_content, task_id): |
|
""" |
|
Save a task file to a temporary location |
|
""" |
|
if not file_content: |
|
return None |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
file_path = os.path.join(temp_dir, f"gaia_task_{task_id}.txt") |
|
|
|
|
|
with open(file_path, 'wb') as f: |
|
f.write(file_content) |
|
|
|
print(f"File saved to {file_path}") |
|
return file_path |
|
|
|
def get_agent_configuration(): |
|
""" |
|
Get the agent configuration based on environment variables |
|
""" |
|
|
|
config = { |
|
"model_type": "OpenAIServerModel", |
|
"model_id": "gpt-4o", |
|
"temperature": 0.2, |
|
"executor_type": "local", |
|
"verbose": False, |
|
"provider": "hf-inference", |
|
"timeout": 120 |
|
} |
|
|
|
|
|
xai_api_key = os.getenv("XAI_API_KEY") |
|
xai_api_base = os.getenv("XAI_API_BASE") |
|
|
|
|
|
if xai_api_key: |
|
config["api_key"] = xai_api_key |
|
if xai_api_base: |
|
config["api_base"] = xai_api_base |
|
|
|
config["model_id"] = "mixtral-8x7b-32768" |
|
|
|
|
|
if os.getenv("AGENT_MODEL_TYPE"): |
|
config["model_type"] = os.getenv("AGENT_MODEL_TYPE") |
|
|
|
if os.getenv("AGENT_MODEL_ID"): |
|
config["model_id"] = os.getenv("AGENT_MODEL_ID") |
|
|
|
if os.getenv("AGENT_TEMPERATURE"): |
|
config["temperature"] = float(os.getenv("AGENT_TEMPERATURE")) |
|
|
|
if os.getenv("AGENT_EXECUTOR_TYPE"): |
|
config["executor_type"] = os.getenv("AGENT_EXECUTOR_TYPE") |
|
|
|
if os.getenv("AGENT_VERBOSE") is not None: |
|
config["verbose"] = os.getenv("AGENT_VERBOSE").lower() == "true" |
|
|
|
if os.getenv("AGENT_API_BASE"): |
|
config["api_base"] = os.getenv("AGENT_API_BASE") |
|
|
|
|
|
if os.getenv("AGENT_PROVIDER"): |
|
config["provider"] = os.getenv("AGENT_PROVIDER") |
|
|
|
if os.getenv("AGENT_TIMEOUT"): |
|
config["timeout"] = int(os.getenv("AGENT_TIMEOUT")) |
|
|
|
return config |
|
|
|
def run_and_submit_all(profile: gr.OAuthProfile | None): |
|
""" |
|
Fetches all questions, runs the GAIAAgent on them, submits all answers, |
|
and displays the results. |
|
""" |
|
|
|
if not profile: |
|
return "Please Login to Hugging Face with the button.", None |
|
|
|
username = profile.username |
|
print(f"User logged in: {username}") |
|
|
|
|
|
space_id = os.getenv("SPACE_ID") |
|
agent_code = f"https://huggingface.co./spaces/{space_id}/tree/main" |
|
|
|
|
|
api_client = GAIAApiClient(DEFAULT_API_URL) |
|
|
|
|
|
try: |
|
agent_config = get_agent_configuration() |
|
print(f"Using agent configuration: {agent_config}") |
|
|
|
agent = GAIAAgent(**agent_config) |
|
print("Agent initialized successfully") |
|
except Exception as e: |
|
error_details = traceback.format_exc() |
|
print(f"Error initializing agent: {e}\n{error_details}") |
|
return f"Error initializing agent: {e}", None |
|
|
|
|
|
try: |
|
questions_data = api_client.get_questions() |
|
if not questions_data: |
|
return "Fetched questions list is empty or invalid format.", None |
|
print(f"Fetched {len(questions_data)} questions.") |
|
except Exception as e: |
|
error_details = traceback.format_exc() |
|
print(f"Error fetching questions: {e}\n{error_details}") |
|
return f"Error fetching questions: {e}", None |
|
|
|
|
|
results_log = [] |
|
answers_payload = [] |
|
print(f"Running agent on {len(questions_data)} questions...") |
|
|
|
|
|
total_questions = len(questions_data) |
|
completed = 0 |
|
failed = 0 |
|
|
|
for item in questions_data: |
|
task_id = item.get("task_id") |
|
question_text = item.get("question") |
|
if not task_id or question_text is None: |
|
print(f"Skipping item with missing task_id or question: {item}") |
|
continue |
|
|
|
try: |
|
|
|
completed += 1 |
|
print(f"Processing question {completed}/{total_questions}: Task ID {task_id}") |
|
|
|
|
|
file_path = None |
|
try: |
|
file_content = api_client.get_file(task_id) |
|
print(f"Downloaded file for task {task_id}") |
|
file_path = save_task_file(file_content, task_id) |
|
except Exception as file_e: |
|
print(f"No file found for task {task_id} or error: {file_e}") |
|
|
|
|
|
submitted_answer = agent.answer_question(question_text, file_path) |
|
|
|
|
|
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) |
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": question_text, |
|
"Submitted Answer": submitted_answer |
|
}) |
|
except Exception as e: |
|
|
|
failed += 1 |
|
error_details = traceback.format_exc() |
|
print(f"Error running agent on task {task_id}: {e}\n{error_details}") |
|
|
|
|
|
error_msg = f"AGENT ERROR: {e}" |
|
answers_payload.append({"task_id": task_id, "submitted_answer": error_msg}) |
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": question_text, |
|
"Submitted Answer": error_msg |
|
}) |
|
|
|
|
|
print(f"\nProcessing complete: {completed} questions processed, {failed} failures") |
|
|
|
if not answers_payload: |
|
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
|
|
|
|
|
submission_data = { |
|
"username": username.strip(), |
|
"agent_code": agent_code, |
|
"answers": answers_payload |
|
} |
|
|
|
print(f"Submitting {len(answers_payload)} answers for username '{username}'...") |
|
|
|
try: |
|
result_data = api_client.submit_answers( |
|
username.strip(), |
|
agent_code, |
|
answers_payload |
|
) |
|
|
|
|
|
correct_count = result_data.get('correct_count', 0) |
|
total_attempted = result_data.get('total_attempted', len(answers_payload)) |
|
success_rate = (correct_count / total_attempted) * 100 if total_attempted > 0 else 0 |
|
|
|
final_status = ( |
|
f"Submission Successful!\n" |
|
f"User: {result_data.get('username')}\n" |
|
f"Overall Score: {result_data.get('score', 'N/A')}% " |
|
f"({correct_count}/{total_attempted} correct, {success_rate:.1f}% success rate)\n" |
|
f"Message: {result_data.get('message', 'No message received.')}" |
|
) |
|
|
|
print("Submission successful.") |
|
return final_status, pd.DataFrame(results_log) |
|
except Exception as e: |
|
error_details = traceback.format_exc() |
|
status_message = f"Submission Failed: {e}\n{error_details}" |
|
print(status_message) |
|
return status_message, pd.DataFrame(results_log) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# GAIA Agent Evaluation Runner") |
|
gr.Markdown( |
|
""" |
|
**Instructions:** |
|
|
|
1. Log in to your Hugging Face account using the button below. |
|
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. |
|
|
|
**Configuration:** |
|
|
|
You can configure the agent by setting these environment variables: |
|
- `AGENT_MODEL_TYPE`: Model type (HfApiModel, InferenceClientModel, LiteLLMModel, OpenAIServerModel) |
|
- `AGENT_MODEL_ID`: Model ID |
|
- `AGENT_TEMPERATURE`: Temperature for generation (0.0-1.0) |
|
- `AGENT_EXECUTOR_TYPE`: Type of executor ('local' or 'e2b') |
|
- `AGENT_VERBOSE`: Enable verbose logging (true/false) |
|
- `AGENT_API_BASE`: Base URL for API calls (for OpenAIServerModel) |
|
|
|
**xAI Support:** |
|
- `XAI_API_KEY`: Your xAI API key |
|
- `XAI_API_BASE`: Base URL for xAI API (default: https://api.groq.com/openai/v1) |
|
- When using xAI, set AGENT_MODEL_TYPE=OpenAIServerModel and AGENT_MODEL_ID=mixtral-8x7b-32768 |
|
|
|
**InferenceClientModel specific settings:** |
|
- `AGENT_PROVIDER`: Provider for InferenceClientModel (e.g., "hf-inference") |
|
- `AGENT_TIMEOUT`: Timeout in seconds for API calls |
|
""" |
|
) |
|
|
|
gr.LoginButton() |
|
|
|
run_button = gr.Button("Run Evaluation & Submit All Answers") |
|
|
|
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) |
|
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) |
|
|
|
run_button.click( |
|
fn=run_and_submit_all, |
|
outputs=[status_output, results_table] |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("\n" + "-"*30 + " App Starting " + "-"*30) |
|
|
|
|
|
config = get_agent_configuration() |
|
print(f"Agent configuration: {config}") |
|
|
|
|
|
demo.launch(debug=True, share=False) |
|
|