Jofthomas commited on
Commit
a4e3b45
·
verified ·
1 Parent(s): e1889c4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +78 -47
main.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import pandas as pd
4
  from fastapi import FastAPI, HTTPException, Body
5
  from pydantic import BaseModel, Field
6
- from typing import List, Dict, Any #<-- Make sure Any is imported
7
  from datasets import load_dataset, Dataset, DatasetDict
8
  from huggingface_hub import HfApi, hf_hub_download
9
  from datetime import datetime, timezone
@@ -12,8 +12,6 @@ import uvicorn
12
  import random
13
 
14
  # --- Constants and Config ---
15
- tool_threshold = 3
16
- step_threshold = 5
17
  HF_DATASET_ID = "agents-course/unit4-students-scores"
18
 
19
  # --- Data Structures ---
@@ -25,33 +23,37 @@ ground_truth_answers: Dict[str, str] = {}
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
- # --- Filtered Dataset Placeholder ---
29
- # Note: Making filtered_dataset global might not be ideal in larger apps,
30
- # but keeping it as is based on the original code.
31
- filtered_dataset = None
32
 
33
- # --- Modified load_questions Function ---
 
 
 
 
 
 
34
  def load_questions():
35
  global filtered_dataset
36
  global questions_for_api
37
  global ground_truth_answers
38
- tempo_filtered=[]
39
  # Clear existing data
40
  questions_for_api.clear()
41
  ground_truth_answers.clear()
42
 
43
- logger.info("Starting to load and filter GAIA dataset...")
44
  try:
45
- # Load the 'validation' split specifically if that's intended
46
- dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split='validation', trust_remote_code=True)
47
- logger.info("GAIA dataset validation split loaded.")
 
 
48
  except Exception as e:
49
- logger.error(f"Failed to load GAIA dataset validation split: {e}", exc_info=True)
50
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
51
 
52
  # --- Filtering Logic (remains the same) ---
53
- for question in dataset: # Iterate directly over the loaded split
54
- metadata = question.get('Annotator Metadata')
55
 
56
  if metadata:
57
  num_tools_str = metadata.get('Number of tools')
@@ -63,59 +65,88 @@ def load_questions():
63
  num_steps = int(num_steps_str)
64
 
65
  if num_tools < tool_threshold and num_steps < step_threshold:
66
- tempo_filtered.append(question)
67
  except ValueError:
68
- logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Could not convert tool/step count: tools='{num_tools_str}', steps='{num_steps_str}'.")
69
- # else: # Optional: Log if numbers are missing
70
- # logger.debug(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing tool/step count in metadata.")
71
  else:
72
- logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
73
 
74
- # Store the filtered list (optional, could process directly)
75
- filtered_dataset = tempo_filtered
76
  logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
77
 
78
- # --- Processing Logic (Modified) ---
79
  processed_count = 0
 
80
  for item in filtered_dataset:
81
  task_id = item.get('task_id')
82
- question_text = item.get('Question') # Keep original key for now
83
  final_answer = item.get('Final answer')
84
 
85
- # Validate required fields needed for processing/scoring
86
- if task_id and question_text and final_answer is not None:
87
- # Create a copy to avoid modifying the original item in filtered_dataset
88
- processed_item: Dict[str, Any] = item.copy()
89
-
90
- # Remove the fields we explicitly want to exclude
91
- processed_item.pop('Final answer', None)
92
- processed_item.pop('Annotator Annotation', None)
93
- # You could add more fields to pop here if needed later
94
- # processed_item.pop('Another field to remove', None)
95
-
96
- # Store the dictionary containing all remaining fields
 
 
 
 
 
97
  questions_for_api.append(processed_item)
98
 
99
- # Store the ground truth answer separately for scoring
100
  ground_truth_answers[str(task_id)] = str(final_answer)
101
  processed_count += 1
102
  else:
103
- # Log which required field was missing if possible
104
- missing = [k for k, v in {'task_id': task_id, 'Question': question_text, 'Final answer': final_answer}.items() if not v and v is not None]
105
- logger.warning(f"Skipping item due to missing required fields ({', '.join(missing)}): task_id={task_id}")
106
-
107
- logger.info(f"Successfully processed {processed_count} questions into API format.")
108
 
 
109
  if not questions_for_api:
110
- logger.error("CRITICAL: No valid questions loaded after filtering. API endpoints needing questions will fail.")
111
  # raise RuntimeError("Failed to load mandatory question data after filtering.")
 
 
 
112
 
113
  # --- Pydantic Models ---
114
- # Keep Question simple for potential internal use or basic validation,
115
- # but the API will return Dict[str, Any]
116
  class Question(BaseModel):
117
  task_id: str
118
- Question: str # Keep original casing if that's what in the data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Keep other models as they are (AnswerItem, Submission, ScoreResponse, ErrorResponse)
121
  # ... (rest of the Pydantic models remain the same) ...
 
3
  import pandas as pd
4
  from fastapi import FastAPI, HTTPException, Body
5
  from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Any, Optional
7
  from datasets import load_dataset, Dataset, DatasetDict
8
  from huggingface_hub import HfApi, hf_hub_download
9
  from datetime import datetime, timezone
 
12
  import random
13
 
14
  # --- Constants and Config ---
 
 
