mrprimenotes's picture
Update app.py
a9e7554 verified
import gradio as gr
from gradio_bbox_annotator import BBoxAnnotator
import json
import os
from pathlib import Path
from PIL import Image
from io import BytesIO
import tempfile
import shutil
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define categories and their limits
CATEGORY_LIMITS = {
"advertisement": 1, # Maximum 1 advertisement annotation per image
"text": 2 # Maximum 2 text annotations per image
}
CATEGORIES = list(CATEGORY_LIMITS.keys())
MAX_SIZE = [1024, 1024] # Maximum width and height for resized images
class ImageProcessor:
def __init__(self):
# Create a persistent directory for resized images
self.base_dir = os.path.join(tempfile.gettempdir(), "annotation_tool")
self.resized_dir = os.path.join(self.base_dir, "resized_images")
self._setup_directories()
logger.info(f"Initialized ImageProcessor with directory: {self.base_dir}")
def _setup_directories(self):
"""Create necessary directories if they don't exist"""
os.makedirs(self.resized_dir, exist_ok=True)
logger.info(f"Set up directories: {self.resized_dir}")
def resize_image(self, image_path):
"""Resize image and save to persistent directory"""
try:
logger.info(f"Processing image: {image_path}")
# Read original image
with open(image_path, "rb") as f:
img = Image.open(BytesIO(f.read()))
img.thumbnail(MAX_SIZE, Image.Resampling.LANCZOS)
# Create a unique filename for the resized image
original_filename = os.path.basename(image_path)
resized_filename = f"resized_{original_filename}"
resized_path = os.path.join(self.resized_dir, resized_filename)
# Save resized image
img.save(resized_path)
logger.info(f"Saved resized image to: {resized_path}")
return resized_path
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise
class AnnotationManager:
def __init__(self):
self.annotations = {}
self.image_processor = ImageProcessor()
def validate_annotations(self, bbox_data):
"""Validate the annotation data and return (is_valid, error_message)"""
if not bbox_data or not isinstance(bbox_data, tuple):
return False, "No image or annotations provided"
image_path, annotations = bbox_data
if not isinstance(image_path, str):
return False, "Invalid image format"
if not annotations:
return False, "No annotations drawn"
# Count annotations per category
category_counts = {cat: 0 for cat in CATEGORIES}
for ann in annotations:
if len(ann) != 5:
return False, "Invalid annotation format"
y1, y2, x1, x2, label = ann
# Validate coordinates
if any(not isinstance(coord, (int, float)) for coord in [y1, y2, x1, x2]):
return False, "Invalid coordinate values"
# Validate label
if not label or label not in CATEGORIES:
return False, f"Invalid or missing label. Must be one of: {', '.join(CATEGORIES)}"
# Count this annotation
category_counts[label] += 1
# Check category limits
for category, count in category_counts.items():
limit = CATEGORY_LIMITS[category]
if count > limit:
return False, f"Too many {category} annotations. Maximum allowed: {limit}"
return True, ""
def process_upload(self, image_path):
"""Process uploaded image"""
if not isinstance(image_path, (str, bytes, os.PathLike)):
logger.warning(f"Invalid image path type: {type(image_path)}")
return None
try:
logger.info(f"Processing upload: {image_path}")
resized_path = self.image_processor.resize_image(image_path)
logger.info(f"Successfully processed upload: {resized_path}")
return resized_path
except Exception as e:
logger.error(f"Error in process_upload: {str(e)}")
return None
def add_annotation(self, bbox_data):
"""Add or update annotations for an image"""
is_valid, error_msg = self.validate_annotations(bbox_data)
if not is_valid:
return self.get_json_annotations(), f"❌ Error: {error_msg}"
image_path, annotations = bbox_data
# Use original filename (remove 'resized_' prefix)
filename = os.path.basename(image_path)
if filename.startswith("resized_"):
filename = filename[8:]
formatted_annotations = []
for ann in annotations:
y1, y2, x1, x2, label = ann
formatted_annotations.append({
"annotation": [y1, y2, x1, x2],
"label": label
})
self.annotations[filename] = formatted_annotations
# Count annotations by type
counts = {cat: sum(1 for ann in annotations if ann[4] == cat) for cat in CATEGORIES}
counts_str = ", ".join(f"{count} {cat}" for cat, count in counts.items())
success_msg = f"βœ… Successfully saved for {filename}: {counts_str}"
return self.get_json_annotations(), success_msg
def get_json_annotations(self):
"""Get all annotations as formatted JSON string"""
return json.dumps(self.annotations, indent=2)
def clear_annotations(self):
"""Clear all annotations"""
self.annotations = {}
return "", "πŸ—‘οΈ All annotations cleared"
def create_interface():
annotation_mgr = AnnotationManager()
with gr.Blocks() as demo:
gr.Markdown(f"""
# Advertisement and Text Annotation Tool
**Instructions:**
1. Upload an image (will be automatically resized to max {MAX_SIZE[0]}x{MAX_SIZE[1]})
2. Draw bounding boxes and select the appropriate label
3. Click 'Save Annotations' to add to the collection
4. Repeat for all images
5. Copy the combined JSON when finished
**Annotation Limits per Image:**
- advertisement: Maximum 1 annotation
- text: Maximum 2 annotations
""")
with gr.Row():
with gr.Column(scale=2):
bbox_input = BBoxAnnotator(
show_label=True,
label="Draw Bounding Boxes",
show_download_button=True,
interactive=True,
categories=CATEGORIES
)
with gr.Column(scale=1):
json_output = gr.TextArea(
label="Combined Annotations JSON",
interactive=True,
lines=15,
show_copy_button=True
)
with gr.Row():
save_btn = gr.Button("Save Current Image Annotations", variant="primary")
clear_btn = gr.Button("Clear All Annotations", variant="secondary")
# Add status message
status_msg = gr.Markdown(label="Status")
# Event handlers
def handle_image_upload(bbox_data):
try:
if not bbox_data or not isinstance(bbox_data, tuple):
return None, "No image uploaded"
image_path, annotations = bbox_data
if not image_path:
return None, "No image path provided"
logger.info(f"Handling upload for: {image_path}")
resized_path = annotation_mgr.process_upload(image_path)
if resized_path and os.path.exists(resized_path):
logger.info(f"Processed image path: {resized_path}")
# Return the resized path and keep any existing annotations
return (resized_path, annotations)
else:
error_msg = "Failed to process image"
logger.error(error_msg)
return None, error_msg
except Exception as e:
error_msg = f"Error in upload handler: {str(e)}"
logger.error(error_msg)
return None, error_msg
# Handle image upload and resizing
bbox_input.upload(
fn=handle_image_upload,
inputs=[bbox_input],
outputs=[bbox_input]
)
save_btn.click(
fn=annotation_mgr.add_annotation,
inputs=[bbox_input],
outputs=[json_output, status_msg]
)
clear_btn.click(
fn=annotation_mgr.clear_annotations,
inputs=[],
outputs=[json_output, status_msg]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()