ocr-llm-test / app.py
winamnd's picture
Update app.py
7dedea0 verified
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from PIL import Image
import pytesseract
import io
# Import save function
from save_results import save_results_to_repo
# Paths
MODEL_PATH = "./distilbert_spam_model"
# Ensure LLM Model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Ensure model is in evaluation mode
model.eval()
# Function to process image for OCR
def preprocess_image(image):
"""Convert PIL image to OpenCV format (NumPy array)"""
return np.array(image)
# OCR Functions (same as ocr-api)
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
extracted_text, confidences = [], []
for line in result[0]:
text, confidence = line[1]
extracted_text.append(text)
confidences.append(confidence)
return extracted_text, confidences
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
extracted_text = [text for text, confidence in predictions[0]]
confidences = [confidence for text, confidence in predictions[0]]
return extracted_text, confidences
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image)
extracted_text = [text for _, text, confidence in results]
confidences = [confidence for _, text, confidence in results]
return extracted_text, confidences
def ocr_with_tesseract(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
extracted_text = pytesseract.image_to_string(gray_image).split("\n")
extracted_text = [line.strip() for line in extracted_text if line.strip()]
confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores
return extracted_text, confidences
# OCR & Classification Function
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR & Classification Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Preprocess text properly
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
spam_prob = probs[0][1].item() # Probability of Spam
# Adjust classification based on threshold (better than argmax)
label = "Spam" if spam_prob > 0.5 else "Not Spam"
# Save results using external function
save_results_to_repo(text_output, label)
return text_output, label
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "__main__":
demo.launch()