mike23415 commited on
Commit
d2d0219
·
verified ·
1 Parent(s): 030be39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -20
app.py CHANGED
@@ -1,34 +1,70 @@
1
- import os
 
 
 
 
 
 
2
  from flask import Flask, request, jsonify
3
- from transformers import pipeline
4
 
5
- # Ensure HF doesn't request a token
6
- os.environ["HF_HOME"] = "/app/cache"
7
- os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
8
- os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
9
- os.environ["HF_HUB_OFFLINE"] = "0"
10
 
11
- # Load model
12
- summarizer = pipeline("summarization", model="t5-base")
 
 
13
 
 
14
  app = Flask(__name__)
15
 
16
- @app.route("/")
17
- def home():
18
- return "Summarization API is running!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @app.route("/summarize", methods=["POST"])
21
- def summarize_text():
22
- data = request.get_json()
23
- text = data.get("text", "")
24
- max_length = data.get("max_length", 50)
25
- min_length = data.get("min_length", 10)
26
 
 
 
 
 
27
  if not text:
28
- return jsonify({"error": "No text provided"}), 400
 
 
 
 
 
 
29
 
30
- summary = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
31
- return jsonify(summary)
 
 
 
 
 
32
 
33
  if __name__ == "__main__":
34
  print("🚀 API is running on port 7860")
 
1
+ import torch
2
+ import pdfplumber
3
+ import pytesseract
4
+ from PIL import Image
5
+ from docx import Document
6
+ from pptx import Presentation
7
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
8
  from flask import Flask, request, jsonify
 
9
 
10
+ # Optimize for CPU
11
+ torch.set_num_threads(4) # Adjust based on CPU cores
12
+ device = torch.device("cpu")
 
 
13
 
14
+ # Load T5-Base model
15
+ model_name = "t5-base"
16
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
17
+ model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
18
 
19
+ # Flask App
20
  app = Flask(__name__)
21
 
22
+ # Function to extract text from files
23
+ def extract_text(file):
24
+ filename = file.filename.lower()
25
+
26
+ if filename.endswith(".pdf"):
27
+ with pdfplumber.open(file) as pdf:
28
+ return " ".join([page.extract_text() for page in pdf.pages if page.extract_text()])
29
+
30
+ elif filename.endswith(".docx"):
31
+ doc = Document(file)
32
+ return " ".join([para.text for para in doc.paragraphs])
33
+
34
+ elif filename.endswith(".pptx"):
35
+ prs = Presentation(file)
36
+ return " ".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
37
+
38
+ elif filename.endswith((".png", ".jpg", ".jpeg")):
39
+ image = Image.open(file)
40
+ return pytesseract.image_to_string(image)
41
+
42
+ return None
43
 
44
  @app.route("/summarize", methods=["POST"])
45
+ def summarize():
46
+ file = request.files.get("file")
 
 
 
47
 
48
+ if not file:
49
+ return jsonify({"error": "No file uploaded"}), 400
50
+
51
+ text = extract_text(file)
52
  if not text:
53
+ return jsonify({"error": "No text found in file"}), 400
54
+
55
+ # Format text for T5
56
+ input_text = "summarize: " + text.strip()
57
+
58
+ # Tokenize input
59
+ inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
60
 
61
+ # Generate summary
62
+ with torch.no_grad():
63
+ summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4)
64
+
65
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
66
+
67
+ return jsonify({"summary": summary})
68
 
69
  if __name__ == "__main__":
70
  print("🚀 API is running on port 7860")