File size: 8,488 Bytes
764d4f7
2a3fae3
80d0b8a
b4aa0e4
b7db40a
764d4f7
2a3fae3
764d4f7
2a3fae3
a911ba5
b4aa0e4
524f780
f8ea9fd
80d0b8a
 
 
dc17435
764d4f7
b4aa0e4
b7db40a
b4aa0e4
 
 
 
012dc5b
 
a911ba5
012dc5b
b4aa0e4
 
012dc5b
 
 
 
d2d0219
2a3fae3
d2d0219
2a3fae3
798ae00
2a3fae3
d2d0219
b46017c
 
 
 
 
 
 
 
 
 
 
 
a911ba5
dc17435
 
 
b46017c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a911ba5
 
b4aa0e4
92d0377
98e82be
 
dc17435
80d0b8a
dc17435
98e82be
3b4df89
2a3fae3
80d0b8a
b4aa0e4
 
 
 
 
 
3696c1f
764d4f7
b4aa0e4
 
3696c1f
dc17435
3696c1f
 
764d4f7
b4aa0e4
764d4f7
3696c1f
 
2a3fae3
80d0b8a
b4aa0e4
798ae00
3696c1f
764d4f7
2a3fae3
 
3696c1f
98e82be
 
a911ba5
98e82be
a911ba5
98e82be
a911ba5
98e82be
a911ba5
dc17435
3696c1f
dc17435
798ae00
3696c1f
b4aa0e4
dc17435
3696c1f
798ae00
b4aa0e4
 
 
 
 
3696c1f
98e82be
80d0b8a
98e82be
764d4f7
92d0377
dc17435
b4aa0e4
 
 
 
 
 
 
d2d0219
92d0377
dc17435
b4aa0e4
 
 
 
 
 
 
d2d0219
92d0377
dc17435
b4aa0e4
 
 
 
 
 
 
 
 
 
 
b7db40a
92d0377
dc17435
b4aa0e4
 
 
 
 
 
 
 
 
53425a8
9fd7d89
b4aa0e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import io
import logging
import tempfile
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from PyPDF2 import PdfReader
from docx import Document
from pptx import Presentation
from transformers import T5Tokenizer, T5ForConditionalGeneration
from flask_cors import CORS  # Import CORS for cross-origin requests

# Configure loggin
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Set up a temporary directory for Hugging Face cache
cache_dir = tempfile.mkdtemp()
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir

# Load T5 model and tokenizer
logger.info("Loading T5-Base model...")
try:
    tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=cache_dir)
    model = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=cache_dir)
    logger.info("T5-Base model loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load T5-Base: {str(e)}")
    raise

ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "txt"}

def allowed_file(filename):
    """Check if the uploaded file has an allowed extension."""
    return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS

def summarize_text(text, max_length=300, min_length=100):
    """
    Summarize text using T5-Base with improved parameters for more comprehensive summaries.
    
    Args:
        text (str): The text to summarize
        max_length (int): Maximum length of the summary (increased from 150)
        min_length (int): Minimum length of the summary (increased from 30)
        
    Returns:
        str: The generated summary
    """
    try:
        if not text.strip():
            return "No text found in the document to summarize."

        # Break text into chunks if it's very long
        chunks = []
        chunk_size = 4000  # Characters per chunk
        for i in range(0, len(text), chunk_size):
            chunks.append(text[i:i + chunk_size])
        
        summaries = []
        for i, chunk in enumerate(chunks):
            # Only process up to 5 chunks to avoid very long processing times
            if i >= 5:
                summaries.append("... (Document continues)")
                break
                
            input_text = "summarize: " + chunk
            inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
            
            # Generate with improved parameters
            summary_ids = model.generate(
                inputs["input_ids"],
                max_length=max_length // min(5, len(chunks)),  # Adjust max_length based on chunks
                min_length=min_length // min(5, len(chunks)),  # Adjust min_length based on chunks
                length_penalty=1.5,  # Reduced to avoid overly verbose summaries
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=3  # Avoid repeating trigrams
            )
            
            chunk_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            summaries.append(chunk_summary)
        
        # Combine summaries from all chunks
        combined_summary = " ".join(summaries)
        
        # For very short summaries, try again with the first chunk but longer output
        if len(combined_summary.split()) < 50 and chunks:
            input_text = "summarize: " + chunks[0]
            inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
            summary_ids = model.generate(
                inputs["input_ids"],
                max_length=max_length,
                min_length=min_length,
                length_penalty=2.0,
                num_beams=5,
                early_stopping=True,
                repetition_penalty=2.5  # Penalize repetition more heavily
            )
            combined_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            
        return combined_summary
    except Exception as e:
        logger.error(f"Error in T5 summarization: {str(e)}")
        return f"Error summarizing text: {str(e)}"

