benkada commited on
Commit
0133631
·
verified ·
1 Parent(s): 57d09d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -37
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os, io
 
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
@@ -12,13 +13,13 @@ from io import BytesIO
12
  # -----------------------------------------------------------------------------
13
  # CONFIGURATION
14
  # -----------------------------------------------------------------------------
15
- HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") # injected as a secret in HF Spaces
16
- PORT = int(os.getenv("PORT", 7860)) # default for local, HF Spaces overrides
17
 
18
  app = FastAPI(
19
- title="AIPowered WebApp API",
20
- description="Backend endpoints for summarisation, captioning and QA",
21
- version="1.2.0",
22
  )
23
 
24
  app.add_middleware(
@@ -29,21 +30,22 @@ app.add_middleware(
29
  allow_headers=["*"],
30
  )
31
 
32
- # Serve optional static assets **only if the folder exists**
33
- from pathlib import Path
 
34
  static_dir = Path("static")
35
  if static_dir.exists():
36
- app.mount("/static", StaticFiles(directory="static"), name="static"), name="static")
37
 
38
  # -----------------------------------------------------------------------------
39
- # MODEL CLIENTS (remote Hugging Face Inference API)
40
  # -----------------------------------------------------------------------------
41
  summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
42
  qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
43
  image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
44
 
45
  # -----------------------------------------------------------------------------
46
- # UTILITY FUNCTIONS
47
  # -----------------------------------------------------------------------------
48
 
49
  def extract_text_from_pdf(content: bytes) -> str:
@@ -56,12 +58,12 @@ def extract_text_from_docx(content: bytes) -> str:
56
 
57
  def process_uploaded_file(file: UploadFile) -> str:
58
  content = file.file.read()
59
- extension = file.filename.split(".")[-1].lower()
60
- if extension == "pdf":
61
  return extract_text_from_pdf(content)
62
- if extension == "docx":
63
  return extract_text_from_docx(content)
64
- if extension == "txt":
65
  return content.decode("utf-8").strip()
66
  raise ValueError("Unsupported file type")
67
 
@@ -71,10 +73,10 @@ def process_uploaded_file(file: UploadFile) -> str:
71
 
72
  @app.get("/", response_class=HTMLResponse)
73
  async def serve_index():
74
- """Serve the frontend HTML file."""
75
  return FileResponse("index.html")
76
 
77
- # ---------- Summarisation -----------------------------------------------------
78
 
79
  @app.post("/api/summarize")
80
  async def summarize_document(file: UploadFile = File(...)):
@@ -84,7 +86,6 @@ async def summarize_document(file: UploadFile = File(...)):
84
  return {"result": "Document too short to summarise."}
85
 
86
  summary_raw = summary_client.summarization(text[:3000])
87
- # Normalise to plain string
88
  if isinstance(summary_raw, list):
89
  summary_txt = summary_raw[0].get("summary_text", str(summary_raw))
90
  elif isinstance(summary_raw, dict):
@@ -96,39 +97,36 @@ async def summarize_document(file: UploadFile = File(...)):
96
  except Exception as exc:
97
  return JSONResponse(status_code=500, content={"error": f"Summarisation failure: {exc}"})
98
 
99
- # ---------- Image Caption -----------------------------------------------------
100
 
101
  @app.post("/api/caption")
102
  async def caption_image(file: UploadFile = File(...)):
103
  try:
104
- image_bytes = await file.read()
105
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
106
- image_pil.thumbnail((1024, 1024))
107
- buf = BytesIO(); image_pil.save(buf, format="JPEG"); img = buf.getvalue()
108
-
109
- result = image_caption_client.image_to_text(img)
110
  if isinstance(result, dict):
111
  caption = result.get("generated_text") or result.get("caption") or "No caption found."
112
  elif isinstance(result, list):
113
  caption = result[0].get("generated_text", "No caption found.")
114
  else:
115
  caption = str(result)
116
-
117
  return {"result": caption}
118
  except Exception as exc:
119
  return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
120
 
121
- # ---------- Question Answering ----------------------------------------------
122
 
123
  @app.post("/api/qa")
124
  async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
125
  try:
126
- # If it's an image, first caption it to build context
127
  if file.content_type.startswith("image/"):
128
- image_bytes = await file.read()
129
- pil = Image.open(io.BytesIO(image_bytes)).convert("RGB"); pil.thumbnail((1024, 1024))
130
- b = BytesIO(); pil.save(b, format="JPEG"); img = b.getvalue()
131
- res = image_caption_client.image_to_text(img)
132
  context = res.get("generated_text") if isinstance(res, dict) else str(res)
133
  else:
134
  context = process_uploaded_file(file)[:3000]