15
  HF_DATASET_ID = "agents-course/unit4-students-scores"
16
 
17
  # --- Data Structures ---
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
 
 
 
 
26
 
27
+ logger = logging.getLogger(__name__) # Make sure logger is initialized
28
+ tool_threshold = 3
29
+ step_threshold = 5
30
+ questions_for_api: List[Dict[str, Any]] = [] # Use Dict[str, Any] for flexibility before validation
31
+ ground_truth_answers: Dict[str, str] = {}
32
+ filtered_dataset = None # Or initialize as empty list: []
33
+
34
  def load_questions():
35
  global filtered_dataset
36
  global questions_for_api
37
  global ground_truth_answers
38
+ tempo_filtered = []
39
  # Clear existing data
40
  questions_for_api.clear()
41
  ground_truth_answers.clear()
42
 
43
+ logger.info("Starting to load and filter GAIA dataset (validation split)...")
44
  try:
45
+ # Load the specified split and features
46
+ dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
47
+ logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
48
+ # You can uncomment below to see the first item's structure if needed
49
+ # logger.debug(f"First item structure: {dataset[0]}")
50
  except Exception as e:
51
+ logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
52
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
53
 
54
  # --- Filtering Logic (remains the same) ---
55
+ for item in dataset:
56
+ metadata = item.get('Annotator Metadata')
57
 
58
  if metadata:
59
  num_tools_str = metadata.get('Number of tools')
 
65
  num_steps = int(num_steps_str)
66
 
67
  if num_tools < tool_threshold and num_steps < step_threshold:
68
+ tempo_filtered.append(item) # Add the original item if it matches filter
69
  except ValueError:
70
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count: tools='{num_tools_str}', steps='{num_steps_str}'.")
71
+ # else: # If needed: log missing numbers in metadata
72
+ # logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
73
  else:
74
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
75
 
76
+ filtered_dataset = tempo_filtered # Store the list of filtered original items
 
77
  logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
78
 
 
79
  processed_count = 0
80
+ # --- REVISED Processing Logic to match the new Pydantic model ---
81
  for item in filtered_dataset:
82
  task_id = item.get('task_id')
83
+ original_question_text = item.get('Question') # Get original text
84
  final_answer = item.get('Final answer')
85
 
86
+ # Validate essential fields needed for processing & ground truth
87
+ if task_id and original_question_text and final_answer is not None:
88
+
89
+ # Create the dictionary for the API, selecting only the desired fields
90
+ processed_item = {
91
+ "task_id": str(task_id), # Ensure string type
92
+ "question": str(original_question_text), # Rename and ensure string type
93
+ # Include optional fields *if they exist* in the source item
94
+ "Level": item.get("Level"), # Use .get() for safety, Pydantic handles None
95
+ "file_name": item.get("file_name"),
96
+ "file_path": item.get("file_path"),
97
+ }
98
+ # Optional: Clean up None values if Pydantic model doesn't handle them as desired
99
+ # processed_item = {k: v for k, v in processed_item.items() if v is not None}
100
+ # However, the Optional[...] fields in Pydantic should handle None correctly.
101
+
102
+ # Append the structured dictionary matching the Pydantic model
103
  questions_for_api.append(processed_item)
104
 
105
+ # Store the ground truth answer separately (as before)
106
  ground_truth_answers[str(task_id)] = str(final_answer)
107
  processed_count += 1
108
  else:
109
+ logger.warning(f"Skipping item due to missing essential fields (task_id, Question, or Final answer): task_id={task_id}")
 
 
 
 
110
 
111
+ logger.info(f"Successfully processed {processed_count} questions for the API matching the Pydantic model.")
112
  if not questions_for_api:
113
+ logger.error("CRITICAL: No valid questions loaded after filtering and processing. API endpoints needing questions will fail.")
114
  # raise RuntimeError("Failed to load mandatory question data after filtering.")
115
+ # --- END REVISED Processing Logic ---
116
+
117
+
118
 
119
  # --- Pydantic Models ---
120
+
121
+
122
  class Question(BaseModel):
123
  task_id: str
124
+ question: str
125
+ Level: Optional[str] = None
126
+ file_name: Optional[str] = None
127
+ file_path: Optional[str] = None
128
+
129
+
130
+ # --- The rest of your Pydantic models remain the same ---
131
+ class AnswerItem(BaseModel):
132
+ task_id: str
133
+ submitted_answer: str = Field(..., description="The agent's answer for the task_id")
134
+
135
+ class Submission(BaseModel):
136
+ username: str = Field(..., description="Hugging Face username", min_length=1)
137
+ agent_code: str = Field(..., description="The Python class code for the agent", min_length=10) # Basic check
138
+ answers: List[AnswerItem] = Field(..., description="List of answers submitted by the agent")
139
+
140
+ class ScoreResponse(BaseModel):
141
+ username: str
142
+ score: float
143
+ correct_count: int
144
+ total_attempted: int
145
+ message: str
146
+ timestamp: str
147
+
148
+ class ErrorResponse(BaseModel):
149
+ detail: str
150
 
151
  # Keep other models as they are (AnswerItem, Submission, ScoreResponse, ErrorResponse)
152
  # ... (rest of the Pydantic models remain the same) ...