@app.route("/", methods=["GET"])
def index():
    """Root endpoint."""
    logger.info("Root endpoint accessed.")
    return "Document Summarizer API with T5-Base is running! Use /summarize endpoint for POST requests."

@app.route("/summarize", methods=["POST"])
def summarize():
    logger.info("Summarize endpoint called.")
    
    # Debug the incoming request
    logger.info(f"Request headers: {request.headers}")
    logger.info(f"Request files: {request.files}")
    logger.info(f"Request form: {request.form}")
    
    # Check if a file is in the request
    if "file" not in request.files:
        logger.error("No file found in request.files")
        return jsonify({"error": "No file uploaded. Make sure to use 'file' as the form field name."}), 400

    file = request.files["file"]

    # Check if file is empty
    if file.filename == "":
        logger.error("File has no filename")
        return jsonify({"error": "No selected file"}), 400

    # Check if file has an allowed extension
    if not allowed_file(file.filename):
        logger.error(f"Unsupported file format: {file.filename}")
        return jsonify({"error": f"Unsupported file format. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400

    # Process the file
    filename = secure_filename(file.filename)
    file_content = file.read()
    file_ext = filename.rsplit(".", 1)[1].lower()

    try:
        if file_ext == "pdf":
            text = summarize_pdf(file_content)
        elif file_ext == "docx":
            text = summarize_docx(file_content)
        elif file_ext == "pptx":
            text = summarize_pptx(file_content)
        elif file_ext == "txt":
            text = summarize_txt(file_content)
        else:
            logger.error("Unsupported file format received.")
            return jsonify({"error": "Unsupported file format"}), 400

        # Generate summary
        logger.info(f"Generating summary for {filename} with text length {len(text)}")
        summary = summarize_text(text)

        logger.info(f"File {filename} summarized successfully.")
        return jsonify({
            "filename": filename, 
            "summary": summary,
            "textLength": len(text) 
        })

    except Exception as e:
        logger.error(f"Error processing file {filename}: {str(e)}")
        return jsonify({"error": f"Error processing file: {str(e)}"}), 500

def summarize_pdf(file_content):
    """Extract text from PDF."""
    try:
        reader = PdfReader(io.BytesIO(file_content))
        text = "\n".join([page.extract_text() or "" for page in reader.pages])
        return text.strip()
    except Exception as e:
        logger.error(f"Error extracting text from PDF: {str(e)}")
        raise Exception(f"Failed to extract text from PDF: {str(e)}")

def summarize_docx(file_content):
    """Extract text from DOCX."""
    try:
        doc = Document(io.BytesIO(file_content))
        text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
        return text.strip()
    except Exception as e:
        logger.error(f"Error extracting text from DOCX: {str(e)}")
        raise Exception(f"Failed to extract text from DOCX: {str(e)}")

def summarize_pptx(file_content):
    """Extract text from PPTX."""
    try:
        ppt = Presentation(io.BytesIO(file_content))
        text = []
        for slide in ppt.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text") and shape.text.strip():
                    text.append(shape.text.strip())
        return "\n".join(text).strip()
    except Exception as e:
        logger.error(f"Error extracting text from PPTX: {str(e)}")
        raise Exception(f"Failed to extract text from PPTX: {str(e)}")

def summarize_txt(file_content):
    """Extract text from TXT file."""
    try:
        return file_content.decode("utf-8").strip()
    except UnicodeDecodeError:
        # Try different encodings if UTF-8 fails
        try:
            return file_content.decode("latin-1").strip()
        except Exception as e:
            logger.error(f"Error decoding text file: {str(e)}")
            raise Exception(f"Failed to decode text file: {str(e)}")

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=True)