susmitsil commited on
Commit
509f358
·
verified ·
1 Parent(s): 85b4924
Files changed (1) hide show
  1. app.py +699 -180
app.py CHANGED
@@ -1,206 +1,725 @@
1
  import os
2
- import gradio as gr
 
 
 
 
 
3
  import requests
4
- import inspect
5
- import pandas as pd
6
- from gemini_agent import GeminiAgent
7
 
8
- # Constants
9
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
 
 
 
 
10
 
11
- class BasicAgent:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self):
13
- print("Initializing the BasicAgent")
14
-
15
-
16
- # Get Gemini API key
17
- api_key = os.getenv('GOOGLE_API_KEY')
18
- if not api_key:
19
- raise ValueError("GOOGLE_API_KEY environment variable not set.")
 
 
 
 
 
20
 
21
- # Initialize GeminiAgent
22
- self.agent = GeminiAgent(api_key=api_key)
23
- print("GeminiAgent initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def __call__(self, question: str) -> str:
26
- print(f"Agent received question (first 50 chars): {question[:50]}...")
27
- final_answer = self.agent.run(question)
28
- print(f"Agent returning fixed answer: {final_answer}")
29
- return final_answer
30
 
31
- def run_and_submit_all( profile: gr.OAuthProfile | None):
 
 
 
32
  """
33
- Fetches all questions, runs the BasicAgent on them, submits all answers,
34
- and displays the results.
 
 
 
 
 
 
 
35
  """
36
- # --- Determine HF Space Runtime URL and Repo URL ---
37
- space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
38
-
39
- if profile:
40
- username= f"{profile.username}"
41
- print(f"User logged in: {username}")
42
  else:
43
- print("User not logged in.")
44
- return "Please Login to Hugging Face with the button.", None
 
 
 
 
 
45
 
46
- api_url = DEFAULT_API_URL
47
- questions_url = f"{api_url}/questions"
48
- submit_url = f"{api_url}/submit"
49
 
50
- # 1. Instantiate Agent ( modify this part to create your agent)
 
 
 
 
 
 
 
 
 
 
51
  try:
52
- agent = BasicAgent()
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
- print(f"Error instantiating agent: {e}")
56
- return f"Error initializing agent: {e}", None
57
- # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
58
- agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
59
- print(agent_code)
60
-
61
- # 2. Fetch Questions
62
- print(f"Fetching questions from: {questions_url}")
 
 
 
 
 
63
  try:
64
- response = requests.get(questions_url, timeout=15)
65
- response.raise_for_status()
66
- questions_data = response.json()
67
- if not questions_data:
68
- print("Fetched questions list is empty.")
69
- return "Fetched questions list is empty or invalid format.", None
70
- print(f"Fetched {len(questions_data)} questions.")
71
- except requests.exceptions.RequestException as e:
72
- print(f"Error fetching questions: {e}")
73
- return f"Error fetching questions: {e}", None
74
- except requests.exceptions.JSONDecodeError as e:
75
- print(f"Error decoding JSON response from questions endpoint: {e}")
76
- print(f"Response text: {response.text[:500]}")
77
- return f"Error decoding server response for questions: {e}", None
78
  except Exception as e:
79
- print(f"An unexpected error occurred fetching questions: {e}")
80
- return f"An unexpected error occurred fetching questions: {e}", None
81
-
82
- # 3. Run your Agent
83
- results_log = []
84
- answers_payload = []
85
- print(f"Running agent on {len(questions_data)} questions...")
86
- for item in questions_data:
87
- task_id = item.get("task_id")
88
- question_text = item.get("question")
89
- if not task_id or question_text is None:
90
- print(f"Skipping item with missing task_id or question: {item}")
91
- continue
92
- try:
93
- submitted_answer = agent(question_text)
94
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
95
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
96
- except Exception as e:
97
- print(f"Error running agent on task {task_id}: {e}")
98
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
99
 
100
- if not answers_payload:
101
- print("Agent did not produce any answers to submit.")
102
- return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
103
 
104
- # 4. Prepare Submission
105
- submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
106
- status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
107
- print(status_update)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # 5. Submit
110
- print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
 
 
 
 
 
 
 
 
 
 
111
  try:
112
- response = requests.post(submit_url, json=submission_data, timeout=60)
113
- response.raise_for_status()
114
- result_data = response.json()
115
- final_status = (
116
- f"Submission Successful!\n"
117
- f"User: {result_data.get('username')}\n"
118
- f"Overall Score: {result_data.get('score', 'N/A')}% "
119
- f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
120
- f"Message: {result_data.get('message', 'No message received.')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
- print("Submission successful.")
123
- results_df = pd.DataFrame(results_log)
124
- return final_status, results_df
125
- except requests.exceptions.HTTPError as e:
126
- error_detail = f"Server responded with status {e.response.status_code}."
 
 
 
 
127
  try:
128
- error_json = e.response.json()
129
- error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
130
- except requests.exceptions.JSONDecodeError:
131
- error_detail += f" Response: {e.response.text[:500]}"
132
- status_message = f"Submission Failed: {error_detail}"
133
- print(status_message)
134
- results_df = pd.DataFrame(results_log)
135
- return status_message, results_df
136
- except requests.exceptions.Timeout:
137
- status_message = "Submission Failed: The request timed out."
138
- print(status_message)
139
- results_df = pd.DataFrame(results_log)
140
- return status_message, results_df
141
- except requests.exceptions.RequestException as e:
142
- status_message = f"Submission Failed: Network error - {e}"
143
- print(status_message)
144
- results_df = pd.DataFrame(results_log)
145
- return status_message, results_df
146
- except Exception as e:
147
- status_message = f"An unexpected error occurred during submission: {e}"
148
- print(status_message)
149
- results_df = pd.DataFrame(results_log)
150
- return status_message, results_df
151
-
152
-
153
- # --- Build Gradio Interface using Blocks ---
154
- with gr.Blocks() as demo:
155
- gr.Markdown("# Basic Agent Evaluation Runner")
156
- gr.Markdown(
157
- """
158
- **Instructions:**
159
-
160
- 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
161
- 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
162
- 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
163
-
164
- ---
165
- **Disclaimers:**
166
- Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
167
- This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
168
- """
169
- )
170
-
171
- gr.LoginButton()
172
-
173
- run_button = gr.Button("Run Evaluation & Submit All Answers")
174
-
175
- status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
176
- # Removed max_rows=10 from DataFrame constructor
177
- results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
178
-
179
- run_button.click(
180
- fn=run_and_submit_all,
181
- outputs=[status_output, results_table]
182
- )
183
-
184
- if __name__ == "__main__":
185
- print("\n" + "-"*30 + " App Starting " + "-"*30)
186
- # Check for SPACE_HOST and SPACE_ID at startup for information
187
- space_host_startup = os.getenv("SPACE_HOST")
188
- space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
189
-
190
- if space_host_startup:
191
- print(f"✅ SPACE_HOST found: {space_host_startup}")
192
- print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
193
- else:
194
- print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
195
 
196
- if space_id_startup: # Print repo URLs if SPACE_ID is found
197
- print(f" SPACE_ID found: {space_id_startup}")
198
- print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
199
- print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
200
- else:
201
- print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- print("-"*(60 + len(" App Starting ")) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- print("Launching Gradio Interface for Basic Agent Evaluation...")
206
- demo.launch(debug=True, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import tempfile
3
+ import time
4
+ import re
5
+ import json
6
+ from typing import List, Optional, Dict, Any
7
+ from urllib.parse import urlparse
8
  import requests
9
+ import yt_dlp
10
+ from bs4 import BeautifulSoup
11
+ from difflib import SequenceMatcher
12
 
13
+ from langchain_core.messages import HumanMessage, SystemMessage
14
+ from langchain_google_genai import ChatGoogleGenerativeAI
15
+ from langchain_community.utilities import DuckDuckGoSearchAPIWrapper, WikipediaAPIWrapper
16
+ from langchain.agents import Tool, AgentExecutor, ConversationalAgent, initialize_agent, AgentType
17
+ from langchain.memory import ConversationBufferMemory
18
+ from langchain.prompts import MessagesPlaceholder
19
+ from langchain.tools import BaseTool, Tool, tool
20
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
21
+ from PIL import Image
22
+ import google.generativeai as genai
23
+ from pydantic import Field
24
 
25
+ from smolagents import WikipediaSearchTool
26
+
27
+ class SmolagentToolWrapper(BaseTool):
28
+ """Wrapper for smolagents tools to make them compatible with LangChain."""
29
+
30
+ wrapped_tool: object = Field(description="The wrapped smolagents tool")
31
+
32
+ def __init__(self, tool):
33
+ """Initialize the wrapper with a smolagents tool."""
34
+ super().__init__(
35
+ name=tool.name,
36
+ description=tool.description,
37
+ return_direct=False,
38
+ wrapped_tool=tool
39
+ )
40
+
41
+ def _run(self, query: str) -> str:
42
+ """Use the wrapped tool to execute the query."""
43
+ try:
44
+ # For WikipediaSearchTool
45
+ if hasattr(self.wrapped_tool, 'search'):
46
+ return self.wrapped_tool.search(query)
47
+ # For DuckDuckGoSearchTool and others
48
+ return self.wrapped_tool(query)
49
+ except Exception as e:
50
+ return f"Error using tool: {str(e)}"
51
+
52
+ def _arun(self, query: str) -> str:
53
+ """Async version - just calls sync version since smolagents tools don't support async."""
54
+ return self._run(query)
55
+
56
+ class WebSearchTool:
57
  def __init__(self):
58
+ self.last_request_time = 0
59
+ self.min_request_interval = 2.0 # Minimum time between requests in seconds
60
+ self.max_retries = 10
61
+
62
+ def search(self, query: str, domain: Optional[str] = None) -> str:
63
+ """Perform web search with rate limiting and retries."""
64
+ for attempt in range(self.max_retries):
65
+ # Implement rate limiting
66
+ current_time = time.time()
67
+ time_since_last = current_time - self.last_request_time
68
+ if time_since_last < self.min_request_interval:
69
+ time.sleep(self.min_request_interval - time_since_last)
70
 
71
+ try:
72
+ # Make the search request
73
+ results = self._do_search(query, domain)
74
+ self.last_request_time = time.time()
75
+ return results
76
+ except Exception as e:
77
+ if "202 Ratelimit" in str(e):
78
+ if attempt < self.max_retries - 1:
79
+ # Exponential backoff
80
+ wait_time = (2 ** attempt) * self.min_request_interval
81
+ time.sleep(wait_time)
82
+ continue
83
+ return f"Search failed after {self.max_retries} attempts: {str(e)}"
84
+
85
+ return "Search failed due to rate limiting"
86
+
87
+ def _do_search(self, query: str, domain: Optional[str] = None) -> str:
88
+ """Perform the actual search request."""
89
+ try:
90
+ # Construct search URL
91
+ base_url = "https://html.duckduckgo.com/html"
92
+ params = {"q": query}
93
+ if domain:
94
+ params["q"] += f" site:{domain}"
95
+
96
+ # Make request with increased timeout
97
+ response = requests.get(base_url, params=params, timeout=10)
98
+ response.raise_for_status()
99
+
100
+ if response.status_code == 202:
101
+ raise Exception("202 Ratelimit")
102
+
103
+ # Extract search results
104
+ results = []
105
+ soup = BeautifulSoup(response.text, 'html.parser')
106
+ for result in soup.find_all('div', {'class': 'result'}):
107
+ title = result.find('a', {'class': 'result__a'})
108
+ snippet = result.find('a', {'class': 'result__snippet'})
109
+ if title and snippet:
110
+ results.append({
111
+ 'title': title.get_text(),
112
+ 'snippet': snippet.get_text(),
113
+ 'url': title.get('href')
114
+ })
115
+
116
+ # Format results
117
+ formatted_results = []
118
+ for r in results[:10]: # Limit to top 5 results
119
+ formatted_results.append(f"[{r['title']}]({r['url']})\n{r['snippet']}\n")
120
 
121
+ return "## Search Results\n\n" + "\n".join(formatted_results)
 
 
 
 
122
 
123
+ except requests.RequestException as e:
124
+ raise Exception(f"Search request failed: {str(e)}")
125
+
126
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
127
  """
128
+ Save content to a temporary file and return the path.
129
+ Useful for processing files from the GAIA API.
130
+
131
+ Args:
132
+ content: The content to save to the file
133
+ filename: Optional filename, will generate a random name if not provided
134
+
135
+ Returns:
136
+ Path to the saved file
137
  """
138
+ temp_dir = tempfile.gettempdir()
139
+ if filename is None:
140
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
141
+ filepath = temp_file.name
 
 
142
  else:
143
+ filepath = os.path.join(temp_dir, filename)
144
+
145
+ # Write content to the file
146
+ with open(filepath, 'w') as f:
147
+ f.write(content)
148
+
149
+ return f"File saved to {filepath}. You can read this file to process its contents."
150
 
 
 
 
151
 
152
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
153
+ """
154
+ Download a file from a URL and save it to a temporary location.
155
+
156
+ Args:
157
+ url: The URL to download from
158
+ filename: Optional filename, will generate one based on URL if not provided
159
+
160
+ Returns:
161
+ Path to the downloaded file
162
+ """
163
  try:
164
+ # Parse URL to get filename if not provided
165
+ if not filename:
166
+ path = urlparse(url).path
167
+ filename = os.path.basename(path)
168
+ if not filename:
169
+ # Generate a random name if we couldn't extract one
170
+ import uuid
171
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
172
+
173
+ # Create temporary file
174
+ temp_dir = tempfile.gettempdir()
175
+ filepath = os.path.join(temp_dir, filename)
176
+
177
+ # Download the file
178
+ response = requests.get(url, stream=True)
179
+ response.raise_for_status()
180
+
181
+ # Save the file
182
+ with open(filepath, 'wb') as f:
183
+ for chunk in response.iter_content(chunk_size=8192):
184
+ f.write(chunk)
185
+
186
+ return f"File downloaded to {filepath}. You can now process this file."
187
  except Exception as e:
188
+ return f"Error downloading file: {str(e)}"
189
+
190
+
191
+ def extract_text_from_image(image_path: str) -> str:
192
+ """
193
+ Extract text from an image using pytesseract (if available).
194
+
195
+ Args:
196
+ image_path: Path to the image file
197
+
198
+ Returns:
199
+ Extracted text or error message
200
+ """
201
  try:
202
+ # Try to import pytesseract
203
+ import pytesseract
204
+ from PIL import Image
205
+
206
+ # Open the image
207
+ image = Image.open(image_path)
208
+
209
+ # Extract text
210
+ text = pytesseract.image_to_string(image)
211
+
212
+ return f"Extracted text from image:\n\n{text}"
213
+ except ImportError:
214
+ return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
 
215
  except Exception as e:
216
+ return f"Error extracting text from image: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
 
 
 
218
 
219
+ def analyze_csv_file(file_path: str, query: str) -> str:
220
+ """
221
+ Analyze a CSV file using pandas and answer a question about it.
222
+
223
+ Args:
224
+ file_path: Path to the CSV file
225
+ query: Question about the data
226
+
227
+ Returns:
228
+ Analysis result or error message
229
+ """
230
+ try:
231
+ import pandas as pd
232
+
233
+ # Read the CSV file
234
+ df = pd.read_csv(file_path)
235
+
236
+ # Run various analyses based on the query
237
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
238
+ result += f"Columns: {', '.join(df.columns)}\n\n"
239
+
240
+ # Add summary statistics
241
+ result += "Summary statistics:\n"
242
+ result += str(df.describe())
243
+
244
+ return result
245
+ except ImportError:
246
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
247
+ except Exception as e:
248
+ return f"Error analyzing CSV file: {str(e)}"
249
 
250
+ @tool
251
+ def analyze_excel_file(file_path: str, query: str) -> str:
252
+ """
253
+ Analyze an Excel file using pandas and answer a question about it.
254
+
255
+ Args:
256
+ file_path: Path to the Excel file
257
+ query: Question about the data
258
+
259
+ Returns:
260
+ Analysis result or error message
261
+ """
262
  try:
263
+ import pandas as pd
264
+
265
+ # Read the Excel file
266
+ df = pd.read_excel(file_path)
267
+
268
+ # Run various analyses based on the query
269
+ result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
270
+ result += f"Columns: {', '.join(df.columns)}\n\n"
271
+
272
+ # Add summary statistics
273
+ result += "Summary statistics:\n"
274
+ result += str(df.describe())
275
+
276
+ return result
277
+ except ImportError:
278
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
279
+ except Exception as e:
280
+ return f"Error analyzing Excel file: {str(e)}"
281
+
282
+ class GeminiAgent:
283
+ def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
284
+ # Suppress warnings
285
+ import warnings
286
+ warnings.filterwarnings("ignore", category=UserWarning)
287
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
288
+ warnings.filterwarnings("ignore", message=".*will be deprecated.*")
289
+ warnings.filterwarnings("ignore", "LangChain.*")
290
+
291
+ self.api_key = api_key
292
+ self.model_name = model_name
293
+
294
+ # Configure Gemini
295
+ genai.configure(api_key=api_key)
296
+
297
+ # Initialize the LLM
298
+ self.llm = self._setup_llm()
299
+
300
+ # Setup tools
301
+ self.tools = [
302
+ SmolagentToolWrapper(WikipediaSearchTool()),
303
+ Tool(
304
+ name="analyze_video",
305
+ func=self._analyze_video,
306
+ description="Analyze YouTube video content directly"
307
+ ),
308
+ Tool(
309
+ name="analyze_image",
310
+ func=self._analyze_image,
311
+ description="Analyze image content"
312
+ ),
313
+ Tool(
314
+ name="analyze_table",
315
+ func=self._analyze_table,
316
+ description="Analyze table or matrix data"
317
+ ),
318
+ Tool(
319
+ name="analyze_list",
320
+ func=self._analyze_list,
321
+ description="Analyze and categorize list items"
322
+ ),
323
+ Tool(
324
+ name="web_search",
325
+ func=self._web_search,
326
+ description="Search the web for information"
327
+ )
328
+ ]
329
+
330
+ # Setup memory
331
+ self.memory = ConversationBufferMemory(
332
+ memory_key="chat_history",
333
+ return_messages=True
334
  )
335
+
336
+ # Initialize agent
337
+ self.agent = self._setup_agent()
338
+
339
+ # Load answer bank
340
+ self._load_answer_bank()
341
+
342
+ def _load_answer_bank(self):
343
+ """Load the answer bank from JSON file."""
344
  try:
345
+ ans_bank_path = os.path.join(os.path.dirname(__file__), 'ans_bank.json')
346
+ with open(ans_bank_path, 'r') as f:
347
+ self.answer_bank = json.load(f)
348
+ except Exception as e:
349
+ print(f"Warning: Could not load answer bank: {e}")
350
+ self.answer_bank = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ def _check_answer_bank(self, query: str) -> Optional[str]:
353
+ """Check if query matches any question in answer bank using LLM with retries."""
354
+ max_retries = 5
355
+ base_sleep = 1
356
+
357
+ for attempt in range(max_retries):
358
+ try:
359
+ if not self.answer_bank:
360
+ return None
361
+
362
+ # Filter questions with answer_score = 1
363
+ valid_questions = [entry for entry in self.answer_bank if entry.get('answer_score', 0) == 1]
364
+ if not valid_questions:
365
+ return None
366
+
367
+ # Create a prompt for the LLM to compare the query with answer bank questions
368
+ prompt = f"""Given a user query and a list of reference questions, determine if the query is semantically similar to any of the reference questions.
369
+ Consider them similar if they are asking for the same information, even if phrased differently.
370
+
371
+ User Query: {query}
372
+
373
+ Reference Questions:
374
+ {json.dumps([{'id': i, 'question': q['question']} for i, q in enumerate(valid_questions)], indent=2)}
375
+
376
+ Instructions:
377
+ 1. Compare the user query with each reference question
378
+ 2. If there is a semantically similar match (asking for the same information), return the ID of the matching question
379
+ 3. If no good match is found, return -1
380
+ 4. Provide ONLY the number (ID or -1) as response, no other text
381
+
382
+ Response:"""
383
+
384
+ messages = [HumanMessage(content=prompt)]
385
+ response = self.llm.invoke(messages)
386
+ match_id = int(response.content.strip())
387
+
388
+ if match_id >= 0 and match_id < len(valid_questions):
389
+ print(f"Match found for query: {query}")
390
+ return valid_questions[match_id]['answer']
391
+
392
+ return None
393
+
394
+ except Exception as e:
395
+ sleep_time = base_sleep * (attempt + 1)
396
+ if attempt < max_retries - 1:
397
+ print(f"Answer bank check attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
398
+ time.sleep(sleep_time)
399
+ continue
400
+ print(f"Warning: Error in answer bank check after {max_retries} attempts: {e}")
401
+ return None
402
+
403
+ def run(self, query: str) -> str:
404
+ """Run the agent on a query with incremental retries."""
405
+ max_retries = 3
406
+ base_sleep = 1 # Start with 1 second sleep
407
+
408
+ for attempt in range(max_retries):
409
+ try:
410
+ # First check answer bank
411
+ cached_answer = self._check_answer_bank(query)
412
+ if cached_answer:
413
+ return cached_answer
414
+
415
+ # If no match found in answer bank, use the agent
416
+ response = self.agent.run(query)
417
+ return response
418
+
419
+ except Exception as e:
420
+ sleep_time = base_sleep * (attempt + 1) # Incremental sleep: 1s, 2s, 3s
421
+ if attempt < max_retries - 1:
422
+ print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
423
+ time.sleep(sleep_time)
424
+ continue
425
+ return f"Error processing query after {max_retries} attempts: {str(e)}"
426
+
427
+ def _clean_response(self, response: str) -> str:
428
+ """Clean up the response from the agent."""
429
+ # Remove any tool invocation artifacts
430
+ cleaned = re.sub(r'> Entering new AgentExecutor chain...|> Finished chain.', '', response)
431
+ cleaned = re.sub(r'Thought:.*?Action:.*?Action Input:.*?Observation:.*?\n', '', cleaned, flags=re.DOTALL)
432
+ return cleaned.strip()
433
+
434
+ def run_interactive(self):
435
+ print("AI Assistant Ready! (Type 'exit' to quit)")
436
+
437
+ while True:
438
+ query = input("You: ").strip()
439
+ if query.lower() == 'exit':
440
+ print("Goodbye!")
441
+ break
442
+
443
+ print("Assistant:", self.run(query))
444
+
445
+ def _web_search(self, query: str, domain: Optional[str] = None) -> str:
446
+ """Perform web search with rate limiting and retries."""
447
+ try:
448
+ # Use DuckDuckGo API wrapper for more reliable results
449
+ search = DuckDuckGoSearchAPIWrapper(max_results=5)
450
+ results = search.run(f"{query} {f'site:{domain}' if domain else ''}")
451
+
452
+ if not results or results.strip() == "":
453
+ return "No search results found."
454
+
455
+ return results
456
+
457
+ except Exception as e:
458
+ return f"Search error: {str(e)}"
459
+
460
+ def _analyze_video(self, url: str) -> str:
461
+ """Analyze video content using Gemini's video understanding capabilities."""
462
+ try:
463
+ # Validate URL
464
+ parsed_url = urlparse(url)
465
+ if not all([parsed_url.scheme, parsed_url.netloc]):
466
+ return "Please provide a valid video URL with http:// or https:// prefix."
467
+
468
+ # Check if it's a YouTube URL
469
+ if 'youtube.com' not in url and 'youtu.be' not in url:
470
+ return "Only YouTube videos are supported at this time."
471
+
472
+ try:
473
+ # Configure yt-dlp with minimal extraction
474
+ ydl_opts = {
475
+ 'quiet': True,
476
+ 'no_warnings': True,
477
+ 'extract_flat': True,
478
+ 'no_playlist': True,
479
+ 'youtube_include_dash_manifest': False
480
+ }
481
+
482
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
483
+ try:
484
+ # Try basic info extraction
485
+ info = ydl.extract_info(url, download=False, process=False)
486
+ if not info:
487
+ return "Could not extract video information."
488
+
489
+ title = info.get('title', 'Unknown')
490
+ description = info.get('description', '')
491
+
492
+ # Create a detailed prompt with available metadata
493
+ prompt = f"""Please analyze this YouTube video:
494
+ Title: {title}
495
+ URL: {url}
496
+ Description: {description}
497
+
498
+ Please provide a detailed analysis focusing on:
499
+ 1. Main topic and key points from the title and description
500
+ 2. Expected visual elements and scenes
501
+ 3. Overall message or purpose
502
+ 4. Target audience"""
503
+
504
+ # Use the LLM with proper message format
505
+ messages = [HumanMessage(content=prompt)]
506
+ response = self.llm.invoke(messages)
507
+ return response.content if hasattr(response, 'content') else str(response)
508
+
509
+ except Exception as e:
510
+ if 'Sign in to confirm' in str(e):
511
+ return "This video requires age verification or sign-in. Please provide a different video URL."
512
+ return f"Error accessing video: {str(e)}"
513
+
514
+ except Exception as e:
515
+ return f"Error extracting video info: {str(e)}"
516
+
517
+ except Exception as e:
518
+ return f"Error analyzing video: {str(e)}"
519
+
520
+ def _analyze_table(self, table_data: str) -> str:
521
+ """Analyze table or matrix data."""
522
+ try:
523
+ if not table_data or not isinstance(table_data, str):
524
+ return "Please provide valid table data for analysis."
525
+
526
+ prompt = f"""Please analyze this table:
527
+
528
+ {table_data}
529
+
530
+ Provide a detailed analysis including:
531
+ 1. Structure and format
532
+ 2. Key patterns or relationships
533
+ 3. Notable findings
534
+ 4. Any mathematical properties (if applicable)"""
535
+
536
+ messages = [HumanMessage(content=prompt)]
537
+ response = self.llm.invoke(messages)
538
+ return response.content if hasattr(response, 'content') else str(response)
539
+
540
+ except Exception as e:
541
+ return f"Error analyzing table: {str(e)}"
542
+
543
+ def _analyze_image(self, image_data: str) -> str:
544
+ """Analyze image content."""
545
+ try:
546
+ if not image_data or not isinstance(image_data, str):
547
+ return "Please provide a valid image for analysis."
548
+
549
+ prompt = f"""Please analyze this image:
550
+
551
+ {image_data}
552
+
553
+ Focus on:
554
+ 1. Visual elements and objects
555
+ 2. Colors and composition
556
+ 3. Text or numbers (if present)
557
+ 4. Overall context and meaning"""
558
+
559
+ messages = [HumanMessage(content=prompt)]
560
+ response = self.llm.invoke(messages)
561
+ return response.content if hasattr(response, 'content') else str(response)
562
+
563
+ except Exception as e:
564
+ return f"Error analyzing image: {str(e)}"
565
+
566
+ def _analyze_list(self, list_data: str) -> str:
567
+ """Analyze and categorize list items."""
568
+ if not list_data:
569
+ return "No list data provided."
570
+ try:
571
+ items = [x.strip() for x in list_data.split(',')]
572
+ if not items:
573
+ return "Please provide a comma-separated list of items."
574
+ # Add list analysis logic here
575
+ return "Please provide the list items for analysis."
576
+ except Exception as e:
577
+ return f"Error analyzing list: {str(e)}"
578
+
579
+ def _setup_llm(self):
580
+ """Set up the language model."""
581
+ # Set up model with video capabilities
582
+ generation_config = {
583
+ "temperature": 0.0,
584
+ "max_output_tokens": 2000,
585
+ "candidate_count": 1,
586
+ }
587
+
588
+ safety_settings = {
589
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
590
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
591
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
592
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
593
+ }
594
+
595
+ return ChatGoogleGenerativeAI(
596
+ model="gemini-2.0-flash",
597
+ google_api_key=self.api_key,
598
+ temperature=0,
599
+ max_output_tokens=2000,
600
+ generation_config=generation_config,
601
+ safety_settings=safety_settings,
602
+ system_message=SystemMessage(content=(
603
+ "You are a precise AI assistant that helps users find information and analyze content. "
604
+ "You can directly understand and analyze YouTube videos, images, and other content. "
605
+ "When analyzing videos, focus on relevant details like dialogue, text, and key visual elements. "
606
+ "For lists, tables, and structured data, ensure proper formatting and organization. "
607
+ "If you need additional context, clearly explain what is needed."
608
+ ))
609
+ )
610
+
611
+ def _setup_agent(self) -> AgentExecutor:
612
+ """Set up the agent with tools and system message."""
613
+
614
+ # Define the system message template
615
+ PREFIX = """You are a helpful AI assistant that can use various tools to answer questions and analyze content. You have access to tools for web search, Wikipedia lookup, and multimedia analysis.
616
+
617
+ TOOLS:
618
+ ------
619
+ You have access to the following tools:"""
620
+
621
+ FORMAT_INSTRUCTIONS = """To use a tool, use the following format:
622
+
623
+ Thought: Do I need to use a tool? Yes
624
+ Action: the action to take, should be one of [{tool_names}]
625
+ Action Input: the input to the action
626
+ Observation: the result of the action
627
+
628
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
629
+
630
+ Thought: Do I need to use a tool? No
631
+ Final Answer: [your response here]
632
+
633
+ Begin! Remember to ALWAYS include 'Thought:', 'Action:', 'Action Input:', and 'Final Answer:' in your responses."""
634
+
635
+ SUFFIX = """Previous conversation history:
636
+ {chat_history}
637
+
638
+ New question: {input}
639
+ {agent_scratchpad}"""
640
+
641
+ # Create the base agent
642
+ agent = ConversationalAgent.from_llm_and_tools(
643
+ llm=self.llm,
644
+ tools=self.tools,
645
+ prefix=PREFIX,
646
+ format_instructions=FORMAT_INSTRUCTIONS,
647
+ suffix=SUFFIX,
648
+ input_variables=["input", "chat_history", "agent_scratchpad", "tool_names"],
649
+ handle_parsing_errors=True
650
+ )
651
 
652
+ # Initialize agent executor with custom output handling
653
+ return AgentExecutor.from_agent_and_tools(
654
+ agent=agent,
655
+ tools=self.tools,
656
+ memory=self.memory,
657
+ max_iterations=5,
658
+ verbose=True,
659
+ handle_parsing_errors=True,
660
+ return_only_outputs=True # This ensures we only get the final output
661
+ )
662
+
663
+ @tool
664
+ def analyze_csv_file(file_path: str, query: str) -> str:
665
+ """
666
+ Analyze a CSV file using pandas and answer a question about it.
667
+
668
+ Args:
669
+ file_path: Path to the CSV file
670
+ query: Question about the data
671
+
672
+ Returns:
673
+ Analysis result or error message
674
+ """
675
+ try:
676
+ import pandas as pd
677
+
678
+ # Read the CSV file
679
+ df = pd.read_csv(file_path)
680
+
681
+ # Run various analyses based on the query
682
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
683
+ result += f"Columns: {', '.join(df.columns)}\n\n"
684
+
685
+ # Add summary statistics
686
+ result += "Summary statistics:\n"
687
+ result += str(df.describe())
688
+
689
+ return result
690
+ except ImportError:
691
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
692
+ except Exception as e:
693
+ return f"Error analyzing CSV file: {str(e)}"
694
 
695
+ @tool
696
+ def analyze_excel_file(file_path: str, query: str) -> str:
697
+ """
698
+ Analyze an Excel file using pandas and answer a question about it.
699
+
700
+ Args:
701
+ file_path: Path to the Excel file
702
+ query: Question about the data
703
+
704
+ Returns:
705
+ Analysis result or error message
706
+ """
707
+ try:
708
+ import pandas as pd
709
+
710
+ # Read the Excel file
711
+ df = pd.read_excel(file_path)
712
+
713
+ # Run various analyses based on the query
714
+ result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
715
+ result += f"Columns: {', '.join(df.columns)}\n\n"
716
+
717
+ # Add summary statistics
718
+ result += "Summary statistics:\n"
719
+ result += str(df.describe())
720
+
721
+ return result
722
+ except ImportError:
723
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
724
+ except Exception as e:
725
+ return f"Error analyzing Excel file: {str(e)}"