nehakothari commited on
Commit
af76571
·
verified ·
1 Parent(s): f721551

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+ import gradio as gr
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+ import pandas as pd
9
+ import pytesseract
10
+ import cv2
11
+ import pymssql
12
+
13
+ # Hardcoded Hugging Face token and SQL server IP address
14
+
15
+ SERVER_IP = "35.227.148.156"
16
+
17
+ # Install dependencies in smaller chunks to avoid memory issues
18
+ def install_dependencies():
19
+ dependency_groups = [
20
+ ["pip==23.3.1", "setuptools", "wheel"],
21
+ ["pytesseract"],
22
+ ["torch==2.1.0+cpu", "torchvision==0.16.0+cpu", "torchaudio==2.1.0+cpu"],
23
+ ["transformers==4.38.2", "auto-gptq==0.7.1", "autoawq==0.2.8"],
24
+ ["qwen_vl_utils==0.0.8", "gradio==4.27.0"],
25
+ ["pyodbc", "sqlalchemy", "azure-storage-blob", "pymssql", "pandas", "opencv-python"]
26
+ ]
27
+
28
+ for group in dependency_groups:
29
+ for package in group:
30
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package], stdout=sys.stdout, stderr=sys.stderr)
31
+ print(f"Installed {package}")
32
+
33
+ install_dependencies()
34
+
35
+ # Install system dependencies (executed separately to avoid timeout issues)
36
+ def install_system_dependencies():
37
+ commands = [
38
+ "apt-get update",
39
+ "apt-get install -y unixodbc-dev tesseract-ocr",
40
+ "ACCEPT_EULA=Y apt-get install -y msodbcsql17"
41
+ ]
42
+ for command in commands:
43
+ subprocess.run(command, shell=True, check=True)
44
+ print(f"Executed: {command}")
45
+
46
+ install_system_dependencies()
47
+
48
+ # Initialize model and processor with CPU mode
49
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
50
+ "Qwen/Qwen2-VL-2B-Instruct-AWQ",
51
+ torch_dtype="auto",
52
+ use_auth_token=HUGGINGFACE_API_KEY
53
+ )
54
+
55
+ # Force model to use CPU to avoid memory issues on Hugging Face Spaces
56
+ model.to("cpu")
57
+
58
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-AWQ", use_auth_token=HUGGINGFACE_API_KEY)
59
+
60
+ pytesseract.pytesseract_cmd = r'/usr/bin/tesseract'
61
+
62
+ # Function to preprocess the image for OCR
63
+ def preprocess_image(image_path):
64
+ image = cv2.imread(image_path)
65
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
66
+ _, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
67
+ return binary
68
+
69
+ # Function to extract text using OCR
70
+ def ocr_extract_text(image_path):
71
+ preprocessed_image = preprocess_image(image_path)
72
+ return pytesseract.image_to_string(preprocessed_image)
73
+
74
+ # Function to process image and extract details
75
+ def process_image(image_path):
76
+ try:
77
+ messages = [{
78
+ "role": "user",
79
+ "content": [
80
+ {"type": "image", "image": image_path},
81
+ {"type": "text", "text": (
82
+ "Extract the following details from the invoice:\n"
83
+ "- 'invoice_number'\n"
84
+ "- 'date'\n"
85
+ "- 'place'\n"
86
+ "- 'amount' (monetary value in the relevant currency)\n"
87
+ "- 'category' (based on the invoice type)"
88
+ )}
89
+ ]
90
+ }]
91
+
92
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+ image_inputs, video_inputs = process_vision_info(messages)
94
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
95
+ inputs = inputs.to(model.device)
96
+
97
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
98
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
99
+
100
+ return parse_details(output_text[0])
101
+
102
+ except Exception as e:
103
+ print(f"Model failed, falling back to OCR: {e}")
104
+ ocr_text = ocr_extract_text(image_path)
105
+ return parse_details(ocr_text)
106
+
107
+ # Function to parse details from extracted text
108
+ def parse_details(details):
109
+ parsed_data = {
110
+ "Invoice Number": None,
111
+ "Date": None,
112
+ "Place": None,
113
+ "Amount": None,
114
+ "Category": None
115
+ }
116
+
117
+ lines = details.split("\n")
118
+ for line in lines:
119
+ lower_line = line.lower()
120
+ if "invoice" in lower_line:
121
+ parsed_data["Invoice Number"] = line.split(":")[-1].strip()
122
+ elif "date" in lower_line:
123
+ parsed_data["Date"] = line.split(":")[-1].strip()
124
+ elif "place" in lower_line:
125
+ parsed_data["Place"] = line.split(":")[-1].strip()
126
+ elif any(keyword in lower_line for keyword in ["total", "amount", "cost"]):
127
+ parsed_data["Amount"] = line.split(":")[-1].strip()
128
+ else:
129
+ parsed_data["Category"] = "General"
130
+
131
+ return parsed_data
132
+
133
+ # Store extracted data in Azure SQL Database
134
+ def store_to_azure_sql(dataframe):
135
+ conn_str = (
136
+ f"Driver={{ODBC Driver 17 for SQL Server}};"
137
+ f"Server={SERVER_IP};"
138
+ "Database=Invoices;"
139
+ "UID=pio-admin;"
140
+ "PWD=Poctest123#;"
141
+ )
142
+ try:
143
+ with pymssql.connect(SERVER_IP, "pio-admin", "Poctest123#", "Invoices") as conn:
144
+ cursor = conn.cursor()
145
+ create_table_query = """
146
+ IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='Invoices' AND xtype='U')
147
+ CREATE TABLE Invoices (
148
+ InvoiceNumber NVARCHAR(255),
149
+ Date NVARCHAR(255),
150
+ Place NVARCHAR(255),
151
+ Amount NVARCHAR(255),
152
+ Category NVARCHAR(255)
153
+ )
154
+ """
155
+ cursor.execute(create_table_query)
156
+
157
+ for _, row in dataframe.iterrows():
158
+ insert_query = """
159
+ INSERT INTO Invoices (InvoiceNumber, Date, Place, Amount, Category)
160
+ VALUES (%s, %s, %s, %s, %s)
161
+ """
162
+ cursor.execute(insert_query, (row['Invoice Number'], row['Date'], row['Place'], row['Amount'], row['Category']))
163
+ conn.commit()
164
+ print("Data successfully stored in Azure SQL Database.")
165
+ except Exception as e:
166
+ print(f"Error storing data to database: {e}")
167
+
168
+ # Gradio interface for invoice processing
169
+ def gradio_interface(image_files):
170
+ results = []
171
+ for image_file in image_files:
172
+ details = process_image(image_file)
173
+ results.append(details)
174
+
175
+ df = pd.DataFrame(results)
176
+ store_to_azure_sql(df)
177
+ return df
178
+
179
+ # Launch Gradio interface
180
+ grpc_interface = gr.Interface(
181
+ fn=gradio_interface,
182
+ inputs=gr.Files(label="Upload Invoice Images"),
183
+ outputs=gr.Dataframe(interactive=True),
184
+ title="Invoice Extraction System",
185
+ )
186
+
187
+ if __name__ == "__main__":
188
+ grpc_interface.launch(share=True)