Summarization / app.py
ikraamkb's picture
Update app.py
587a2e1 verified
raw
history blame
11.9 kB
# app.py
"""from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import fitz, docx, pptx, openpyxl, re, nltk, tempfile, os, easyocr, datetime, hashlib
from nltk.tokenize import sent_tokenize
from fpdf import FPDF
from gtts import gTTS
nltk.download('punkt', quiet=True)
# Load models
MODEL_NAME = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
reader = easyocr.Reader(['en'], gpu=False)
summary_cache = {}
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
return text.strip()
def extract_text(file_path: str, ext: str):
try:
if ext == "pdf":
with fitz.open(file_path) as doc:
text = "\n".join(page.get_text("text") for page in doc)
if len(text.strip()) < 50:
images = [page.get_pixmap() for page in doc]
temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
images[0].save(temp_img.name)
text = "\n".join(reader.readtext(temp_img.name, detail=0))
os.unlink(temp_img.name)
elif ext == "docx":
doc = docx.Document(file_path)
text = "\n".join(p.text for p in doc.paragraphs)
elif ext == "pptx":
prs = pptx.Presentation(file_path)
text = "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
elif ext == "xlsx":
wb = openpyxl.load_workbook(file_path, read_only=True)
text = "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)])
else:
text = ""
except Exception as e:
return "", f"Error extracting text: {str(e)}"
return clean_text(text), ""
def chunk_text(text: str, max_tokens: int = 950):
sentences = sent_tokenize(text)
chunks, current_chunk = [], ""
for sentence in sentences:
if len(tokenizer.encode(current_chunk + " " + sentence)) <= max_tokens:
current_chunk += " " + sentence
else:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def generate_summary(text: str, length: str = "medium"):
cache_key = hashlib.md5((text + length).encode()).hexdigest()
if cache_key in summary_cache:
return summary_cache[cache_key]
length_params = {
"short": {"max_length": 80, "min_length": 30},
"medium": {"max_length": 200, "min_length": 80},
"long": {"max_length": 300, "min_length": 210}
}
chunks = chunk_text(text)
summaries = summarizer(
chunks,
max_length=length_params[length]["max_length"],
min_length=length_params[length]["min_length"],
do_sample=False,
truncation=True,
no_repeat_ngram_size=2,
num_beams=2,
early_stopping=True
)
final_summary = " ".join(s['summary_text'] for s in summaries)
final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
final_summary = final_summary if len(final_summary) > 25 else "Summary too short."
summary_cache[cache_key] = final_summary
return final_summary
def text_to_speech(text: str):
try:
tts = gTTS(text)
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
return temp_audio.name
except:
return ""
def create_pdf(summary: str, filename: str):
try:
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
pdf.multi_cell(0, 10, summary)
temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
pdf.output(temp_pdf.name)
return temp_pdf.name
except:
return ""
async def summarize_document(file, length="medium"):
contents = await file.read()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(contents)
tmp_path = tmp_file.name
ext = file.filename.split('.')[-1].lower()
text, error = extract_text(tmp_path, ext)
if error:
raise Exception(error)
if not text or len(text.split()) < 30:
raise Exception("Document too short to summarize.")
summary = generate_summary(text, length)
audio_path = text_to_speech(summary)
pdf_path = create_pdf(summary, file.filename)
result = {"summary": summary}
if audio_path:
result["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
if pdf_path:
result["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
return result"""
# app.py
from fastapi import UploadFile, File
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import fitz # PyMuPDF
import docx
import pptx
import openpyxl
import re
import nltk
import torch
from nltk.tokenize import sent_tokenize
from gtts import gTTS
from fpdf import FPDF
import tempfile
import os
import easyocr
import datetime
import hashlib
# Setup
nltk.download('punkt', quiet=True)
# Load Models
MODEL_NAME = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
reader = easyocr.Reader(['en','fr'], gpu=torch.cuda.is_available())
summary_cache = {}
# Allowed file extensions
ALLOWED_EXTENSIONS = {'pdf', 'docx', 'pptx', 'xlsx'}
# --- Helper Functions ---
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
return text.strip()
def extract_text(file_path: str, extension: str):
try:
if extension == "pdf":
with fitz.open(file_path) as doc:
text = "\n".join(page.get_text("text") for page in doc)
if len(text.strip()) < 50:
images = [page.get_pixmap() for page in doc]
temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
images[0].save(temp_img.name)
ocr_result = reader.readtext(temp_img.name, detail=0)
os.unlink(temp_img.name)
text = "\n".join(ocr_result) if ocr_result else text
elif extension == "docx":
doc = docx.Document(file_path)
text = "\n".join(p.text for p in doc.paragraphs)
elif extension == "pptx":
prs = pptx.Presentation(file_path)
text = "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
elif extension == "xlsx":
wb = openpyxl.load_workbook(file_path, read_only=True)
text = "\n".join(
[" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
)
else:
return "", "Unsupported file format."
return clean_text(text), ""
except Exception as e:
return "", f"Error reading {extension.upper()} file: {str(e)}"
def chunk_text(text: str, max_tokens: int = 950):
try:
sentences = sent_tokenize(text)
except:
words = text.split()
sentences = [' '.join(words[i:i+20]) for i in range(0, len(words), 20)]
chunks = []
current_chunk = ""
for sentence in sentences:
token_length = len(tokenizer.encode(current_chunk + " " + sentence))
if token_length <= max_tokens:
current_chunk += " " + sentence
else:
if current_chunk.strip():
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks
def generate_summary(text: str, length: str = "medium"):
cache_key = hashlib.md5((text + length).encode()).hexdigest()
if cache_key in summary_cache:
return summary_cache[cache_key]
length_params = {
"short": {"max_length": 50, "min_length": 30},
"medium": {"max_length": 200, "min_length": 80},
"long": {"max_length": 300, "min_length": 210}
}
chunks = chunk_text(text)
summaries = summarizer(
chunks,
max_length=length_params[length]["max_length"],
min_length=length_params[length]["min_length"],
do_sample=False,
truncation=True,
no_repeat_ngram_size=2,
num_beams=2,
early_stopping=True
)
summary_texts = [s['summary_text'] for s in summaries]
final_summary = " ".join(summary_texts)
final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
final_summary = final_summary if len(final_summary) > 25 else "Summary too short - document may be too brief"
summary_cache[cache_key] = final_summary
return final_summary
def text_to_speech(text: str):
try:
tts = gTTS(text)
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
return temp_audio.name
except Exception:
return ""
def create_pdf(summary: str, filename: str):
try:
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", 'B', 16)
pdf.cell(200, 10, txt=f"Summary of {filename}", ln=1, align='C')
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
pdf.ln(10)
pdf.set_font("Arial", size=10)
pdf.multi_cell(0, 10, txt=summary)
temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
pdf.output(temp_pdf.name)
return temp_pdf.name
except Exception:
return ""
# --- Public API Function ---
async def summarize_document(file: UploadFile, length: str = "medium"):
try:
filename = file.filename
extension = os.path.splitext(filename)[-1].lower().replace('.', '')
if extension not in ALLOWED_EXTENSIONS:
raise Exception(f"Unsupported file type: {extension.upper()}. Only PDF, DOCX, PPTX, XLSX are allowed.")
# Save uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as tmp_file:
tmp_file.write(await file.read())
tmp_path = tmp_file.name
# Extract text
text, error = extract_text(tmp_path, extension)
if error:
os.unlink(tmp_path)
raise Exception(error)
if not text or len(text.split()) < 30:
os.unlink(tmp_path)
raise Exception("Document too short to summarize.")
# Summarize
summary = generate_summary(text, length)
# Create audio + PDF
audio_path = text_to_speech(summary)
pdf_path = create_pdf(summary, filename)
# Clean temp file
os.unlink(tmp_path)
# Prepare response
response = {"summary": summary}
if audio_path:
response["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
if pdf_path:
response["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
return response
except Exception as e:
raise Exception(f"Summarization failed: {str(e)}")