Spaces:
Running
Running
import gradio as gr | |
import torch | |
import json | |
import os | |
import cv2 | |
import numpy as np | |
import easyocr | |
import keras_ocr | |
from paddleocr import PaddleOCR | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import torch.nn.functional as F | |
from PIL import Image | |
import pytesseract | |
import io | |
# Import save function | |
from save_results import save_results_to_repo | |
# Paths | |
MODEL_PATH = "./distilbert_spam_model" | |
# Ensure LLM Model exists | |
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")): | |
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...") | |
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) | |
model.save_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
tokenizer.save_pretrained(MODEL_PATH) | |
print(f"✅ Model saved at {MODEL_PATH}.") | |
else: | |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) | |
# Ensure model is in evaluation mode | |
model.eval() | |
# Function to process image for OCR | |
def preprocess_image(image): | |
"""Convert PIL image to OpenCV format (NumPy array)""" | |
return np.array(image) | |
# OCR Functions (same as ocr-api) | |
def ocr_with_paddle(img): | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
result = ocr.ocr(img) | |
extracted_text, confidences = [], [] | |
for line in result[0]: | |
text, confidence = line[1] | |
extracted_text.append(text) | |
confidences.append(confidence) | |
return extracted_text, confidences | |
def ocr_with_keras(img): | |
pipeline = keras_ocr.pipeline.Pipeline() | |
images = [keras_ocr.tools.read(img)] | |
predictions = pipeline.recognize(images) | |
extracted_text = [text for text, confidence in predictions[0]] | |
confidences = [confidence for text, confidence in predictions[0]] | |
return extracted_text, confidences | |
def ocr_with_easy(img): | |
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
reader = easyocr.Reader(['en']) | |
results = reader.readtext(gray_image) | |
extracted_text = [text for _, text, confidence in results] | |
confidences = [confidence for _, text, confidence in results] | |
return extracted_text, confidences | |
def ocr_with_tesseract(img): | |
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
extracted_text = pytesseract.image_to_string(gray_image).split("\n") | |
extracted_text = [line.strip() for line in extracted_text if line.strip()] | |
confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores | |
return extracted_text, confidences | |
# OCR & Classification Function | |
def ocr_with_paddle(img): | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
result = ocr.ocr(img) | |
return ' '.join([item[1][0] for item in result[0]]) | |
def ocr_with_keras(img): | |
pipeline = keras_ocr.pipeline.Pipeline() | |
images = [keras_ocr.tools.read(img)] | |
predictions = pipeline.recognize(images) | |
return ' '.join([text for text, _ in predictions[0]]) | |
def ocr_with_easy(img): | |
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
reader = easyocr.Reader(['en']) | |
results = reader.readtext(gray_image, detail=0) | |
return ' '.join(results) | |
# OCR & Classification Function | |
def generate_ocr(method, img): | |
if img is None: | |
raise gr.Error("Please upload an image!") | |
# Convert PIL Image to OpenCV format | |
img = np.array(img) | |
# Select OCR method | |
if method == "PaddleOCR": | |
text_output = ocr_with_paddle(img) | |
elif method == "EasyOCR": | |
text_output = ocr_with_easy(img) | |
else: # KerasOCR | |
text_output = ocr_with_keras(img) | |
# Preprocess text properly | |
text_output = text_output.strip() | |
if len(text_output) == 0: | |
return "No text detected!", "Cannot classify" | |
# Tokenize text | |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities | |
spam_prob = probs[0][1].item() # Probability of Spam | |
# Adjust classification based on threshold (better than argmax) | |
label = "Spam" if spam_prob > 0.5 else "Not Spam" | |
# Save results using external function | |
save_results_to_repo(text_output, label) | |
return text_output, label | |
# Gradio Interface | |
image_input = gr.Image() | |
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR") | |
output_text = gr.Textbox(label="Extracted Text") | |
output_label = gr.Textbox(label="Spam Classification") | |
demo = gr.Interface( | |
generate_ocr, | |
inputs=[method_input, image_input], | |
outputs=[output_text, output_label], | |
title="OCR Spam Classifier", | |
description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.", | |
theme="compact", | |
) | |
# Launch App | |
if __name__ == "__main__": | |
demo.launch() | |