File size: 5,033 Bytes
19736cf
025580f
104c39e
 
19736cf
c623da2
19736cf
c623da2
19736cf
104c39e
deb409e
b2a51c8
 
 
 
 
 
025580f
104c39e
c623da2
025580f
b2a51c8
104c39e
14299e0
a92b56c
104c39e
 
 
14299e0
104c39e
 
 
c623da2
b2a51c8
deb409e
 
b2a51c8
 
 
 
 
 
19736cf
 
 
b2a51c8
 
 
 
 
 
c4cf574
19736cf
025580f
 
 
b2a51c8
 
 
c4cf574
19736cf
104c39e
025580f
b2a51c8
 
 
 
 
 
 
 
 
 
 
025580f
7dedea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4639dba
bf26c19
 
4639dba
104c39e
 
bf26c19
104c39e
 
 
7dedea0
104c39e
7dedea0
 
 
deb409e
7dedea0
 
 
 
2a250f6
7dedea0
 
a92b56c
17f2d95
104c39e
 
7dedea0
 
2a250f6
7dedea0
 
2a250f6
7dedea0
 
104c39e
7dedea0
104c39e
e08cb5e
4639dba
b2a51c8
4639dba
 
 
 
 
 
 
 
b2a51c8
4639dba
 
0ab07a8
4ee3a20
a7de18e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from PIL import Image
import pytesseract
import io

# Import save function
from save_results import save_results_to_repo  

# Paths
MODEL_PATH = "./distilbert_spam_model"

# Ensure LLM Model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
    print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    model.save_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    tokenizer.save_pretrained(MODEL_PATH)
    print(f"✅ Model saved at {MODEL_PATH}.")
else:
    model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)

# Ensure model is in evaluation mode
model.eval()

# Function to process image for OCR
def preprocess_image(image):
    """Convert PIL image to OpenCV format (NumPy array)"""
    return np.array(image)

# OCR Functions (same as ocr-api)
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    extracted_text, confidences = [], []
    for line in result[0]:
        text, confidence = line[1]
        extracted_text.append(text)
        confidences.append(confidence)
    return extracted_text, confidences

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    extracted_text = [text for text, confidence in predictions[0]]
    confidences = [confidence for text, confidence in predictions[0]]
    return extracted_text, confidences

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image)
    extracted_text = [text for _, text, confidence in results]
    confidences = [confidence for _, text, confidence in results]
    return extracted_text, confidences

def ocr_with_tesseract(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    extracted_text = pytesseract.image_to_string(gray_image).split("\n")
    extracted_text = [line.strip() for line in extracted_text if line.strip()]
    confidences = [1.0] * len(extracted_text)  # Tesseract doesn't return confidence scores
    return extracted_text, confidences

# OCR & Classification Function
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    return ' '.join([item[1][0] for item in result[0]])

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    return ' '.join([text for text, _ in predictions[0]])

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image, detail=0)
    return ' '.join(results)

# OCR & Classification Function
def generate_ocr(method, img):
    if img is None:
        raise gr.Error("Please upload an image!")

    # Convert PIL Image to OpenCV format
    img = np.array(img)

    # Select OCR method
    if method == "PaddleOCR":
        text_output = ocr_with_paddle(img)
    elif method == "EasyOCR":
        text_output = ocr_with_easy(img)
    else:  # KerasOCR
        text_output = ocr_with_keras(img)

    # Preprocess text properly
    text_output = text_output.strip()
    if len(text_output) == 0:
        return "No text detected!", "Cannot classify"

    # Tokenize text
    inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)  # Convert logits to probabilities
        spam_prob = probs[0][1].item()  # Probability of Spam

    # Adjust classification based on threshold (better than argmax)
    label = "Spam" if spam_prob > 0.5 else "Not Spam"

    # Save results using external function
    save_results_to_repo(text_output, label)

    return text_output, label

# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")

demo = gr.Interface(
    generate_ocr,
    inputs=[method_input, image_input],
    outputs=[output_text, output_label],
    title="OCR Spam Classifier",
    description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.",
    theme="compact",
)

# Launch App
if __name__ == "__main__":
    demo.launch()