Spaces:
Running
Running
File size: 5,033 Bytes
19736cf 025580f 104c39e 19736cf c623da2 19736cf c623da2 19736cf 104c39e deb409e b2a51c8 025580f 104c39e c623da2 025580f b2a51c8 104c39e 14299e0 a92b56c 104c39e 14299e0 104c39e c623da2 b2a51c8 deb409e b2a51c8 19736cf b2a51c8 c4cf574 19736cf 025580f b2a51c8 c4cf574 19736cf 104c39e 025580f b2a51c8 025580f 7dedea0 4639dba bf26c19 4639dba 104c39e bf26c19 104c39e 7dedea0 104c39e 7dedea0 deb409e 7dedea0 2a250f6 7dedea0 a92b56c 17f2d95 104c39e 7dedea0 2a250f6 7dedea0 2a250f6 7dedea0 104c39e 7dedea0 104c39e e08cb5e 4639dba b2a51c8 4639dba b2a51c8 4639dba 0ab07a8 4ee3a20 a7de18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from PIL import Image
import pytesseract
import io
# Import save function
from save_results import save_results_to_repo
# Paths
MODEL_PATH = "./distilbert_spam_model"
# Ensure LLM Model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Ensure model is in evaluation mode
model.eval()
# Function to process image for OCR
def preprocess_image(image):
"""Convert PIL image to OpenCV format (NumPy array)"""
return np.array(image)
# OCR Functions (same as ocr-api)
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
extracted_text, confidences = [], []
for line in result[0]:
text, confidence = line[1]
extracted_text.append(text)
confidences.append(confidence)
return extracted_text, confidences
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
extracted_text = [text for text, confidence in predictions[0]]
confidences = [confidence for text, confidence in predictions[0]]
return extracted_text, confidences
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image)
extracted_text = [text for _, text, confidence in results]
confidences = [confidence for _, text, confidence in results]
return extracted_text, confidences
def ocr_with_tesseract(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
extracted_text = pytesseract.image_to_string(gray_image).split("\n")
extracted_text = [line.strip() for line in extracted_text if line.strip()]
confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores
return extracted_text, confidences
# OCR & Classification Function
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR & Classification Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Preprocess text properly
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
spam_prob = probs[0][1].item() # Probability of Spam
# Adjust classification based on threshold (better than argmax)
label = "Spam" if spam_prob > 0.5 else "Not Spam"
# Save results using external function
save_results_to_repo(text_output, label)
return text_output, label
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "__main__":
demo.launch()
|