mike23415's picture
Update app.py
f8ea9fd verified
import os
import io
import logging
import tempfile
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from PyPDF2 import PdfReader
from docx import Document
from pptx import Presentation
from transformers import T5Tokenizer, T5ForConditionalGeneration
from flask_cors import CORS # Import CORS for cross-origin requests
# Configure loggin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Set up a temporary directory for Hugging Face cache
cache_dir = tempfile.mkdtemp()
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# Load T5 model and tokenizer
logger.info("Loading T5-Base model...")
try:
tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=cache_dir)
model = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=cache_dir)
logger.info("T5-Base model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load T5-Base: {str(e)}")
raise
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"}
def allowed_file(filename):
"""Check if the uploaded file has an allowed extension."""
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def summarize_text(text, max_length=300, min_length=100):
"""
Summarize text using T5-Base with improved parameters for more comprehensive summaries.
Args:
text (str): The text to summarize
max_length (int): Maximum length of the summary (increased from 150)
min_length (int): Minimum length of the summary (increased from 30)
Returns:
str: The generated summary
"""
try:
if not text.strip():
return "No text found in the document to summarize."
# Break text into chunks if it's very long
chunks = []
chunk_size = 4000 # Characters per chunk
for i in range(0, len(text), chunk_size):
chunks.append(text[i:i + chunk_size])
summaries = []
for i, chunk in enumerate(chunks):
# Only process up to 5 chunks to avoid very long processing times
if i >= 5:
summaries.append("... (Document continues)")
break
input_text = "summarize: " + chunk
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate with improved parameters
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length // min(5, len(chunks)), # Adjust max_length based on chunks
min_length=min_length // min(5, len(chunks)), # Adjust min_length based on chunks
length_penalty=1.5, # Reduced to avoid overly verbose summaries
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3 # Avoid repeating trigrams
)
chunk_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(chunk_summary)
# Combine summaries from all chunks
combined_summary = " ".join(summaries)
# For very short summaries, try again with the first chunk but longer output
if len(combined_summary.split()) < 50 and chunks:
input_text = "summarize: " + chunks[0]
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=5,
early_stopping=True,
repetition_penalty=2.5 # Penalize repetition more heavily
)
combined_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return combined_summary
except Exception as e:
logger.error(f"Error in T5 summarization: {str(e)}")
return f"Error summarizing text: {str(e)}"
@app.route("/", methods=["GET"])
def index():
"""Root endpoint."""
logger.info("Root endpoint accessed.")
return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests."
@app.route("/summarize", methods=["POST"])
def summarize():
logger.info("Summarize endpoint called.")
# Debug the incoming request
logger.info(f"Request headers: {request.headers}")
logger.info(f"Request files: {request.files}")
logger.info(f"Request form: {request.form}")
# Check if a file is in the request
if "file" not in request.files:
logger.error("No file found in request.files")
return jsonify({"error": "No file uploaded. Make sure to use 'file' as the form field name."}), 400
file = request.files["file"]
# Check if file is empty
if file.filename == "":
logger.error("File has no filename")
return jsonify({"error": "No selected file"}), 400
# Check if file has an allowed extension
if not allowed_file(file.filename):
logger.error(f"Unsupported file format: {file.filename}")
return jsonify({"error": f"Unsupported file format. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
# Process the file
filename = secure_filename(file.filename)
file_content = file.read()
file_ext = filename.rsplit(".", 1)[1].lower()
try:
if file_ext == "pdf":
text = summarize_pdf(file_content)
elif file_ext == "docx":
text = summarize_docx(file_content)
elif file_ext == "pptx":
text = summarize_pptx(file_content)
elif file_ext == "txt":
text = summarize_txt(file_content)
else:
logger.error("Unsupported file format received.")
return jsonify({"error": "Unsupported file format"}), 400
# Generate summary
logger.info(f"Generating summary for {filename} with text length {len(text)}")
summary = summarize_text(text)
logger.info(f"File {filename} summarized successfully.")
return jsonify({
"filename": filename,
"summary": summary,
"textLength": len(text)
})
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return jsonify({"error": f"Error processing file: {str(e)}"}), 500
def summarize_pdf(file_content):
"""Extract text from PDF."""
try:
reader = PdfReader(io.BytesIO(file_content))
text = "\n".join([page.extract_text() or "" for page in reader.pages])
return text.strip()
except Exception as e:
logger.error(f"Error extracting text from PDF: {str(e)}")
raise Exception(f"Failed to extract text from PDF: {str(e)}")
def summarize_docx(file_content):
"""Extract text from DOCX."""
try:
doc = Document(io.BytesIO(file_content))
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
return text.strip()
except Exception as e:
logger.error(f"Error extracting text from DOCX: {str(e)}")
raise Exception(f"Failed to extract text from DOCX: {str(e)}")
def summarize_pptx(file_content):
"""Extract text from PPTX."""
try:
ppt = Presentation(io.BytesIO(file_content))
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
text.append(shape.text.strip())
return "\n".join(text).strip()
except Exception as e:
logger.error(f"Error extracting text from PPTX: {str(e)}")
raise Exception(f"Failed to extract text from PPTX: {str(e)}")
def summarize_txt(file_content):
"""Extract text from TXT file."""
try:
return file_content.decode("utf-8").strip()
except UnicodeDecodeError:
# Try different encodings if UTF-8 fails
try:
return file_content.decode("latin-1").strip()
except Exception as e:
logger.error(f"Error decoding text file: {str(e)}")
raise Exception(f"Failed to decode text file: {str(e)}")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)