winamnd commited on
Commit
c623da2
·
verified ·
1 Parent(s): a6b2047

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -68
app.py CHANGED
@@ -1,94 +1,87 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
- import keras_ocr
5
  import cv2
 
6
  import easyocr
 
7
  from paddleocr import PaddleOCR
8
- import numpy as np
9
 
10
- # Load tokenizer
11
- tokenizer = DistilBertTokenizer.from_pretrained("./distilbert_spam_model")
12
 
13
- # Load model
14
- model = DistilBertForSequenceClassification.from_pretrained("./distilbert_spam_model")
15
- model.load_state_dict(torch.load("./distilbert_spam_model/model.pth", map_location=torch.device('cpu')))
16
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- """
19
- Paddle OCR
20
- """
21
  def ocr_with_paddle(img):
22
- finaltext = ''
23
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
24
  result = ocr.ocr(img)
25
-
26
- for i in range(len(result[0])):
27
- text = result[0][i][1][0]
28
- finaltext += ' ' + text
29
- return finaltext
30
 
31
- """
32
- Keras OCR
33
- """
34
  def ocr_with_keras(img):
35
- output_text = ''
36
  pipeline = keras_ocr.pipeline.Pipeline()
37
  images = [keras_ocr.tools.read(img)]
38
  predictions = pipeline.recognize(images)
39
-
40
- for text, _ in predictions[0]:
41
- output_text += ' ' + text
42
- return output_text
43
 
44
- """
45
- Easy OCR
46
- """
47
  def ocr_with_easy(img):
48
  reader = easyocr.Reader(['en'])
49
- bounds = reader.readtext(img, paragraph=True, detail=0)
50
- return ' '.join(bounds)
51
 
52
- """
53
- Generate OCR and classify spam
54
- """
55
- def generate_ocr_and_classify(Method, img):
56
- if img is None:
57
- raise gr.Error("Please upload an image!")
58
-
59
- # Perform OCR
60
- text_output = ''
61
- if Method == 'EasyOCR':
62
- text_output = ocr_with_easy(img)
63
- elif Method == 'KerasOCR':
64
- text_output = ocr_with_keras(img)
65
- elif Method == 'PaddleOCR':
66
- text_output = ocr_with_paddle(img)
67
 
68
- # Classify extracted text
69
- inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
70
- with torch.no_grad():
71
- outputs = model(**inputs)
72
-
73
- prediction = torch.argmax(outputs.logits, dim=1).item()
74
- classification = "Spam" if prediction == 1 else "Not Spam"
 
 
 
 
75
 
76
- return text_output, classification
 
77
 
78
- """
79
- Create user interface
80
- """
81
- image = gr.Image()
82
- method = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
83
- output_text = gr.Textbox(label="Extracted Text")
84
- output_label = gr.Label(label="Classification")
85
 
86
- demo = gr.Interface(
87
- generate_ocr_and_classify,
88
- [method, image],
89
- [output_text, output_label],
90
- title="OCR & Spam Classification",
91
- description="Upload an image with text, extract the text using OCR, and classify whether it is spam or not.",
 
92
  )
93
 
94
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
 
4
  import cv2
5
+ import numpy as np
6
  import easyocr
7
+ import keras_ocr
8
  from paddleocr import PaddleOCR
9
+ import os
10
 
11
+ # Ensure model config exists
12
+ MODEL_PATH = "./distilbert_spam_model"
13
 
14
+ if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
15
+ print("config.json not found. Generating default configuration...")
16
+ config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2)
17
+ config.save_pretrained(MODEL_PATH)
18
+
19
+ # Load tokenizer and model
20
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
21
+ model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
22
+
23
+ # Define Spam Classification Function
24
+ def classify_text(text):
25
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ prediction = torch.argmax(logits, dim=-1).item()
30
+ return "Spam" if prediction == 1 else "Not Spam"
31
 
32
+ # OCR Methods
 
 
33
  def ocr_with_paddle(img):
 
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
35
  result = ocr.ocr(img)
36
+ extracted_text = ' '.join([entry[1][0] for entry in result[0]])
37
+ return extracted_text
 
 
 
38
 
 
 
 
39
  def ocr_with_keras(img):
 
40
  pipeline = keras_ocr.pipeline.Pipeline()
41
  images = [keras_ocr.tools.read(img)]
42
  predictions = pipeline.recognize(images)
43
+ extracted_text = ' '.join([text for text, _ in predictions[0]])
44
+ return extracted_text
 
 
45
 
 
 
 
46
  def ocr_with_easy(img):
47
  reader = easyocr.Reader(['en'])
48
+ results = reader.readtext(img, detail=0)
49
+ return ' '.join(results)
50
 
51
+ # OCR + Spam Detection
52
+ def process_image(ocr_method, image):
53
+ if image is None:
54
+ return "Error: No image uploaded."
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ if ocr_method == "PaddleOCR":
57
+ extracted_text = ocr_with_paddle(image)
58
+ elif ocr_method == "KerasOCR":
59
+ extracted_text = ocr_with_keras(image)
60
+ elif ocr_method == "EasyOCR":
61
+ extracted_text = ocr_with_easy(image)
62
+ else:
63
+ return "Invalid OCR method."
64
+
65
+ if not extracted_text.strip():
66
+ return "No text detected in the image."
67
 
68
+ classification = classify_text(extracted_text)
69
+ return f"Extracted Text: {extracted_text}\n\nClassification: {classification}"
70
 
71
+ # Gradio UI
72
+ image_input = gr.Image(type="numpy")
73
+ ocr_method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="OCR Method")
74
+ output_text = gr.Textbox(label="OCR & Classification Result")
 
 
 
75
 
76
+ interface = gr.Interface(
77
+ fn=process_image,
78
+ inputs=[ocr_method_input, image_input],
79
+ outputs=output_text,
80
+ title="OCR + Spam Detection",
81
+ description="Upload an image with text, extract the text using OCR, and classify it as Spam or Not Spam using DistilBERT.",
82
+ theme="compact"
83
  )
84
 
85
+ # Launch app
86
+ if __name__ == "__main__":
87
+ interface.launch()