Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
import pandas as pd
|
4 |
from fastapi import FastAPI, HTTPException, Body
|
5 |
from pydantic import BaseModel, Field
|
6 |
-
from typing import List, Dict, Any
|
7 |
from datasets import load_dataset, Dataset, DatasetDict
|
8 |
from huggingface_hub import HfApi, hf_hub_download
|
9 |
from datetime import datetime, timezone
|
@@ -12,8 +12,6 @@ import uvicorn
|
|
12 |
import random
|
13 |
|
14 |
# --- Constants and Config ---
|
15 |
-
tool_threshold = 3
|
16 |
-
step_threshold = 5
|
17 |
HF_DATASET_ID = "agents-course/unit4-students-scores"
|
18 |
|
19 |
# --- Data Structures ---
|
@@ -25,33 +23,37 @@ ground_truth_answers: Dict[str, str] = {}
|
|
25 |
logging.basicConfig(level=logging.INFO)
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
-
# --- Filtered Dataset Placeholder ---
|
29 |
-
# Note: Making filtered_dataset global might not be ideal in larger apps,
|
30 |
-
# but keeping it as is based on the original code.
|
31 |
-
filtered_dataset = None
|
32 |
|
33 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def load_questions():
|
35 |
global filtered_dataset
|
36 |
global questions_for_api
|
37 |
global ground_truth_answers
|
38 |
-
tempo_filtered=[]
|
39 |
# Clear existing data
|
40 |
questions_for_api.clear()
|
41 |
ground_truth_answers.clear()
|
42 |
|
43 |
-
logger.info("Starting to load and filter GAIA dataset...")
|
44 |
try:
|
45 |
-
# Load the
|
46 |
-
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split=
|
47 |
-
logger.info("GAIA dataset validation split loaded.")
|
|
|
|
|
48 |
except Exception as e:
|
49 |
-
logger.error(f"Failed to load GAIA dataset
|
50 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
51 |
|
52 |
# --- Filtering Logic (remains the same) ---
|
53 |
-
for
|
54 |
-
metadata =
|
55 |
|
56 |
if metadata:
|
57 |
num_tools_str = metadata.get('Number of tools')
|
@@ -63,59 +65,88 @@ def load_questions():
|
|
63 |
num_steps = int(num_steps_str)
|
64 |
|
65 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
66 |
-
tempo_filtered.append(
|
67 |
except ValueError:
|
68 |
-
|
69 |
-
# else: #
|
70 |
-
# logger.
|
71 |
else:
|
72 |
-
logger.warning(f"Skipping Task ID: {
|
73 |
|
74 |
-
# Store the
|
75 |
-
filtered_dataset = tempo_filtered
|
76 |
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
|
77 |
|
78 |
-
# --- Processing Logic (Modified) ---
|
79 |
processed_count = 0
|
|
|
80 |
for item in filtered_dataset:
|
81 |
task_id = item.get('task_id')
|
82 |
-
|
83 |
final_answer = item.get('Final answer')
|
84 |
|
85 |
-
# Validate
|
86 |
-
if task_id and
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
questions_for_api.append(processed_item)
|
98 |
|
99 |
-
# Store the ground truth answer separately
|
100 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
101 |
processed_count += 1
|
102 |
else:
|
103 |
-
|
104 |
-
missing = [k for k, v in {'task_id': task_id, 'Question': question_text, 'Final answer': final_answer}.items() if not v and v is not None]
|
105 |
-
logger.warning(f"Skipping item due to missing required fields ({', '.join(missing)}): task_id={task_id}")
|
106 |
-
|
107 |
-
logger.info(f"Successfully processed {processed_count} questions into API format.")
|
108 |
|
|
|
109 |
if not questions_for_api:
|
110 |
-
logger.error("CRITICAL: No valid questions loaded after filtering. API endpoints needing questions will fail.")
|
111 |
# raise RuntimeError("Failed to load mandatory question data after filtering.")
|
|
|
|
|
|
|
112 |
|
113 |
# --- Pydantic Models ---
|
114 |
-
|
115 |
-
|
116 |
class Question(BaseModel):
|
117 |
task_id: str
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
# Keep other models as they are (AnswerItem, Submission, ScoreResponse, ErrorResponse)
|
121 |
# ... (rest of the Pydantic models remain the same) ...
|
|
|
3 |
import pandas as pd
|
4 |
from fastapi import FastAPI, HTTPException, Body
|
5 |
from pydantic import BaseModel, Field
|
6 |
+
from typing import List, Dict, Any, Optional
|
7 |
from datasets import load_dataset, Dataset, DatasetDict
|
8 |
from huggingface_hub import HfApi, hf_hub_download
|
9 |
from datetime import datetime, timezone
|
|
|
12 |
import random
|
13 |
|
14 |
# --- Constants and Config ---
|
|
|
|
|
15 |
HF_DATASET_ID = "agents-course/unit4-students-scores"
|
16 |
|
17 |
# --- Data Structures ---
|
|
|
23 |
logging.basicConfig(level=logging.INFO)
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
logger = logging.getLogger(__name__) # Make sure logger is initialized
|
28 |
+
tool_threshold = 3
|
29 |
+
step_threshold = 5
|
30 |
+
questions_for_api: List[Dict[str, Any]] = [] # Use Dict[str, Any] for flexibility before validation
|
31 |
+
ground_truth_answers: Dict[str, str] = {}
|
32 |
+
filtered_dataset = None # Or initialize as empty list: []
|
33 |
+
|
34 |
def load_questions():
|
35 |
global filtered_dataset
|
36 |
global questions_for_api
|
37 |
global ground_truth_answers
|
38 |
+
tempo_filtered = []
|
39 |
# Clear existing data
|
40 |
questions_for_api.clear()
|
41 |
ground_truth_answers.clear()
|
42 |
|
43 |
+
logger.info("Starting to load and filter GAIA dataset (validation split)...")
|
44 |
try:
|
45 |
+
# Load the specified split and features
|
46 |
+
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
|
47 |
+
logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
|
48 |
+
# You can uncomment below to see the first item's structure if needed
|
49 |
+
# logger.debug(f"First item structure: {dataset[0]}")
|
50 |
except Exception as e:
|
51 |
+
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
52 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
53 |
|
54 |
# --- Filtering Logic (remains the same) ---
|
55 |
+
for item in dataset:
|
56 |
+
metadata = item.get('Annotator Metadata')
|
57 |
|
58 |
if metadata:
|
59 |
num_tools_str = metadata.get('Number of tools')
|
|
|
65 |
num_steps = int(num_steps_str)
|
66 |
|
67 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
68 |
+
tempo_filtered.append(item) # Add the original item if it matches filter
|
69 |
except ValueError:
|
70 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count: tools='{num_tools_str}', steps='{num_steps_str}'.")
|
71 |
+
# else: # If needed: log missing numbers in metadata
|
72 |
+
# logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
|
73 |
else:
|
74 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
|
75 |
|
76 |
+
filtered_dataset = tempo_filtered # Store the list of filtered original items
|
|
|
77 |
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
|
78 |
|
|
|
79 |
processed_count = 0
|
80 |
+
# --- REVISED Processing Logic to match the new Pydantic model ---
|
81 |
for item in filtered_dataset:
|
82 |
task_id = item.get('task_id')
|
83 |
+
original_question_text = item.get('Question') # Get original text
|
84 |
final_answer = item.get('Final answer')
|
85 |
|
86 |
+
# Validate essential fields needed for processing & ground truth
|
87 |
+
if task_id and original_question_text and final_answer is not None:
|
88 |
+
|
89 |
+
# Create the dictionary for the API, selecting only the desired fields
|
90 |
+
processed_item = {
|
91 |
+
"task_id": str(task_id), # Ensure string type
|
92 |
+
"question": str(original_question_text), # Rename and ensure string type
|
93 |
+
# Include optional fields *if they exist* in the source item
|
94 |
+
"Level": item.get("Level"), # Use .get() for safety, Pydantic handles None
|
95 |
+
"file_name": item.get("file_name"),
|
96 |
+
"file_path": item.get("file_path"),
|
97 |
+
}
|
98 |
+
# Optional: Clean up None values if Pydantic model doesn't handle them as desired
|
99 |
+
# processed_item = {k: v for k, v in processed_item.items() if v is not None}
|
100 |
+
# However, the Optional[...] fields in Pydantic should handle None correctly.
|
101 |
+
|
102 |
+
# Append the structured dictionary matching the Pydantic model
|
103 |
questions_for_api.append(processed_item)
|
104 |
|
105 |
+
# Store the ground truth answer separately (as before)
|
106 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
107 |
processed_count += 1
|
108 |
else:
|
109 |
+
logger.warning(f"Skipping item due to missing essential fields (task_id, Question, or Final answer): task_id={task_id}")
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
logger.info(f"Successfully processed {processed_count} questions for the API matching the Pydantic model.")
|
112 |
if not questions_for_api:
|
113 |
+
logger.error("CRITICAL: No valid questions loaded after filtering and processing. API endpoints needing questions will fail.")
|
114 |
# raise RuntimeError("Failed to load mandatory question data after filtering.")
|
115 |
+
# --- END REVISED Processing Logic ---
|
116 |
+
|
117 |
+
|
118 |
|
119 |
# --- Pydantic Models ---
|
120 |
+
|
121 |
+
|
122 |
class Question(BaseModel):
|
123 |
task_id: str
|
124 |
+
question: str
|
125 |
+
Level: Optional[str] = None
|
126 |
+
file_name: Optional[str] = None
|
127 |
+
file_path: Optional[str] = None
|
128 |
+
|
129 |
+
|
130 |
+
# --- The rest of your Pydantic models remain the same ---
|
131 |
+
class AnswerItem(BaseModel):
|
132 |
+
task_id: str
|
133 |
+
submitted_answer: str = Field(..., description="The agent's answer for the task_id")
|
134 |
+
|
135 |
+
class Submission(BaseModel):
|
136 |
+
username: str = Field(..., description="Hugging Face username", min_length=1)
|
137 |
+
agent_code: str = Field(..., description="The Python class code for the agent", min_length=10) # Basic check
|
138 |
+
answers: List[AnswerItem] = Field(..., description="List of answers submitted by the agent")
|
139 |
+
|
140 |
+
class ScoreResponse(BaseModel):
|
141 |
+
username: str
|
142 |
+
score: float
|
143 |
+
correct_count: int
|
144 |
+
total_attempted: int
|
145 |
+
message: str
|
146 |
+
timestamp: str
|
147 |
+
|
148 |
+
class ErrorResponse(BaseModel):
|
149 |
+
detail: str
|
150 |
|
151 |
# Keep other models as they are (AnswerItem, Submission, ScoreResponse, ErrorResponse)
|
152 |
# ... (rest of the Pydantic models remain the same) ...
|