winamnd commited on
Commit
4639dba
ยท
verified ยท
1 Parent(s): dad8a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -38
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  import os
5
  import cv2
6
  import numpy as np
@@ -9,7 +10,7 @@ import keras_ocr
9
  from paddleocr import PaddleOCR
10
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
11
  import torch.nn.functional as F
12
- from save_results import save_results_to_repo
13
 
14
  # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
@@ -26,10 +27,10 @@ else:
26
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
27
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
28
 
29
- # Set model to evaluation mode
30
  model.eval()
31
 
32
- # OCR Methods
33
  def ocr_with_paddle(img):
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
35
  result = ocr.ocr(img)
@@ -47,10 +48,10 @@ def ocr_with_easy(img):
47
  results = reader.readtext(gray_image, detail=0)
48
  return ' '.join(results)
49
 
50
- # OCR Extraction Function
51
- def extract_text(method, img):
52
  if img is None:
53
- return "Error: Please upload an image!", ""
54
 
55
  # Convert PIL Image to OpenCV format
56
  img = np.array(img)
@@ -63,52 +64,42 @@ def extract_text(method, img):
63
  else: # KerasOCR
64
  text_output = ocr_with_keras(img)
65
 
66
- # Clean extracted text
67
  text_output = text_output.strip()
68
-
69
  if len(text_output) == 0:
70
- return "No text detected!", ""
71
 
72
- return text_output, ""
73
-
74
- # Classification Function
75
- def classify_text(text_output):
76
- if text_output.strip() in ["No text detected!", "Error: Please upload an image!"]:
77
- return text_output, "Cannot classify"
78
-
79
- # Tokenize text
80
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
81
 
82
- # Model inference
83
  with torch.no_grad():
84
  outputs = model(**inputs)
85
- probs = F.softmax(outputs.logits, dim=1)
86
- prediction = torch.argmax(probs, dim=1).item()
87
 
88
- label_map = {0: "Not Spam", 1: "Spam"}
89
- label = label_map[prediction]
90
 
91
- # Save results automatically
92
  save_results_to_repo(text_output, label)
93
 
94
  return text_output, label
95
 
96
  # Gradio Interface
97
- with gr.Blocks() as demo:
98
- gr.Markdown("## OCR Spam Classifier")
99
-
100
- method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method")
101
- image_input = gr.Image(label="Upload Image")
102
-
103
- extract_button = gr.Button("Submit")
104
- classify_button = gr.Button("Classify")
105
-
106
- output_text = gr.Textbox(label="Extracted Text", interactive=True)
107
- output_label = gr.Textbox(label="Spam Classification", interactive=False)
108
-
109
- # Button Click Bindings
110
- extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text, output_label])
111
- classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label])
112
 
113
  # Launch App
114
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import json
4
+ import csv
5
  import os
6
  import cv2
7
  import numpy as np
 
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
  import torch.nn.functional as F
13
+ from save_results import save_results_to_repo # Import the save function
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
 
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
+ # ๐Ÿ”น Ensure model is in evaluation mode
31
  model.eval()
32
 
33
+ # OCR Functions (No changes here)
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
 
48
  results = reader.readtext(gray_image, detail=0)
49
  return ' '.join(results)
50
 
51
+ # OCR & Classification Function
52
+ def generate_ocr(method, img):
53
  if img is None:
54
+ raise gr.Error("Please upload an image!")
55
 
56
  # Convert PIL Image to OpenCV format
57
  img = np.array(img)
 
64
  else: # KerasOCR
65
  text_output = ocr_with_keras(img)
66
 
67
+ # ๐Ÿ”น Preprocess text properly
68
  text_output = text_output.strip()
 
69
  if len(text_output) == 0:
70
+ return "No text detected!", "Cannot classify"
71
 
72
+ # ๐Ÿ”น Tokenize text
 
 
 
 
 
 
 
73
  inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
74
 
75
+ # ๐Ÿ”น Perform inference
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
+ probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
79
+ spam_prob = probs[0][1].item() # Probability of Spam
80
 
81
+ # ๐Ÿ”น Adjust classification based on threshold (better than argmax)
82
+ label = "Spam" if spam_prob > 0.5 else "Not Spam"
83
 
84
+ # ๐Ÿ”น Save results using external function
85
  save_results_to_repo(text_output, label)
86
 
87
  return text_output, label
88
 
89
  # Gradio Interface
90
+ image_input = gr.Image()
91
+ method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
92
+ output_text = gr.Textbox(label="Extracted Text")
93
+ output_label = gr.Textbox(label="Spam Classification")
94
+
95
+ demo = gr.Interface(
96
+ generate_ocr,
97
+ inputs=[method_input, image_input],
98
+ outputs=[output_text, output_label],
99
+ title="OCR Spam Classifier",
100
+ description="Upload an image, extract text, and classify it as Spam or Not Spam.",
101
+ theme="compact",
102
+ )
 
 
103
 
104
  # Launch App
105
  if __name__ == "__main__":