import os import io import logging import tempfile from flask import Flask, request, jsonify from werkzeug.utils import secure_filename from PyPDF2 import PdfReader from docx import Document from pptx import Presentation from transformers import T5Tokenizer, T5ForConditionalGeneration from flask_cors import CORS # Import CORS for cross-origin requests # Configure loggin logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) CORS(app) # Enable CORS for all routes # Set up a temporary directory for Hugging Face cache cache_dir = tempfile.mkdtemp() os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir # Load T5 model and tokenizer logger.info("Loading T5-Base model...") try: tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=cache_dir) model = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=cache_dir) logger.info("T5-Base model loaded successfully.") except Exception as e: logger.error(f"Failed to load T5-Base: {str(e)}") raise ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"} def allowed_file(filename): """Check if the uploaded file has an allowed extension.""" return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS def summarize_text(text, max_length=300, min_length=100): """ Summarize text using T5-Base with improved parameters for more comprehensive summaries. Args: text (str): The text to summarize max_length (int): Maximum length of the summary (increased from 150) min_length (int): Minimum length of the summary (increased from 30) Returns: str: The generated summary """ try: if not text.strip(): return "No text found in the document to summarize." # Break text into chunks if it's very long chunks = [] chunk_size = 4000 # Characters per chunk for i in range(0, len(text), chunk_size): chunks.append(text[i:i + chunk_size]) summaries = [] for i, chunk in enumerate(chunks): # Only process up to 5 chunks to avoid very long processing times if i >= 5: summaries.append("... (Document continues)") break input_text = "summarize: " + chunk inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate with improved parameters summary_ids = model.generate( inputs["input_ids"], max_length=max_length // min(5, len(chunks)), # Adjust max_length based on chunks min_length=min_length // min(5, len(chunks)), # Adjust min_length based on chunks length_penalty=1.5, # Reduced to avoid overly verbose summaries num_beams=4, early_stopping=True, no_repeat_ngram_size=3 # Avoid repeating trigrams ) chunk_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(chunk_summary) # Combine summaries from all chunks combined_summary = " ".join(summaries) # For very short summaries, try again with the first chunk but longer output if len(combined_summary.split()) < 50 and chunks: input_text = "summarize: " + chunks[0] inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) summary_ids = model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=5, early_stopping=True, repetition_penalty=2.5 # Penalize repetition more heavily ) combined_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return combined_summary except Exception as e: logger.error(f"Error in T5 summarization: {str(e)}") return f"Error summarizing text: {str(e)}" @app.route("/", methods=["GET"]) def index(): """Root endpoint.""" logger.info("Root endpoint accessed.") return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests." @app.route("/summarize", methods=["POST"]) def summarize(): logger.info("Summarize endpoint called.") # Debug the incoming request logger.info(f"Request headers: {request.headers}") logger.info(f"Request files: {request.files}") logger.info(f"Request form: {request.form}") # Check if a file is in the request if "file" not in request.files: logger.error("No file found in request.files") return jsonify({"error": "No file uploaded. Make sure to use 'file' as the form field name."}), 400 file = request.files["file"] # Check if file is empty if file.filename == "": logger.error("File has no filename") return jsonify({"error": "No selected file"}), 400 # Check if file has an allowed extension if not allowed_file(file.filename): logger.error(f"Unsupported file format: {file.filename}") return jsonify({"error": f"Unsupported file format. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400 # Process the file filename = secure_filename(file.filename) file_content = file.read() file_ext = filename.rsplit(".", 1)[1].lower() try: if file_ext == "pdf": text = summarize_pdf(file_content) elif file_ext == "docx": text = summarize_docx(file_content) elif file_ext == "pptx": text = summarize_pptx(file_content) elif file_ext == "txt": text = summarize_txt(file_content) else: logger.error("Unsupported file format received.") return jsonify({"error": "Unsupported file format"}), 400 # Generate summary logger.info(f"Generating summary for {filename} with text length {len(text)}") summary = summarize_text(text) logger.info(f"File {filename} summarized successfully.") return jsonify({ "filename": filename, "summary": summary, "textLength": len(text) }) except Exception as e: logger.error(f"Error processing file {filename}: {str(e)}") return jsonify({"error": f"Error processing file: {str(e)}"}), 500 def summarize_pdf(file_content): """Extract text from PDF.""" try: reader = PdfReader(io.BytesIO(file_content)) text = "\n".join([page.extract_text() or "" for page in reader.pages]) return text.strip() except Exception as e: logger.error(f"Error extracting text from PDF: {str(e)}") raise Exception(f"Failed to extract text from PDF: {str(e)}") def summarize_docx(file_content): """Extract text from DOCX.""" try: doc = Document(io.BytesIO(file_content)) text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) return text.strip() except Exception as e: logger.error(f"Error extracting text from DOCX: {str(e)}") raise Exception(f"Failed to extract text from DOCX: {str(e)}") def summarize_pptx(file_content): """Extract text from PPTX.""" try: ppt = Presentation(io.BytesIO(file_content)) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text") and shape.text.strip(): text.append(shape.text.strip()) return "\n".join(text).strip() except Exception as e: logger.error(f"Error extracting text from PPTX: {str(e)}") raise Exception(f"Failed to extract text from PPTX: {str(e)}") def summarize_txt(file_content): """Extract text from TXT file.""" try: return file_content.decode("utf-8").strip() except UnicodeDecodeError: # Try different encodings if UTF-8 fails try: return file_content.decode("latin-1").strip() except Exception as e: logger.error(f"Error decoding text file: {str(e)}") raise Exception(f"Failed to decode text file: {str(e)}") if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)