First_agent_template / tools /find_image_online.py
PLBot's picture
Create find_image_online.py
300f4a0 verified
from typing import Any, Optional
import re
from smolagents.tools import Tool
from smolagents.agent_types import AgentImage
import requests
from io import BytesIO
import os
import tempfile
class FindImageOnlineTool(Tool):
name = "find_image_online"
description = "Searches for images online based on a query and returns an image that matches the description."
inputs = {'query': {'type': 'string', 'description': 'The search query for the image you want to find.'}}
output_type = "image"
def __init__(self, web_search_tool=None, visit_webpage_tool=None):
super().__init__()
self.web_search_tool = web_search_tool
self.visit_webpage_tool = visit_webpage_tool
self.is_initialized = True
def extract_image_urls(self, markdown_content):
# Extract image URLs from markdown using regex
# Look for standard markdown image patterns ![alt](url)
md_image_pattern = r'!\[.*?\]\((https?://[^)]+\.(jpg|jpeg|png|gif|webp))\)'
md_images = re.findall(md_image_pattern, markdown_content)
# Also look for direct URLs that end with image extensions
direct_url_pattern = r'(https?://[^\s)]+\.(jpg|jpeg|png|gif|webp))'
direct_urls = re.findall(direct_url_pattern, markdown_content)
# Combine and deduplicate results
image_urls = [url for url, _ in md_images] + [url for url, _ in direct_urls]
return list(set(image_urls))
def download_image(self, url):
try:
response = requests.get(url, stream=True, timeout=10)
response.raise_for_status()
# Create a temporary file with appropriate extension
ext = os.path.splitext(url)[1]
if not ext or ext not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']:
ext = '.jpg' # Default extension
temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
temp_file.write(response.content)
temp_file.close()
return temp_file.name, url
except Exception as e:
print(f"Error downloading image from {url}: {str(e)}")
return None, url
def forward(self, query: str) -> Any:
if not self.web_search_tool or not self.visit_webpage_tool:
return "Error: Web search and visit webpage tools must be provided."
try:
# Step 1: Search for the query + "image"
search_query = f"{query} image"
search_results = self.web_search_tool.forward(search_query)
# Step 2: Extract URLs from search results
url_pattern = r'\((https?://[^)]+)\)'
urls = re.findall(url_pattern, search_results)
# Step 3: Visit each page and look for images
for url in urls[:3]: # Limit to first 3 results for efficiency
try:
page_content = self.visit_webpage_tool.forward(url)
image_urls = self.extract_image_urls(page_content)
# Step 4: Download the first valid image found
for img_url in image_urls[:5]: # Try up to 5 images per page
img_path, source_url = self.download_image(img_url)
if img_path:
# Return both the image and the source information
return {
"image": AgentImage(img_path),
"source_url": source_url,
"page_url": url,
"query": query
}
except Exception as e:
continue # Try the next URL if this one fails
return f"Could not find a suitable image for '{query}'. Please try a different query."
except Exception as e:
return f"Error finding image: {str(e)}"