Spaces:
Sleeping
Sleeping
File size: 8,488 Bytes
764d4f7 2a3fae3 80d0b8a b4aa0e4 b7db40a 764d4f7 2a3fae3 764d4f7 2a3fae3 a911ba5 b4aa0e4 524f780 f8ea9fd 80d0b8a dc17435 764d4f7 b4aa0e4 b7db40a b4aa0e4 012dc5b a911ba5 012dc5b b4aa0e4 012dc5b d2d0219 2a3fae3 d2d0219 2a3fae3 798ae00 2a3fae3 d2d0219 b46017c a911ba5 dc17435 b46017c a911ba5 b4aa0e4 92d0377 98e82be dc17435 80d0b8a dc17435 98e82be 3b4df89 2a3fae3 80d0b8a b4aa0e4 3696c1f 764d4f7 b4aa0e4 3696c1f dc17435 3696c1f 764d4f7 b4aa0e4 764d4f7 3696c1f 2a3fae3 80d0b8a b4aa0e4 798ae00 3696c1f 764d4f7 2a3fae3 3696c1f 98e82be a911ba5 98e82be a911ba5 98e82be a911ba5 98e82be a911ba5 dc17435 3696c1f dc17435 798ae00 3696c1f b4aa0e4 dc17435 3696c1f 798ae00 b4aa0e4 3696c1f 98e82be 80d0b8a 98e82be 764d4f7 92d0377 dc17435 b4aa0e4 d2d0219 92d0377 dc17435 b4aa0e4 d2d0219 92d0377 dc17435 b4aa0e4 b7db40a 92d0377 dc17435 b4aa0e4 53425a8 9fd7d89 b4aa0e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import os
import io
import logging
import tempfile
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from PyPDF2 import PdfReader
from docx import Document
from pptx import Presentation
from transformers import T5Tokenizer, T5ForConditionalGeneration
from flask_cors import CORS # Import CORS for cross-origin requests
# Configure loggin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Set up a temporary directory for Hugging Face cache
cache_dir = tempfile.mkdtemp()
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# Load T5 model and tokenizer
logger.info("Loading T5-Base model...")
try:
tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=cache_dir)
model = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=cache_dir)
logger.info("T5-Base model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load T5-Base: {str(e)}")
raise
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"}
def allowed_file(filename):
"""Check if the uploaded file has an allowed extension."""
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def summarize_text(text, max_length=300, min_length=100):
"""
Summarize text using T5-Base with improved parameters for more comprehensive summaries.
Args:
text (str): The text to summarize
max_length (int): Maximum length of the summary (increased from 150)
min_length (int): Minimum length of the summary (increased from 30)
Returns:
str: The generated summary
"""
try:
if not text.strip():
return "No text found in the document to summarize."
# Break text into chunks if it's very long
chunks = []
chunk_size = 4000 # Characters per chunk
for i in range(0, len(text), chunk_size):
chunks.append(text[i:i + chunk_size])
summaries = []
for i, chunk in enumerate(chunks):
# Only process up to 5 chunks to avoid very long processing times
if i >= 5:
summaries.append("... (Document continues)")
break
input_text = "summarize: " + chunk
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate with improved parameters
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length // min(5, len(chunks)), # Adjust max_length based on chunks
min_length=min_length // min(5, len(chunks)), # Adjust min_length based on chunks
length_penalty=1.5, # Reduced to avoid overly verbose summaries
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3 # Avoid repeating trigrams
)
chunk_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(chunk_summary)
# Combine summaries from all chunks
combined_summary = " ".join(summaries)
# For very short summaries, try again with the first chunk but longer output
if len(combined_summary.split()) < 50 and chunks:
input_text = "summarize: " + chunks[0]
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=5,
early_stopping=True,
repetition_penalty=2.5 # Penalize repetition more heavily
)
combined_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return combined_summary
except Exception as e:
logger.error(f"Error in T5 summarization: {str(e)}")
return f"Error summarizing text: {str(e)}"
@app.route("/", methods=["GET"])
def index():
"""Root endpoint."""
logger.info("Root endpoint accessed.")
return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests."
@app.route("/summarize", methods=["POST"])
def summarize():
logger.info("Summarize endpoint called.")
# Debug the incoming request
logger.info(f"Request headers: {request.headers}")
logger.info(f"Request files: {request.files}")
logger.info(f"Request form: {request.form}")
# Check if a file is in the request
if "file" not in request.files:
logger.error("No file found in request.files")
return jsonify({"error": "No file uploaded. Make sure to use 'file' as the form field name."}), 400
file = request.files["file"]
# Check if file is empty
if file.filename == "":
logger.error("File has no filename")
return jsonify({"error": "No selected file"}), 400
# Check if file has an allowed extension
if not allowed_file(file.filename):
logger.error(f"Unsupported file format: {file.filename}")
return jsonify({"error": f"Unsupported file format. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
# Process the file
filename = secure_filename(file.filename)
file_content = file.read()
file_ext = filename.rsplit(".", 1)[1].lower()
try:
if file_ext == "pdf":
text = summarize_pdf(file_content)
elif file_ext == "docx":
text = summarize_docx(file_content)
elif file_ext == "pptx":
text = summarize_pptx(file_content)
elif file_ext == "txt":
text = summarize_txt(file_content)
else:
logger.error("Unsupported file format received.")
return jsonify({"error": "Unsupported file format"}), 400
# Generate summary
logger.info(f"Generating summary for {filename} with text length {len(text)}")
summary = summarize_text(text)
logger.info(f"File {filename} summarized successfully.")
return jsonify({
"filename": filename,
"summary": summary,
"textLength": len(text)
})
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return jsonify({"error": f"Error processing file: {str(e)}"}), 500
def summarize_pdf(file_content):
"""Extract text from PDF."""
try:
reader = PdfReader(io.BytesIO(file_content))
text = "\n".join([page.extract_text() or "" for page in reader.pages])
return text.strip()
except Exception as e:
logger.error(f"Error extracting text from PDF: {str(e)}")
raise Exception(f"Failed to extract text from PDF: {str(e)}")
def summarize_docx(file_content):
"""Extract text from DOCX."""
try:
doc = Document(io.BytesIO(file_content))
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
return text.strip()
except Exception as e:
logger.error(f"Error extracting text from DOCX: {str(e)}")
raise Exception(f"Failed to extract text from DOCX: {str(e)}")
def summarize_pptx(file_content):
"""Extract text from PPTX."""
try:
ppt = Presentation(io.BytesIO(file_content))
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
text.append(shape.text.strip())
return "\n".join(text).strip()
except Exception as e:
logger.error(f"Error extracting text from PPTX: {str(e)}")
raise Exception(f"Failed to extract text from PPTX: {str(e)}")
def summarize_txt(file_content):
"""Extract text from TXT file."""
try:
return file_content.decode("utf-8").strip()
except UnicodeDecodeError:
# Try different encodings if UTF-8 fails
try:
return file_content.decode("latin-1").strip()
except Exception as e:
logger.error(f"Error decoding text file: {str(e)}")
raise Exception(f"Failed to decode text file: {str(e)}")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True) |