ikraamkb commited on
Commit
094c949
·
verified ·
1 Parent(s): eb457e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -188
app.py CHANGED
@@ -1,197 +1,19 @@
1
- import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
- import fitz # PyMuPDF
4
- import docx
5
- import pptx
6
- import openpyxl
7
- import re
8
- import nltk
9
- from nltk.tokenize import sent_tokenize
10
- import torch
11
- from fastapi import FastAPI, UploadFile, Form, File
12
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
13
- from fastapi.middleware.cors import CORSMiddleware
14
- from gtts import gTTS
15
- import tempfile
16
- import os
17
- import shutil
18
- import easyocr
19
- from fpdf import FPDF
20
- import datetime
21
- from concurrent.futures import ThreadPoolExecutor
22
- import hashlib
23
-
24
- nltk.download('punkt', quiet=True)
25
-
26
- app = FastAPI()
27
-
28
- app.add_middleware(
29
- CORSMiddleware,
30
- allow_origins=["*"],
31
- allow_credentials=True,
32
- allow_methods=["*"],
33
- allow_headers=["*"],
34
- )
35
-
36
- MODEL_NAME = "facebook/bart-large-cnn"
37
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
38
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
39
- model.eval()
40
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
41
-
42
- reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
43
- executor = ThreadPoolExecutor()
44
-
45
- summary_cache = {}
46
-
47
- def clean_text(text: str) -> str:
48
- text = re.sub(r'\s+', ' ', text)
49
- text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
50
- text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
51
- text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
52
- return text.strip()
53
-
54
- def extract_text(file_path: str, file_extension: str):
55
- try:
56
- if file_extension == "pdf":
57
- with fitz.open(file_path) as doc:
58
- text = "\n".join(page.get_text("text") for page in doc)
59
- if len(text.strip()) < 50:
60
- images = [page.get_pixmap() for page in doc]
61
- temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
62
- images[0].save(temp_img.name)
63
- ocr_result = reader.readtext(temp_img.name, detail=0)
64
- os.unlink(temp_img.name)
65
- text = "\n".join(ocr_result) if ocr_result else text
66
- return clean_text(text), ""
67
-
68
- elif file_extension == "docx":
69
- doc = docx.Document(file_path)
70
- return clean_text("\n".join(p.text for p in doc.paragraphs)), ""
71
-
72
- elif file_extension == "pptx":
73
- prs = pptx.Presentation(file_path)
74
- text = [shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]
75
- return clean_text("\n".join(text)), ""
76
-
77
- elif file_extension == "xlsx":
78
- wb = openpyxl.load_workbook(file_path, read_only=True)
79
- text = [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
80
- return clean_text("\n".join(text)), ""
81
-
82
- elif file_extension in ["jpg", "jpeg", "png"]:
83
- ocr_result = reader.readtext(file_path, detail=0)
84
- return clean_text("\n".join(ocr_result)), ""
85
-
86
- return "", "Unsupported file format"
87
- except Exception as e:
88
- return "", f"Error reading {file_extension.upper()} file: {str(e)}"
89
-
90
- def chunk_text(text: str, max_tokens: int = 950):
91
- try:
92
- sentences = sent_tokenize(text)
93
- except:
94
- words = text.split()
95
- sentences = [' '.join(words[i:i+20]) for i in range(0, len(words), 20)]
96
-
97
- chunks = []
98
- current_chunk = ""
99
- for sentence in sentences:
100
- token_length = len(tokenizer.encode(current_chunk + " " + sentence))
101
- if token_length <= max_tokens:
102
- current_chunk += " " + sentence
103
- else:
104
- chunks.append(current_chunk.strip())
105
- current_chunk = sentence
106
-
107
- if current_chunk:
108
- chunks.append(current_chunk.strip())
109
-
110
- return chunks
111
-
112
- def generate_summary(text: str, length: str = "medium") -> str:
113
- cache_key = hashlib.md5((text + length).encode()).hexdigest()
114
- if cache_key in summary_cache:
115
- return summary_cache[cache_key]
116
-
117
- length_params = {
118
- "short": {"max_length": 80, "min_length": 30},
119
- "medium": {"max_length": 200, "min_length": 80},
120
- "long": {"max_length": 300, "min_length": 210}
121
- }
122
- chunks = chunk_text(text)
123
- try:
124
- summaries = summarizer(
125
- chunks,
126
- max_length=length_params[length]["max_length"],
127
- min_length=length_params[length]["min_length"],
128
- do_sample=False,
129
- truncation=True,
130
- no_repeat_ngram_size=2,
131
- num_beams=2,
132
- early_stopping=True
133
- )
134
- summary_texts = [s['summary_text'] for s in summaries]
135
- except Exception as e:
136
- summary_texts = [f"[Batch error: {str(e)}]"]
137
-
138
- final_summary = " ".join(summary_texts)
139
- final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
140
- final_summary = final_summary if len(final_summary) > 25 else "Summary too short - document may be too brief"
141
-
142
- summary_cache[cache_key] = final_summary
143
- return final_summary
144
-
145
- def text_to_speech(text: str):
146
- try:
147
- tts = gTTS(text)
148
- temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
149
- tts.save(temp_audio.name)
150
- return temp_audio.name
151
- except Exception as e:
152
- print(f"Error in text-to-speech: {e}")
153
- return ""
154
-
155
- def create_pdf(summary: str, original_filename: str):
156
- try:
157
- pdf = FPDF()
158
- pdf.add_page()
159
- pdf.set_font("Arial", size=12)
160
- pdf.set_font("Arial", 'B', 16)
161
- pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
162
- pdf.set_font("Arial", size=12)
163
- pdf.cell(200, 10, txt=f"Original file: {original_filename}", ln=1)
164
- pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
165
- pdf.ln(10)
166
- pdf.multi_cell(0, 10, txt=summary)
167
- temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
168
- pdf.output(temp_pdf.name)
169
- return temp_pdf.name
170
- except Exception as e:
171
- print(f"Error creating PDF: {e}")
172
- return ""
173
-
174
  @app.post("/summarize/")
175
  async def summarize_api(file: UploadFile = File(...), length: str = Form("medium")):
176
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
177
  shutil.copyfileobj(file.file, temp)
178
  temp.flush()
179
  class FileObj: name = temp.name
180
- summary, _, audio_path, pdf_path = summarize_document(FileObj, length)
181
-
182
- return {
 
 
 
 
 
 
183
  "summary": summary,
184
  "audio_url": f"/files/{os.path.basename(audio_path)}" if audio_path else None,
185
  "pdf_url": f"/files/{os.path.basename(pdf_path)}" if pdf_path else None
186
- }
187
-
188
- @app.get("/files/{file_name}")
189
- async def get_file(file_name: str):
190
- file_path = os.path.join(tempfile.gettempdir(), file_name)
191
- if os.path.exists(file_path):
192
- return FileResponse(file_path)
193
- return JSONResponse({"error": "File not found"}, status_code=404)
194
-
195
- @app.get("/")
196
- def redirect_to_interface():
197
- return RedirectResponse(url="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @app.post("/summarize/")
2
  async def summarize_api(file: UploadFile = File(...), length: str = Form("medium")):
3
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
4
  shutil.copyfileobj(file.file, temp)
5
  temp.flush()
6
  class FileObj: name = temp.name
7
+ text, error = extract_text(FileObj.name, os.path.splitext(file.filename)[1][1:].lower())
8
+ if error:
9
+ return JSONResponse({"error": error}, status_code=400)
10
+
11
+ summary = generate_summary(text, length)
12
+ audio_path = text_to_speech(summary)
13
+ pdf_path = create_pdf(summary, file.filename)
14
+
15
+ return JSONResponse({
16
  "summary": summary,
17
  "audio_url": f"/files/{os.path.basename(audio_path)}" if audio_path else None,
18
  "pdf_url": f"/files/{os.path.basename(pdf_path)}" if pdf_path else None
19
+ })