Spaces:
Sleeping
Sleeping
import os | |
import io | |
import logging | |
import tempfile | |
from flask import Flask, request, jsonify | |
from werkzeug.utils import secure_filename | |
from PyPDF2 import PdfReader | |
from docx import Document | |
from pptx import Presentation | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from flask_cors import CORS # Import CORS for cross-origin requests | |
# Configure loggin | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize Flask app | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
# Set up a temporary directory for Hugging Face cache | |
cache_dir = tempfile.mkdtemp() | |
os.environ["HF_HOME"] = cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
# Load T5 model and tokenizer | |
logger.info("Loading T5-Base model...") | |
try: | |
tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=cache_dir) | |
model = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=cache_dir) | |
logger.info("T5-Base model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load T5-Base: {str(e)}") | |
raise | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"} | |
def allowed_file(filename): | |
"""Check if the uploaded file has an allowed extension.""" | |
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS | |
def summarize_text(text, max_length=300, min_length=100): | |
""" | |
Summarize text using T5-Base with improved parameters for more comprehensive summaries. | |
Args: | |
text (str): The text to summarize | |
max_length (int): Maximum length of the summary (increased from 150) | |
min_length (int): Minimum length of the summary (increased from 30) | |
Returns: | |
str: The generated summary | |
""" | |
try: | |
if not text.strip(): | |
return "No text found in the document to summarize." | |
# Break text into chunks if it's very long | |
chunks = [] | |
chunk_size = 4000 # Characters per chunk | |
for i in range(0, len(text), chunk_size): | |
chunks.append(text[i:i + chunk_size]) | |
summaries = [] | |
for i, chunk in enumerate(chunks): | |
# Only process up to 5 chunks to avoid very long processing times | |
if i >= 5: | |
summaries.append("... (Document continues)") | |
break | |
input_text = "summarize: " + chunk | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Generate with improved parameters | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_length // min(5, len(chunks)), # Adjust max_length based on chunks | |
min_length=min_length // min(5, len(chunks)), # Adjust min_length based on chunks | |
length_penalty=1.5, # Reduced to avoid overly verbose summaries | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=3 # Avoid repeating trigrams | |
) | |
chunk_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
summaries.append(chunk_summary) | |
# Combine summaries from all chunks | |
combined_summary = " ".join(summaries) | |
# For very short summaries, try again with the first chunk but longer output | |
if len(combined_summary.split()) < 50 and chunks: | |
input_text = "summarize: " + chunks[0] | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
min_length=min_length, | |
length_penalty=2.0, | |
num_beams=5, | |
early_stopping=True, | |
repetition_penalty=2.5 # Penalize repetition more heavily | |
) | |
combined_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return combined_summary | |
except Exception as e: | |
logger.error(f"Error in T5 summarization: {str(e)}") | |
return f"Error summarizing text: {str(e)}" | |
def index(): | |
"""Root endpoint.""" | |
logger.info("Root endpoint accessed.") | |
return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests." | |
def summarize(): | |
logger.info("Summarize endpoint called.") | |
# Debug the incoming request | |
logger.info(f"Request headers: {request.headers}") | |
logger.info(f"Request files: {request.files}") | |
logger.info(f"Request form: {request.form}") | |
# Check if a file is in the request | |
if "file" not in request.files: | |
logger.error("No file found in request.files") | |
return jsonify({"error": "No file uploaded. Make sure to use 'file' as the form field name."}), 400 | |
file = request.files["file"] | |
# Check if file is empty | |
if file.filename == "": | |
logger.error("File has no filename") | |
return jsonify({"error": "No selected file"}), 400 | |
# Check if file has an allowed extension | |
if not allowed_file(file.filename): | |
logger.error(f"Unsupported file format: {file.filename}") | |
return jsonify({"error": f"Unsupported file format. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 | |
# Process the file | |
filename = secure_filename(file.filename) | |
file_content = file.read() | |
file_ext = filename.rsplit(".", 1)[1].lower() | |
try: | |
if file_ext == "pdf": | |
text = summarize_pdf(file_content) | |
elif file_ext == "docx": | |
text = summarize_docx(file_content) | |
elif file_ext == "pptx": | |
text = summarize_pptx(file_content) | |
elif file_ext == "txt": | |
text = summarize_txt(file_content) | |
else: | |
logger.error("Unsupported file format received.") | |
return jsonify({"error": "Unsupported file format"}), 400 | |
# Generate summary | |
logger.info(f"Generating summary for {filename} with text length {len(text)}") | |
summary = summarize_text(text) | |
logger.info(f"File {filename} summarized successfully.") | |
return jsonify({ | |
"filename": filename, | |
"summary": summary, | |
"textLength": len(text) | |
}) | |
except Exception as e: | |
logger.error(f"Error processing file {filename}: {str(e)}") | |
return jsonify({"error": f"Error processing file: {str(e)}"}), 500 | |
def summarize_pdf(file_content): | |
"""Extract text from PDF.""" | |
try: | |
reader = PdfReader(io.BytesIO(file_content)) | |
text = "\n".join([page.extract_text() or "" for page in reader.pages]) | |
return text.strip() | |
except Exception as e: | |
logger.error(f"Error extracting text from PDF: {str(e)}") | |
raise Exception(f"Failed to extract text from PDF: {str(e)}") | |
def summarize_docx(file_content): | |
"""Extract text from DOCX.""" | |
try: | |
doc = Document(io.BytesIO(file_content)) | |
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) | |
return text.strip() | |
except Exception as e: | |
logger.error(f"Error extracting text from DOCX: {str(e)}") | |
raise Exception(f"Failed to extract text from DOCX: {str(e)}") | |
def summarize_pptx(file_content): | |
"""Extract text from PPTX.""" | |
try: | |
ppt = Presentation(io.BytesIO(file_content)) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text") and shape.text.strip(): | |
text.append(shape.text.strip()) | |
return "\n".join(text).strip() | |
except Exception as e: | |
logger.error(f"Error extracting text from PPTX: {str(e)}") | |
raise Exception(f"Failed to extract text from PPTX: {str(e)}") | |
def summarize_txt(file_content): | |
"""Extract text from TXT file.""" | |
try: | |
return file_content.decode("utf-8").strip() | |
except UnicodeDecodeError: | |
# Try different encodings if UTF-8 fails | |
try: | |
return file_content.decode("latin-1").strip() | |
except Exception as e: | |
logger.error(f"Error decoding text file: {str(e)}") | |
raise Exception(f"Failed to decode text file: {str(e)}") | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=True) |