from typing import Any, Optional import re from smolagents.tools import Tool from smolagents.agent_types import AgentImage import requests from io import BytesIO import os import tempfile class FindImageOnlineTool(Tool): name = "find_image_online" description = "Searches for images online based on a query and returns an image that matches the description." inputs = {'query': {'type': 'string', 'description': 'The search query for the image you want to find.'}} output_type = "image" def __init__(self, web_search_tool=None, visit_webpage_tool=None): super().__init__() self.web_search_tool = web_search_tool self.visit_webpage_tool = visit_webpage_tool self.is_initialized = True def extract_image_urls(self, markdown_content): # Extract image URLs from markdown using regex # Look for standard markdown image patterns ![alt](url) md_image_pattern = r'!\[.*?\]\((https?://[^)]+\.(jpg|jpeg|png|gif|webp))\)' md_images = re.findall(md_image_pattern, markdown_content) # Also look for direct URLs that end with image extensions direct_url_pattern = r'(https?://[^\s)]+\.(jpg|jpeg|png|gif|webp))' direct_urls = re.findall(direct_url_pattern, markdown_content) # Combine and deduplicate results image_urls = [url for url, _ in md_images] + [url for url, _ in direct_urls] return list(set(image_urls)) def download_image(self, url): try: response = requests.get(url, stream=True, timeout=10) response.raise_for_status() # Create a temporary file with appropriate extension ext = os.path.splitext(url)[1] if not ext or ext not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']: ext = '.jpg' # Default extension temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False) temp_file.write(response.content) temp_file.close() return temp_file.name, url except Exception as e: print(f"Error downloading image from {url}: {str(e)}") return None, url def forward(self, query: str) -> Any: if not self.web_search_tool or not self.visit_webpage_tool: return "Error: Web search and visit webpage tools must be provided." try: # Step 1: Search for the query + "image" search_query = f"{query} image" search_results = self.web_search_tool.forward(search_query) # Step 2: Extract URLs from search results url_pattern = r'\((https?://[^)]+)\)' urls = re.findall(url_pattern, search_results) # Step 3: Visit each page and look for images for url in urls[:3]: # Limit to first 3 results for efficiency try: page_content = self.visit_webpage_tool.forward(url) image_urls = self.extract_image_urls(page_content) # Step 4: Download the first valid image found for img_url in image_urls[:5]: # Try up to 5 images per page img_path, source_url = self.download_image(img_url) if img_path: # Return both the image and the source information return { "image": AgentImage(img_path), "source_url": source_url, "page_url": url, "query": query } except Exception as e: continue # Try the next URL if this one fails return f"Could not find a suitable image for '{query}'. Please try a different query." except Exception as e: return f"Error finding image: {str(e)}"