File size: 8,722 Bytes
68ed57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#!/usr/bin/env python3
"""
Test script for the GAIA agent using real API keys.
This script simulates GAIA benchmark questions and helps debug/improve the agent.
"""

import os
import sys
import json
import tempfile
from typing import List, Dict, Any, Optional
import traceback
import dotenv

# Load environment variables from .env file
dotenv.load_dotenv()

# Import our agent
from core_agent import GAIAAgent

# Simulation of GAIA benchmark questions
SAMPLE_QUESTIONS = [
    {
        "task_id": "task_001",
        "question": "What is the capital of France?",
        "expected_answer": "Paris",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_002",
        "question": "What is the square root of 144?",
        "expected_answer": "12",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_003",
        "question": "If a train travels at 60 miles per hour, how far will it travel in 2.5 hours?",
        "expected_answer": "150 miles",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_004", 
        "question": ".rewsna eht sa 'thgir' drow eht etirw ,tfel fo etisoppo eht si tahW",
        "expected_answer": "right",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_005",
        "question": "Analyze the data in the attached CSV file and tell me the total sales for the month of January.",
        "expected_answer": "$10,250.75",
        "has_file": True,
        "file_content": """Date,Product,Quantity,Price,Total
2023-01-05,Widget A,10,25.99,259.90
2023-01-12,Widget B,5,45.50,227.50
2023-01-15,Widget C,20,50.25,1005.00
2023-01-20,Widget A,15,25.99,389.85
2023-01-25,Widget B,8,45.50,364.00
2023-01-28,Widget D,100,80.04,8004.50"""
    },
    {
        "task_id": "task_006",
        "question": "I'm making a grocery list for my mom, but she's a picky eater. She only eats foods that don't contain the letter 'e'. List 5 common fruits and vegetables she can eat.",
        "expected_answer": "Banana, Kiwi, Corn, Fig, Taro",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_007",
        "question": "How many studio albums were published by Mercedes Sosa between 1972 and 1985?",
        "expected_answer": "12",
        "has_file": False,
        "file_content": None
    },
    {
        "task_id": "task_008",
        "question": "In the video https://www.youtube.com/watch?v=L1vXC1KMRd0, what color is primarily associated with the main character?",
        "expected_answer": "Blue",
        "has_file": False,
        "file_content": None
    }
]

def initialize_agent():
    """Initialize the GAIAAgent with appropriate API keys."""
    print("Initializing GAIAAgent with API keys...")
    
    # Try X.AI first (xAI) with the correct API endpoint
    if os.getenv("XAI_API_KEY"):
        print("Using X.AI API key")
        try:
            agent = GAIAAgent(
                model_type="OpenAIServerModel",
                model_id="grok-3-latest",  # Use the X.AI model
                api_key=os.getenv("XAI_API_KEY"),
                api_base="https://api.x.ai/v1",  # Correct X.AI endpoint
                temperature=0.2,
                executor_type="local",
                verbose=True,
                system_prompt_suffix=additional_system_prompt  # Add our hints
            )
            print("Using OpenAIServerModel with X.AI API")
            return agent
        except Exception as e:
            print(f"Error initializing with X.AI API: {e}")
            traceback.print_exc()
    
    # Then try OpenAI
    if os.getenv("OPENAI_API_KEY"):
        print("Using OpenAI API key")
        try:
            model_id = os.getenv("AGENT_MODEL_ID", "gpt-4o")
            agent = GAIAAgent(
                model_type="OpenAIServerModel",
                model_id=model_id,
                api_key=os.getenv("OPENAI_API_KEY"),
                temperature=0.2,
                executor_type="local",
                verbose=True
            )
            print(f"Using OpenAIServerModel with model_id: {model_id}")
            return agent
        except Exception as e:
            print(f"Error initializing with OpenAI API: {e}")
            traceback.print_exc()
    
    # Last resort, try Hugging Face
    if os.getenv("HUGGINGFACEHUB_API_TOKEN"):
        print("Using Hugging Face API token")
        try:
            # Use a smaller model that might work within free tier
            model_id = "tiiuae/falcon-7b-instruct"  # Try a smaller model that might be within free tier
            agent = GAIAAgent(
                model_type="HfApiModel",
                model_id=model_id,
                api_key=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
                temperature=0.2,
                executor_type="local",
                verbose=True
            )
            print(f"Using HfApiModel with model_id: {model_id}")
            return agent
        except Exception as e:
            print(f"Error initializing with Hugging Face API: {e}")
            traceback.print_exc()
    
    print("ERROR: No valid API keys found in environment. Please set one of the following:")
    print("- XAI_API_KEY (for X.AI)")
    print("- OPENAI_API_KEY")
    print("- HUGGINGFACEHUB_API_TOKEN")
    return None

