winamnd commited on
Commit
dad8a00
·
verified ·
1 Parent(s): e08cb5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -13,7 +13,6 @@ from save_results import save_results_to_repo
13
 
14
  # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
16
- RESULTS_JSON = "ocr_results.json"
17
 
18
  # Ensure model exists
19
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
@@ -51,7 +50,7 @@ def ocr_with_easy(img):
51
  # OCR Extraction Function
52
  def extract_text(method, img):
53
  if img is None:
54
- raise gr.Error("Please upload an image!")
55
 
56
  # Convert PIL Image to OpenCV format
57
  img = np.array(img)
@@ -68,13 +67,13 @@ def extract_text(method, img):
68
  text_output = text_output.strip()
69
 
70
  if len(text_output) == 0:
71
- return "No text detected!"
72
 
73
- return text_output
74
 
75
  # Classification Function
76
  def classify_text(text_output):
77
- if text_output.strip() == "No text detected!":
78
  return text_output, "Cannot classify"
79
 
80
  # Tokenize text
@@ -95,25 +94,20 @@ def classify_text(text_output):
95
  return text_output, label
96
 
97
  # Gradio Interface
98
- image_input = gr.Image()
99
- method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method")
100
- output_text = gr.Textbox(label="Extracted Text", interactive=True)
101
- output_label = gr.Textbox(label="Spam Classification", interactive=False)
102
-
103
- # Define UI layout
104
  with gr.Blocks() as demo:
105
  gr.Markdown("## OCR Spam Classifier")
106
-
107
- with gr.Row():
108
- method_input.render()
109
-
110
- with gr.Row():
111
- image_input.render()
112
-
113
  extract_button = gr.Button("Submit")
114
  classify_button = gr.Button("Classify")
115
-
116
- extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text])
 
 
 
 
117
  classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label])
118
 
119
  # Launch App
 
13
 
14
  # Paths
15
  MODEL_PATH = "./distilbert_spam_model"
 
16
 
17
  # Ensure model exists
18
  if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
 
50
  # OCR Extraction Function
51
  def extract_text(method, img):
52
  if img is None:
53
+ return "Error: Please upload an image!", ""
54
 
55
  # Convert PIL Image to OpenCV format
56
  img = np.array(img)
 
67
  text_output = text_output.strip()
68
 
69
  if len(text_output) == 0:
70
+ return "No text detected!", ""
71
 
72
+ return text_output, ""
73
 
74
  # Classification Function
75
  def classify_text(text_output):
76
+ if text_output.strip() in ["No text detected!", "Error: Please upload an image!"]:
77
  return text_output, "Cannot classify"
78
 
79
  # Tokenize text
 
94
  return text_output, label
95
 
96
  # Gradio Interface
 
 
 
 
 
 
97
  with gr.Blocks() as demo:
98
  gr.Markdown("## OCR Spam Classifier")
99
+
100
+ method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="Choose OCR Method")
101
+ image_input = gr.Image(label="Upload Image")
102
+
 
 
 
103
  extract_button = gr.Button("Submit")
104
  classify_button = gr.Button("Classify")
105
+
106
+ output_text = gr.Textbox(label="Extracted Text", interactive=True)
107
+ output_label = gr.Textbox(label="Spam Classification", interactive=False)
108
+
109
+ # Button Click Bindings
110
+ extract_button.click(fn=extract_text, inputs=[method_input, image_input], outputs=[output_text, output_label])
111
  classify_button.click(fn=classify_text, inputs=[output_text], outputs=[output_text, output_label])
112
 
113
  # Launch App