Spaces:
Running
on
L4
Running
on
L4
import gradio as gr | |
import easyocr | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import pipeline | |
import logging | |
import os | |
import time | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Check for GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
using_gpu = device == "cuda" | |
logger.info(f"Using device: {device}") | |
class SmartGlassesSystem: | |
"""Main class for Police Smart Glasses AI system""" | |
def __init__(self): | |
self.initialize_models() | |
self.supported_languages = { | |
"Arabic": ["ar", "en"], | |
"Hindi": ["hi", "en"], | |
"Chinese": ["ch_sim", "en"], | |
"Japanese": ["ja", "en"], | |
"Korean": ["ko", "en"], | |
"Russian": ["ru", "en"], | |
"French": ["fr", "en"] | |
} | |
# Cache for OCR readers to avoid reloading | |
self.ocr_readers = {} | |
def initialize_models(self): | |
"""Initialize all AI models with proper error handling""" | |
try: | |
# Load OCR for most common languages eagerly | |
logger.info("Loading initial OCR readers...") | |
self.ocr_readers = { | |
"Arabic": easyocr.Reader(['ar', 'en'], gpu=using_gpu, verbose=False), | |
"Hindi": easyocr.Reader(['hi', 'en'], gpu=using_gpu, verbose=False) | |
} | |
# Load translation model | |
logger.info("Loading translation model...") | |
self.translator = pipeline( | |
"translation", | |
model="Helsinki-NLP/opus-mt-mul-en", | |
device=0 if using_gpu else -1 | |
) | |
# Check if timm is installed for object detection | |
try: | |
import timm | |
logger.info("Loading object detection model...") | |
self.detector = pipeline( | |
"object-detection", | |
model="facebook/detr-resnet-50", | |
device=0 if using_gpu else -1 | |
) | |
except ImportError: | |
logger.warning("timm library not found, using YOLOv5 as fallback for object detection") | |
try: | |
import torch | |
# Use YOLOv5 as a fallback (it has fewer dependencies) | |
self.detector = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) | |
# Make detector interface compatible with transformers pipeline | |
self._original_detect = self.detector | |
self.detector = self._yolo_detector_wrapper | |
except Exception as e2: | |
logger.error(f"Fallback object detection also failed: {str(e2)}") | |
logger.warning("Object detection will be disabled") | |
self.detector = self._dummy_detector | |
logger.info("All models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error initializing models: {str(e)}") | |
raise RuntimeError(f"Failed to initialize AI models: {str(e)}") | |
def _yolo_detector_wrapper(self, image): | |
"""Wrapper to make YOLOv5 output compatible with transformers pipeline format""" | |
results = self._original_detect(image) | |
detections = [] | |
# Convert YOLOv5 results to transformers pipeline format | |
for i, (x1, y1, x2, y2, conf, cls) in enumerate(results.xyxy[0]): | |
detections.append({ | |
'score': float(conf), | |
'label': results.names[int(cls)], | |
'box': { | |
'xmin': int(x1), | |
'ymin': int(y1), | |
'xmax': int(x2), | |
'ymax': int(y2) | |
} | |
}) | |
return detections | |
def _dummy_detector(self, image): | |
"""Dummy detector when no object detection is available""" | |
logger.warning("Object detection is disabled due to missing dependencies") | |
return [] | |
def get_ocr_reader(self, language_choice): | |
"""Get or create appropriate OCR reader based on language choice""" | |
if language_choice in self.ocr_readers: | |
return self.ocr_readers[language_choice] | |
# Create new reader if not already loaded | |
if language_choice in self.supported_languages: | |
logger.info(f"Loading new OCR reader for {language_choice}...") | |
reader = easyocr.Reader( | |
self.supported_languages[language_choice], | |
gpu=using_gpu, | |
verbose=False | |
) | |
# Cache for future use | |
self.ocr_readers[language_choice] = reader | |
return reader | |
else: | |
# Fallback to general reader | |
logger.warning(f"Unsupported language: {language_choice}, using default") | |
if "Other" not in self.ocr_readers: | |
self.ocr_readers["Other"] = easyocr.Reader(['en', 'fr', 'ru'], gpu=using_gpu, verbose=False) | |
return self.ocr_readers["Other"] | |
def extract_text(self, image, language_choice): | |
"""Extract text from image using OCR""" | |
start_time = time.time() | |
reader = self.get_ocr_reader(language_choice) | |
try: | |
text_results = reader.readtext(image) | |
extracted_texts = [res[1] for res in text_results] | |
extracted_text = " ".join(extracted_texts) | |
# Get bounding boxes for visualization | |
text_boxes = [(res[0], res[1]) for res in text_results] | |
logger.info(f"OCR completed in {time.time() - start_time:.2f} seconds") | |
return extracted_text, text_boxes | |
except Exception as e: | |
logger.error(f"OCR error: {str(e)}") | |
return "Error during text extraction.", [] | |
def translate_text(self, text): | |
"""Translate extracted text to English""" | |
if not text or text == "No text detected." or text.strip() == "": | |
return "No text to translate." | |
try: | |
translation = self.translator(text)[0]['translation_text'] | |
return translation | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}") | |
return f"Translation error: {str(e)}" | |
def detect_objects(self, image_pil): | |
"""Detect objects in the image""" | |
try: | |
detections = self.detector(image_pil) | |
return detections | |
except Exception as e: | |
logger.error(f"Object detection error: {str(e)}") | |
return [] | |
def visualize_results(self, image, text_boxes, detections): | |
"""Create visualization with detected objects and text""" | |
image_draw = image.copy().convert("RGB") | |
draw = ImageDraw.Draw(image_draw) | |
# Try to load a better font, fall back to default if necessary | |
try: | |
font = ImageFont.truetype("Arial", 12) | |
except IOError: | |
font = ImageFont.load_default() | |
# Draw text bounding boxes | |
for box, text in text_boxes: | |
# Convert box points to rectangle coordinates | |
points = np.array(box).astype(np.int32) | |
draw.polygon([tuple(p) for p in points], outline="blue", width=2) | |
# Add small text label | |
draw.text((points[0][0], points[0][1] - 10), "Text", fill="blue", font=font) | |
# Draw object detection boxes | |
for det in detections: | |
box = det['box'] | |
label = det['label'] | |
score = det['score'] | |
if score > 0.6: # Higher confidence threshold | |
draw.rectangle( | |
[box['xmin'], box['ymin'], box['xmax'], box['ymax']], | |
outline="red", | |
width=3 | |
) | |
label_text = f"{label} ({score:.2f})" | |
draw.text((box['xmin'], box['ymin'] - 15), label_text, fill="red", font=font) | |
return image_draw | |
def process_image(self, image, language_choice): | |
"""Main processing pipeline""" | |
if image is None: | |
return ( | |
None, | |
"No image provided. Please upload an image.", | |
"No image to process." | |
) | |
# Convert to numpy array if needed | |
if not isinstance(image, np.ndarray): | |
image = np.array(image) | |
# Create PIL image for visualization | |
image_pil = Image.fromarray(image) | |
# Extract text | |
extracted_text, text_boxes = self.extract_text(image, language_choice) | |
# Translate text | |
translation = self.translate_text(extracted_text) | |
# Detect objects | |
detections = self.detect_objects(image_pil) | |
# Create visualization | |
result_image = self.visualize_results(image_pil, text_boxes, detections) | |
return result_image, extracted_text, translation | |
# Create system instance | |
smart_glasses = SmartGlassesSystem() | |
def create_interface(): | |
"""Create and configure the Gradio interface""" | |
# Custom CSS for better appearance | |
custom_css = """ | |
.gradio-container { | |
background-color: #f0f4f8; | |
} | |
.output-image { | |
border: 2px solid #2c3e50; | |
border-radius: 5px; | |
} | |
""" | |
# Create interface | |
with gr.Blocks(css=custom_css, title="π¨ Police Smart Glasses - AI Demo") as iface: | |
gr.Markdown(""" | |
# π¨ Police Smart Glasses - Advanced AI Demo | |
This system demonstrates real-time text recognition, translation, and object detection capabilities | |
for law enforcement smart glasses technology. | |
### Instructions: | |
1. Upload an image containing text in the selected language | |
2. Choose the primary language in the image | |
3. View the detection results, extracted text, and English translation | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input components | |
input_image = gr.Image( | |
type="pil", | |
label="Upload an Image (e.g., Signs, Documents, License Plates)" | |
) | |
language_choice = gr.Dropdown( | |
choices=list(smart_glasses.supported_languages.keys()) + ["Other"], | |
value="Arabic", | |
label="Select Primary Language in Image" | |
) | |
process_btn = gr.Button("Process Image", variant="primary") | |
with gr.Column(scale=1): | |
# Output components | |
output_image = gr.Image(label="Analysis Results") | |
extracted_text = gr.Textbox(label="Extracted Text") | |
translated_text = gr.Textbox(label="Translated Text (English)") | |
# Set up processing function | |
process_btn.click( | |
fn=smart_glasses.process_image, | |
inputs=[input_image, language_choice], | |
outputs=[output_image, extracted_text, translated_text] | |
) | |
# Examples for testing | |
gr.Examples( | |
examples=[ | |
["examples/arabic_sign.jpg", "Arabic"], | |
["examples/hindi_text.jpg", "Hindi"], | |
["examples/russian_document.jpg", "Russian"] | |
], | |
inputs=[input_image, language_choice] | |
) | |
# System information | |
with gr.Accordion("System Information", open=False): | |
gr.Markdown(f""" | |
- **Device**: {'GPU' if using_gpu else 'CPU'} | |
- **Supported Languages**: {', '.join(smart_glasses.supported_languages.keys())} | |
- **AI Models**: | |
- OCR: EasyOCR | |
- Translation: Helsinki-NLP/opus-mt-mul-en | |
- Object Detection: facebook/detr-resnet-50 | |
""") | |
return iface | |
if __name__ == "__main__": | |
# Create and launch interface | |
iface = create_interface() | |
iface.launch( | |
share=True, # Enable sharing | |
debug=True # Show debugging information | |
) |