Spaces:
Sleeping
Sleeping
import torch | |
import pdfplumber | |
import pytesseract | |
from PIL import Image | |
from docx import Document | |
from pptx import Presentation | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from flask import Flask, request, jsonify | |
# Optimize for CPU | |
torch.set_num_threads(4) # Adjust based on CPU cores | |
device = torch.device("cpu") | |
# Load T5-Base model | |
model_name = "t5-base" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device) | |
# Flask App | |
app = Flask(__name__) | |
# Function to extract text from files | |
def extract_text(file): | |
filename = file.filename.lower() | |
if filename.endswith(".pdf"): | |
with pdfplumber.open(file) as pdf: | |
return " ".join([page.extract_text() for page in pdf.pages if page.extract_text()]) | |
elif filename.endswith(".docx"): | |
doc = Document(file) | |
return " ".join([para.text for para in doc.paragraphs]) | |
elif filename.endswith(".pptx"): | |
prs = Presentation(file) | |
return " ".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]) | |
elif filename.endswith((".png", ".jpg", ".jpeg")): | |
image = Image.open(file) | |
return pytesseract.image_to_string(image) | |
return None | |
def summarize(): | |
file = request.files.get("file") | |
if not file: | |
return jsonify({"error": "No file uploaded"}), 400 | |
text = extract_text(file) | |
if not text: | |
return jsonify({"error": "No text found in file"}), 400 | |
# Format text for T5 | |
input_text = "summarize: " + text.strip() | |
# Tokenize input | |
inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=512).to(device) | |
# Generate summary | |
with torch.no_grad(): | |
summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return jsonify({"summary": summary}) | |
if __name__ == "__main__": | |
print("π API is running on port 7860") | |
app.run(host="0.0.0.0", port=7860) | |