Spaces:
Sleeping
Sleeping
from typing import Any, Optional | |
import re | |
from smolagents.tools import Tool | |
from smolagents.agent_types import AgentImage | |
import requests | |
from io import BytesIO | |
import os | |
import tempfile | |
class FindImageOnlineTool(Tool): | |
name = "find_image_online" | |
description = "Searches for images online based on a query and returns an image that matches the description." | |
inputs = {'query': {'type': 'string', 'description': 'The search query for the image you want to find.'}} | |
output_type = "image" | |
def __init__(self, web_search_tool=None, visit_webpage_tool=None): | |
super().__init__() | |
self.web_search_tool = web_search_tool | |
self.visit_webpage_tool = visit_webpage_tool | |
self.is_initialized = True | |
def extract_image_urls(self, markdown_content): | |
# Extract image URLs from markdown using regex | |
# Look for standard markdown image patterns  | |
md_image_pattern = r'!\[.*?\]\((https?://[^)]+\.(jpg|jpeg|png|gif|webp))\)' | |
md_images = re.findall(md_image_pattern, markdown_content) | |
# Also look for direct URLs that end with image extensions | |
direct_url_pattern = r'(https?://[^\s)]+\.(jpg|jpeg|png|gif|webp))' | |
direct_urls = re.findall(direct_url_pattern, markdown_content) | |
# Combine and deduplicate results | |
image_urls = [url for url, _ in md_images] + [url for url, _ in direct_urls] | |
return list(set(image_urls)) | |
def download_image(self, url): | |
try: | |
response = requests.get(url, stream=True, timeout=10) | |
response.raise_for_status() | |
# Create a temporary file with appropriate extension | |
ext = os.path.splitext(url)[1] | |
if not ext or ext not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']: | |
ext = '.jpg' # Default extension | |
temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False) | |
temp_file.write(response.content) | |
temp_file.close() | |
return temp_file.name, url | |
except Exception as e: | |
print(f"Error downloading image from {url}: {str(e)}") | |
return None, url | |
def forward(self, query: str) -> Any: | |
if not self.web_search_tool or not self.visit_webpage_tool: | |
return "Error: Web search and visit webpage tools must be provided." | |
try: | |
# Step 1: Search for the query + "image" | |
search_query = f"{query} image" | |
search_results = self.web_search_tool.forward(search_query) | |
# Step 2: Extract URLs from search results | |
url_pattern = r'\((https?://[^)]+)\)' | |
urls = re.findall(url_pattern, search_results) | |
# Step 3: Visit each page and look for images | |
for url in urls[:3]: # Limit to first 3 results for efficiency | |
try: | |
page_content = self.visit_webpage_tool.forward(url) | |
image_urls = self.extract_image_urls(page_content) | |
# Step 4: Download the first valid image found | |
for img_url in image_urls[:5]: # Try up to 5 images per page | |
img_path, source_url = self.download_image(img_url) | |
if img_path: | |
# Return both the image and the source information | |
return { | |
"image": AgentImage(img_path), | |
"source_url": source_url, | |
"page_url": url, | |
"query": query | |
} | |
except Exception as e: | |
continue # Try the next URL if this one fails | |
return f"Could not find a suitable image for '{query}'. Please try a different query." | |
except Exception as e: | |
return f"Error finding image: {str(e)}" |