def save_test_file(task_id: str, content: str) -> str:
    """Save a test file to a temporary location."""
    temp_dir = tempfile.gettempdir()
    file_path = os.path.join(temp_dir, f"test_file_{task_id}.csv")
    
    with open(file_path, 'w') as f:
        f.write(content)
    
    return file_path

def run_tests():
    """Run tests using the GAIAAgent with API keys."""
    agent = initialize_agent()
    
    if not agent:
        print("Failed to initialize agent. Exiting.")
        return
    
    results = []
    correct_count = 0
    total_count = len(SAMPLE_QUESTIONS)
    
    for idx, question_data in enumerate(SAMPLE_QUESTIONS):
        task_id = question_data["task_id"]
        question = question_data["question"]
        expected = question_data["expected_answer"]
        
        print(f"\n{'='*80}")
        print(f"Question {idx+1}/{total_count}: {question}")
        print(f"Expected: {expected}")
        
        # Process any attached file
        file_path = None
        if question_data["has_file"] and question_data["file_content"]:
            file_path = save_test_file(task_id, question_data["file_content"])
            print(f"Created test file: {file_path}")
        
        # Get answer from agent
        try:
            answer = agent.answer_question(question, file_path)
            print(f"Agent answer: {answer}")
            
            # Check if answer matches expected
            is_correct = answer.lower() == expected.lower()
            if is_correct:
                correct_count += 1
                print(f"βœ… CORRECT")
            else:
                print(f"❌ INCORRECT - Expected: {expected}")
            
            results.append({
                "task_id": task_id,
                "question": question,
                "expected": expected,
                "answer": answer,
                "is_correct": is_correct
            })
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"Error processing question: {e}\n{error_details}")
            results.append({
                "task_id": task_id,
                "question": question,
                "expected": expected,
                "answer": f"ERROR: {str(e)}",
                "is_correct": False
            })
    
    # Print summary
    accuracy = (correct_count / total_count) * 100
    print(f"\n{'='*80}")
    print(f"Test Results: {correct_count}/{total_count} correct ({accuracy:.1f}%)")
    
    return results


if __name__ == "__main__":
    print("Running tests for GAIA agent with API keys...")
    
    # Print environment information
    print("\nEnvironment information:")
    print(f"XAI_API_KEY set: {'Yes' if os.getenv('XAI_API_KEY') else 'No'}")
    print(f"OPENAI_API_KEY set: {'Yes' if os.getenv('OPENAI_API_KEY') else 'No'}")
    print(f"HUGGINGFACEHUB_API_TOKEN set: {'Yes' if os.getenv('HUGGINGFACEHUB_API_TOKEN') else 'No'}")
    print(f"AGENT_MODEL_TYPE: {os.getenv('AGENT_MODEL_TYPE', 'OpenAIServerModel')} (default: OpenAIServerModel)")
    print(f"AGENT_MODEL_ID: {os.getenv('AGENT_MODEL_ID', 'gpt-4o')} (default: gpt-4o)")
    
    results = run_tests()
    
    # Save results to a file
    with open("test_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print("\nResults saved to test_results.json")