@@ -141,15 +139,11 @@ async def question_answering(file: UploadFile = File(...), question: str = Form(
141
  except Exception as exc:
142
  return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
143
 
144
- # ---------- Health check ------------------------------------------------------
145
 
146
  @app.get("/api/health")
147
  async def health():
148
- return {
149
- "status": "healthy",
150
- "hf_token_set": bool(HUGGINGFACE_TOKEN),
151
- "version": app.version,
152
- }
153
 
154
  # -----------------------------------------------------------------------------
155
  # ENTRYPOINT
 
1
  import os, io
2
+ from pathlib import Path
3
  from fastapi import FastAPI, UploadFile, File, Form
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
 
13
  # -----------------------------------------------------------------------------
14
  # CONFIGURATION
15
  # -----------------------------------------------------------------------------
16
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") # set in HF Space secrets or env
17
+ PORT = int(os.getenv("PORT", 7860)) # Spaces auto-set PORT; default 7860 locally
18
 
19
  app = FastAPI(
20
+ title="AI-Powered Web-App API",
21
+ description="Backend for summarisation, captioning & QA",
22
+ version="1.2.1",
23
  )
24
 
25
  app.add_middleware(
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ # -----------------------------------------------------------------------------
34
+ # OPTIONAL STATIC FILES (only if ./static exists)
35
+ # -----------------------------------------------------------------------------
36
  static_dir = Path("static")
37
  if static_dir.exists():
38
+ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
39
 
40
  # -----------------------------------------------------------------------------
41
+ # HUGGING FACE INFERENCE CLIENTS
42
  # -----------------------------------------------------------------------------
43
  summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
44
  qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
45
  image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
46
 
47
  # -----------------------------------------------------------------------------
48
+ # UTILITIES
49
  # -----------------------------------------------------------------------------
50
 
51
  def extract_text_from_pdf(content: bytes) -> str:
 
58
 
59
  def process_uploaded_file(file: UploadFile) -> str:
60
  content = file.file.read()
61
+ ext = file.filename.split(".")[-1].lower()
62
+ if ext == "pdf":
63
  return extract_text_from_pdf(content)
64
+ if ext == "docx":
65
  return extract_text_from_docx(content)
66
+ if ext == "txt":
67
  return content.decode("utf-8").strip()
68
  raise ValueError("Unsupported file type")
69
 
 
73
 
74
  @app.get("/", response_class=HTMLResponse)
75
  async def serve_index():
76
+ """Return the frontend HTML page."""
77
  return FileResponse("index.html")
78
 
79
+ # -------------------- Summarisation ------------------------------------------
80
 
81
  @app.post("/api/summarize")
82
  async def summarize_document(file: UploadFile = File(...)):
 
86
  return {"result": "Document too short to summarise."}
87
 
88
  summary_raw = summary_client.summarization(text[:3000])
 
89
  if isinstance(summary_raw, list):
90
  summary_txt = summary_raw[0].get("summary_text", str(summary_raw))
91
  elif isinstance(summary_raw, dict):
 
97
  except Exception as exc:
98
  return JSONResponse(status_code=500, content={"error": f"Summarisation failure: {exc}"})
99
 
100
+ # -------------------- Image Caption -----------------------------------------
101
 
102
  @app.post("/api/caption")
103
  async def caption_image(file: UploadFile = File(...)):
104
  try:
105
+ img_bytes = await file.read()
106
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
107
+ img.thumbnail((1024, 1024))
108
+ buf = BytesIO(); img.save(buf, format="JPEG")
109
+ result = image_caption_client.image_to_text(buf.getvalue())
 
110
  if isinstance(result, dict):
111
  caption = result.get("generated_text") or result.get("caption") or "No caption found."
112
  elif isinstance(result, list):
113
  caption = result[0].get("generated_text", "No caption found.")
114
  else:
115
  caption = str(result)
 
116
  return {"result": caption}
117
  except Exception as exc:
118
  return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
119
 
120
+ # -------------------- Question Answering ------------------------------------
121
 
122
  @app.post("/api/qa")
123
  async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
124
  try:
 
125
  if file.content_type.startswith("image/"):
126
+ img_bytes = await file.read()
127
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB"); img.thumbnail((1024, 1024))
128
+ b = BytesIO(); img.save(b, format="JPEG")
129
+ res = image_caption_client.image_to_text(b.getvalue())
130
  context = res.get("generated_text") if isinstance(res, dict) else str(res)
131
  else:
132
  context = process_uploaded_file(file)[:3000]
 
139
  except Exception as exc:
140
  return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
141
 
142
+ # -------------------- Health -------------------------------------------------
143
 
144
  @app.get("/api/health")
145
  async def health():
146
+ return {"status": "healthy", "hf_token_set": bool(HUGGINGFACE_TOKEN), "version": app.version}
 
 
 
 
147
 
148
  # -----------------------------------------------------------------------------
149
  # ENTRYPOINT