Spaces:
Running
Running
File size: 6,942 Bytes
6f78a44 da10ca7 7839da1 6f78a44 da10ca7 6f78a44 da10ca7 7839da1 da10ca7 7839da1 6f78a44 7839da1 6f78a44 7839da1 6f78a44 7839da1 da10ca7 6f78a44 7839da1 6f78a44 7839da1 da10ca7 7839da1 da10ca7 7839da1 da10ca7 7839da1 6f78a44 7839da1 6f78a44 7839da1 6f78a44 7839da1 6f78a44 da10ca7 6f78a44 da10ca7 6f78a44 da10ca7 6f78a44 da10ca7 6f78a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import fitz # PyMuPDF
import docx
import pptx
import openpyxl
import re
import nltk
import torch
from nltk.tokenize import sent_tokenize
from gtts import gTTS
from fpdf import FPDF
import tempfile
import os
import easyocr
import datetime
import hashlib
# Initialize
nltk.download('punkt', quiet=True)
app = FastAPI()
# Load Summarizer Model
MODEL_NAME = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
# Load OCR Reader
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
# Cache
summary_cache = {}
# --- Helper Functions ---
def clean_text(text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
return text.strip()
def extract_text(file_path: str, file_extension: str):
try:
if file_extension == "pdf":
with fitz.open(file_path) as doc:
text = "\n".join(page.get_text("text") for page in doc)
if len(text.strip()) < 50:
images = [page.get_pixmap() for page in doc]
temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
images[0].save(temp_img.name)
ocr_result = reader.readtext(temp_img.name, detail=0)
os.unlink(temp_img.name)
text = "\n".join(ocr_result) if ocr_result else text
return clean_text(text), ""
elif file_extension == "docx":
doc = docx.Document(file_path)
return clean_text("\n".join(p.text for p in doc.paragraphs)), ""
elif file_extension == "pptx":
prs = pptx.Presentation(file_path)
text = [shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]
return clean_text("\n".join(text)), ""
elif file_extension == "xlsx":
wb = openpyxl.load_workbook(file_path, read_only=True)
text = [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
return clean_text("\n".join(text)), ""
elif file_extension in ["jpg", "jpeg", "png"]:
ocr_result = reader.readtext(file_path, detail=0)
return clean_text("\n".join(ocr_result)), ""
return "", "Unsupported file format"
except Exception as e:
return "", f"Error reading {file_extension.upper()} file: {str(e)}"
def chunk_text(text: str, max_tokens: int = 950):
try:
sentences = sent_tokenize(text)
except:
words = text.split()
sentences = [' '.join(words[i:i+20]) for i in range(0, len(words), 20)]
chunks = []
current_chunk = ""
for sentence in sentences:
token_length = len(tokenizer.encode(current_chunk + " " + sentence))
if token_length <= max_tokens:
current_chunk += " " + sentence
else:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def generate_summary(text: str, length: str = "medium") -> str:
cache_key = hashlib.md5((text + length).encode()).hexdigest()
if cache_key in summary_cache:
return summary_cache[cache_key]
length_params = {
"short": {"max_length": 80, "min_length": 30},
"medium": {"max_length": 200, "min_length": 80},
"long": {"max_length": 300, "min_length": 210}
}
chunks = chunk_text(text)
summaries = summarizer(
chunks,
max_length=length_params[length]["max_length"],
min_length=length_params[length]["min_length"],
do_sample=False,
truncation=True,
no_repeat_ngram_size=2,
num_beams=2,
early_stopping=True
)
summary_texts = [s['summary_text'] for s in summaries]
final_summary = " ".join(summary_texts)
final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
final_summary = final_summary if len(final_summary) > 25 else "Summary too short - document may be too brief"
summary_cache[cache_key] = final_summary
return final_summary
def text_to_speech(text: str):
try:
tts = gTTS(text)
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
return temp_audio.name
except Exception:
return ""
def create_pdf(summary: str, original_filename: str):
try:
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", 'B', 16)
pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, txt=f"Original file: {original_filename}", ln=1)
pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
pdf.ln(10)
pdf.multi_cell(0, 10, txt=summary)
temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
pdf.output(temp_pdf.name)
return temp_pdf.name
except Exception:
return ""
# --- API Endpoints ---
@app.post("/summarize/")
async def summarize_api(file: UploadFile = File(...), length: str = Form("medium")):
contents = await file.read()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(contents)
tmp_path = tmp_file.name
file_ext = tmp_path.split('.')[-1].lower()
text, error = extract_text(tmp_path, file_ext)
if error:
return JSONResponse({"detail": error}, status_code=400)
if not text or len(text.split()) < 30:
return JSONResponse({"detail": "Document too short to summarize"}, status_code=400)
summary = generate_summary(text, length)
audio_path = text_to_speech(summary)
pdf_path = create_pdf(summary, file.filename)
response = {"summary": summary}
if audio_path:
response["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
if pdf_path:
response["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
return JSONResponse(response)
@app.get("/files/{file_name}")
async def serve_file(file_name: str):
path = os.path.join(tempfile.gettempdir(), file_name)
if os.path.exists(path):
return FileResponse(path)
return JSONResponse({"error": "File not found"}, status_code=404)
@app.get("/")
def home():
return RedirectResponse(url="/")
|