winamnd commited on
Commit
a7de18e
·
verified ·
1 Parent(s): f56fe40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  import os
5
  import cv2
6
  import numpy as np
@@ -9,10 +10,10 @@ import keras_ocr
9
  from paddleocr import PaddleOCR
10
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
11
  import torch.nn.functional as F
 
12
 
13
  # Paths
14
  MODEL_PATH = "./distilbert_spam_model"
15
- RESULTS_JSON = "results.json"
16
 
17
  # Ensure model exists
18
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
@@ -26,10 +27,10 @@ else:
26
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
27
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
28
 
29
- # Ensure model is in evaluation mode
30
  model.eval()
31
 
32
- # OCR Functions
33
  def ocr_with_paddle(img):
34
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
35
  result = ocr.ocr(img)
@@ -47,22 +48,6 @@ def ocr_with_easy(img):
47
  results = reader.readtext(gray_image, detail=0)
48
  return ' '.join(results)
49
 
50
- # Save results to JSON
51
- def save_to_json(text, label):
52
- data = {"text": text, "classification": label}
53
- if os.path.exists(RESULTS_JSON):
54
- with open(RESULTS_JSON, "r") as file:
55
- try:
56
- results = json.load(file)
57
- except json.JSONDecodeError:
58
- results = []
59
- else:
60
- results = []
61
-
62
- results.append(data)
63
- with open(RESULTS_JSON, "w") as file:
64
- json.dump(results, file, indent=4)
65
-
66
  # OCR & Classification Function
67
  def generate_ocr(method, img):
68
  if img is None:
@@ -79,6 +64,7 @@ def generate_ocr(method, img):
79
  else: # KerasOCR
80
  text_output = ocr_with_keras(img)
81
 
 
82
  text_output = text_output.strip()
83
  if len(text_output) == 0:
84
  return "No text detected!", "Cannot classify"
@@ -89,13 +75,14 @@ def generate_ocr(method, img):
89
  # Perform inference
90
  with torch.no_grad():
91
  outputs = model(**inputs)
92
- probs = F.softmax(outputs.logits, dim=1)
93
- spam_prob = probs[0][1].item()
94
 
 
95
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
96
 
97
- # Save results to JSON
98
- save_to_json(text_output, label)
99
 
100
  return text_output, label
101
 
@@ -115,5 +102,5 @@ demo = gr.Interface(
115
  )
116
 
117
  # Launch App
118
- if __name__ == "_main_":
119
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import json
4
+ import csv
5
  import os
6
  import cv2
7
  import numpy as np
 
10
  from paddleocr import PaddleOCR
11
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
  import torch.nn.functional as F
13
+ from save_results import save_results_to_repo # Import the save function
14
 
15
  # Paths
16
  MODEL_PATH = "./distilbert_spam_model"
 
17
 
18
  # Ensure model exists
19
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
 
27
  model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
28
  tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
29
 
30
+ # 🔹 Ensure model is in evaluation mode
31
  model.eval()
32
 
33
+ # OCR Functions (No changes here)
34
  def ocr_with_paddle(img):
35
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
36
  result = ocr.ocr(img)
 
48
  results = reader.readtext(gray_image, detail=0)
49
  return ' '.join(results)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # OCR & Classification Function
52
  def generate_ocr(method, img):
53
  if img is None:
 
64
  else: # KerasOCR
65
  text_output = ocr_with_keras(img)
66
 
67
+ # Preprocess text properly
68
  text_output = text_output.strip()
69
  if len(text_output) == 0:
70
  return "No text detected!", "Cannot classify"
 
75
  # Perform inference
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
+ probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
79
+ spam_prob = probs[0][1].item() # Probability of Spam
80
 
81
+ # Adjust classification based on threshold (better than argmax)
82
  label = "Spam" if spam_prob > 0.5 else "Not Spam"
83
 
84
+ # Save results using external function
85
+ save_results_to_repo(text_output, label)
86
 
87
  return text_output, label
88
 
 
102
  )
103
 
104
  # Launch App
105
+ if __name__ == "__main__":
106
+ demo.launch()