Update agent.py
Browse files
agent.py
CHANGED
@@ -14,10 +14,12 @@ import mimetypes
|
|
14 |
import os
|
15 |
import re
|
16 |
import tempfile
|
|
|
17 |
from typing import List, Dict, Any, Optional
|
18 |
import json
|
19 |
import requests
|
20 |
from urllib.parse import urlparse
|
|
|
21 |
|
22 |
from smolagents import (
|
23 |
CodeAgent,
|
@@ -42,6 +44,39 @@ def _download_file(file_id: str) -> bytes:
|
|
42 |
resp.raise_for_status()
|
43 |
return resp.content
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# --------------------------------------------------------------------------- #
|
46 |
# custom tool: fetch GAIA attachments
|
47 |
# --------------------------------------------------------------------------- #
|
@@ -224,6 +259,81 @@ def analyze_excel_file(file_path: str, query: str) -> str:
|
|
224 |
except Exception as e:
|
225 |
return f"Error analyzing Excel file: {str(e)}"
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
# --------------------------------------------------------------------------- #
|
228 |
# GAIAAgent class
|
229 |
# --------------------------------------------------------------------------- #
|
@@ -233,7 +343,8 @@ class GAIAAgent:
|
|
233 |
api_key: Optional[str] = None,
|
234 |
temperature: float = 0.1,
|
235 |
verbose: bool = False,
|
236 |
-
system_prompt: Optional[str] = None
|
|
|
237 |
):
|
238 |
"""
|
239 |
Initialize a GAIAAgent with Claude model
|
@@ -243,6 +354,7 @@ class GAIAAgent:
|
|
243 |
temperature: Temperature for text generation
|
244 |
verbose: Enable verbose logging
|
245 |
system_prompt: Custom system prompt (optional)
|
|
|
246 |
"""
|
247 |
# Set verbosity
|
248 |
self.verbose = verbose
|
@@ -260,15 +372,16 @@ All answers are graded by exact string match, so format carefully!"""
|
|
260 |
if self.verbose:
|
261 |
print(f"Using Anthropic token: {api_key[:5]}...")
|
262 |
|
263 |
-
# Initialize Claude model
|
264 |
-
self.model =
|
265 |
model_id="anthropic/claude-3-5-sonnet-20240620", # Use Claude 3.5 Sonnet
|
266 |
api_key=api_key,
|
267 |
-
temperature=temperature
|
|
|
268 |
)
|
269 |
|
270 |
if self.verbose:
|
271 |
-
print(f"Initialized model:
|
272 |
|
273 |
# Initialize default tools
|
274 |
self.tools = [
|
@@ -334,8 +447,12 @@ All answers are graded by exact string match, so format carefully!"""
|
|
334 |
# If there's a file, read it and include its content in the context
|
335 |
if task_file_path:
|
336 |
try:
|
|
|
|
|
337 |
with open(task_file_path, 'r', errors='ignore') as f:
|
338 |
-
file_content = f.read()
|
|
|
|
|
339 |
|
340 |
# Determine file type from extension
|
341 |
import os
|
@@ -343,11 +460,11 @@ All answers are graded by exact string match, so format carefully!"""
|
|
343 |
|
344 |
context = f"""
|
345 |
Question: {question}
|
346 |
-
This question has an associated file. Here is the file content:
|
347 |
```{file_ext}
|
348 |
{file_content}
|
349 |
```
|
350 |
-
Analyze the file content
|
351 |
"""
|
352 |
except Exception as file_e:
|
353 |
try:
|
@@ -385,12 +502,12 @@ This question appears to be in reversed text. Here's the reversed version:
|
|
385 |
Now answer the question above. Remember to format your answer exactly as requested.
|
386 |
"""
|
387 |
|
388 |
-
# Add a prompt to ensure precise answers
|
389 |
full_prompt = f"""{context}
|
390 |
When answering, provide ONLY the precise answer requested.
|
391 |
Do not include explanations, steps, reasoning, or additional text.
|
392 |
Be direct and specific. GAIA benchmark requires exact matching answers.
|
393 |
-
|
394 |
"""
|
395 |
|
396 |
# Run the agent with the question
|
@@ -486,8 +603,9 @@ class ClaudeAgent:
|
|
486 |
# Create GAIAAgent instance
|
487 |
self.agent = GAIAAgent(
|
488 |
api_key=api_key,
|
489 |
-
temperature=0.1,
|
490 |
-
verbose=True,
|
|
|
491 |
)
|
492 |
except Exception as e:
|
493 |
print(f"Error initializing GAIAAgent: {e}")
|
@@ -506,6 +624,9 @@ class ClaudeAgent:
|
|
506 |
try:
|
507 |
print(f"Received question: {question[:100]}..." if len(question) > 100 else f"Received question: {question}")
|
508 |
|
|
|
|
|
|
|
509 |
# Detect reversed text
|
510 |
if question.startswith(".") or ".rewsna eht sa" in question:
|
511 |
print("Detected reversed text question")
|
|
|
14 |
import os
|
15 |
import re
|
16 |
import tempfile
|
17 |
+
import time
|
18 |
from typing import List, Dict, Any, Optional
|
19 |
import json
|
20 |
import requests
|
21 |
from urllib.parse import urlparse
|
22 |
+
import random
|
23 |
|
24 |
from smolagents import (
|
25 |
CodeAgent,
|
|
|
44 |
resp.raise_for_status()
|
45 |
return resp.content
|
46 |
|
47 |
+
# --------------------------------------------------------------------------- #
|
48 |
+
# Rate limiting helper
|
49 |
+
# --------------------------------------------------------------------------- #
|
50 |
+
class RateLimiter:
|
51 |
+
"""Simple rate limiter to prevent Anthropic API rate limit errors"""
|
52 |
+
def __init__(self, requests_per_minute=20, burst=3):
|
53 |
+
self.requests_per_minute = requests_per_minute
|
54 |
+
self.burst = burst
|
55 |
+
self.request_times = []
|
56 |
+
|
57 |
+
def wait(self):
|
58 |
+
"""Wait if needed to avoid exceeding rate limits"""
|
59 |
+
now = time.time()
|
60 |
+
# Remove timestamps older than 1 minute
|
61 |
+
self.request_times = [t for t in self.request_times if now - t < 60]
|
62 |
+
|
63 |
+
# If we've made too many requests in the last minute, wait
|
64 |
+
if len(self.request_times) >= self.requests_per_minute:
|
65 |
+
oldest = min(self.request_times)
|
66 |
+
sleep_time = 60 - (now - oldest) + 1 # +1 for safety
|
67 |
+
print(f"Rate limit approaching. Waiting {sleep_time:.2f} seconds before next request...")
|
68 |
+
time.sleep(sleep_time)
|
69 |
+
|
70 |
+
# Add current timestamp to the list
|
71 |
+
self.request_times.append(time.time())
|
72 |
+
|
73 |
+
# Add a small random delay to avoid bursts of requests
|
74 |
+
if len(self.request_times) > self.burst:
|
75 |
+
time.sleep(random.uniform(0.2, 1.0))
|
76 |
+
|
77 |
+
# Global rate limiter instance
|
78 |
+
RATE_LIMITER = RateLimiter(requests_per_minute=25) # Keep below 40 for safety
|
79 |
+
|
80 |
# --------------------------------------------------------------------------- #
|
81 |
# custom tool: fetch GAIA attachments
|
82 |
# --------------------------------------------------------------------------- #
|
|
|
259 |
except Exception as e:
|
260 |
return f"Error analyzing Excel file: {str(e)}"
|
261 |
|
262 |
+
# --------------------------------------------------------------------------- #
|
263 |
+
# Custom LiteLLM model with rate limiting and error handling
|
264 |
+
# --------------------------------------------------------------------------- #
|
265 |
+
class RateLimitedClaudeModel:
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
model_id: str = "anthropic/claude-3-5-sonnet-20240620",
|
269 |
+
api_key: Optional[str] = None,
|
270 |
+
temperature: float = 0.1,
|
271 |
+
max_tokens: int = 1024,
|
272 |
+
max_retries: int = 3,
|
273 |
+
retry_delay: int = 5,
|
274 |
+
):
|
275 |
+
"""
|
276 |
+
Initialize a Claude model with rate limiting and error handling
|
277 |
+
|
278 |
+
Args:
|
279 |
+
model_id: The model ID to use
|
280 |
+
api_key: The API key to use
|
281 |
+
temperature: The temperature to use
|
282 |
+
max_tokens: The maximum number of tokens to generate
|
283 |
+
max_retries: The maximum number of retries on rate limit errors
|
284 |
+
retry_delay: The initial delay between retries (will increase exponentially)
|
285 |
+
"""
|
286 |
+
# Get API key
|
287 |
+
if api_key is None:
|
288 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
289 |
+
if not api_key:
|
290 |
+
raise ValueError("No Anthropic token provided. Please set ANTHROPIC_API_KEY environment variable or pass api_key parameter.")
|
291 |
+
|
292 |
+
self.model_id = model_id
|
293 |
+
self.api_key = api_key
|
294 |
+
self.temperature = temperature
|
295 |
+
self.max_tokens = max_tokens
|
296 |
+
self.max_retries = max_retries
|
297 |
+
self.retry_delay = retry_delay
|
298 |
+
|
299 |
+
# Create the underlying LiteLLM model
|
300 |
+
self.model = LiteLLMModel(
|
301 |
+
model_id=model_id,
|
302 |
+
api_key=api_key,
|
303 |
+
temperature=temperature
|
304 |
+
)
|
305 |
+
|
306 |
+
def __call__(self, prompt: str, system_instruction: str, **kwargs) -> str:
|
307 |
+
"""
|
308 |
+
Call the model with rate limiting and error handling
|
309 |
+
|
310 |
+
Args:
|
311 |
+
prompt: The prompt to generate from
|
312 |
+
system_instruction: The system instruction to use
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
The generated text
|
316 |
+
"""
|
317 |
+
retries = 0
|
318 |
+
while True:
|
319 |
+
try:
|
320 |
+
# Wait according to rate limiter
|
321 |
+
RATE_LIMITER.wait()
|
322 |
+
|
323 |
+
# Call the model
|
324 |
+
return self.model(prompt, system_instruction=system_instruction, **kwargs)
|
325 |
+
|
326 |
+
except Exception as e:
|
327 |
+
# Check if it's a rate limit error
|
328 |
+
if "rate_limit_error" in str(e) and retries < self.max_retries:
|
329 |
+
retries += 1
|
330 |
+
sleep_time = self.retry_delay * (2 ** (retries - 1)) # Exponential backoff
|
331 |
+
print(f"Rate limit exceeded, retrying in {sleep_time} seconds (attempt {retries}/{self.max_retries})...")
|
332 |
+
time.sleep(sleep_time)
|
333 |
+
else:
|
334 |
+
# If it's not a rate limit error or we've exceeded max retries, raise
|
335 |
+
raise
|
336 |
+
|
337 |
# --------------------------------------------------------------------------- #
|
338 |
# GAIAAgent class
|
339 |
# --------------------------------------------------------------------------- #
|
|
|
343 |
api_key: Optional[str] = None,
|
344 |
temperature: float = 0.1,
|
345 |
verbose: bool = False,
|
346 |
+
system_prompt: Optional[str] = None,
|
347 |
+
max_tokens: int = 1024,
|
348 |
):
|
349 |
"""
|
350 |
Initialize a GAIAAgent with Claude model
|
|
|
354 |
temperature: Temperature for text generation
|
355 |
verbose: Enable verbose logging
|
356 |
system_prompt: Custom system prompt (optional)
|
357 |
+
max_tokens: Maximum number of tokens to generate per response
|
358 |
"""
|
359 |
# Set verbosity
|
360 |
self.verbose = verbose
|
|
|
372 |
if self.verbose:
|
373 |
print(f"Using Anthropic token: {api_key[:5]}...")
|
374 |
|
375 |
+
# Initialize Claude model with rate limiting
|
376 |
+
self.model = RateLimitedClaudeModel(
|
377 |
model_id="anthropic/claude-3-5-sonnet-20240620", # Use Claude 3.5 Sonnet
|
378 |
api_key=api_key,
|
379 |
+
temperature=temperature,
|
380 |
+
max_tokens=max_tokens,
|
381 |
)
|
382 |
|
383 |
if self.verbose:
|
384 |
+
print(f"Initialized model: RateLimitedClaudeModel - anthropic/claude-3-5-sonnet-20240620")
|
385 |
|
386 |
# Initialize default tools
|
387 |
self.tools = [
|
|
|
447 |
# If there's a file, read it and include its content in the context
|
448 |
if task_file_path:
|
449 |
try:
|
450 |
+
# Limit file content size to avoid token limits
|
451 |
+
max_file_size = 10000 # Characters
|
452 |
with open(task_file_path, 'r', errors='ignore') as f:
|
453 |
+
file_content = f.read(max_file_size)
|
454 |
+
if len(file_content) >= max_file_size:
|
455 |
+
file_content = file_content[:max_file_size] + "... [content truncated to prevent exceeding token limits]"
|
456 |
|
457 |
# Determine file type from extension
|
458 |
import os
|
|
|
460 |
|
461 |
context = f"""
|
462 |
Question: {question}
|
463 |
+
This question has an associated file. Here is the file content (it may be truncated):
|
464 |
```{file_ext}
|
465 |
{file_content}
|
466 |
```
|
467 |
+
Analyze the available file content to answer the question.
|
468 |
"""
|
469 |
except Exception as file_e:
|
470 |
try:
|
|
|
502 |
Now answer the question above. Remember to format your answer exactly as requested.
|
503 |
"""
|
504 |
|
505 |
+
# Add a prompt to ensure precise answers but keep it concise
|
506 |
full_prompt = f"""{context}
|
507 |
When answering, provide ONLY the precise answer requested.
|
508 |
Do not include explanations, steps, reasoning, or additional text.
|
509 |
Be direct and specific. GAIA benchmark requires exact matching answers.
|
510 |
+
Example: If asked "What is the capital of France?", respond just with "Paris".
|
511 |
"""
|
512 |
|
513 |
# Run the agent with the question
|
|
|
603 |
# Create GAIAAgent instance
|
604 |
self.agent = GAIAAgent(
|
605 |
api_key=api_key,
|
606 |
+
temperature=0.1, # Use low temperature for precise answers
|
607 |
+
verbose=True, # Enable verbose logging
|
608 |
+
max_tokens=1024, # Reduce max tokens to avoid hitting rate limits
|
609 |
)
|
610 |
except Exception as e:
|
611 |
print(f"Error initializing GAIAAgent: {e}")
|
|
|
624 |
try:
|
625 |
print(f"Received question: {question[:100]}..." if len(question) > 100 else f"Received question: {question}")
|
626 |
|
627 |
+
# Add delay between questions to respect rate limits
|
628 |
+
time.sleep(random.uniform(0.5, 2.0))
|
629 |
+
|
630 |
# Detect reversed text
|
631 |
if question.startswith(".") or ".rewsna eht sa" in question:
|
632 |
print("Detected reversed text question")
|