benkada commited on
Commit
2dd1f0a
Β·
verified Β·
1 Parent(s): a8335ea

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +59 -27
main.py CHANGED
@@ -1,5 +1,6 @@
1
  import os, io
2
  from pathlib import Path
 
3
  from fastapi import FastAPI, UploadFile, File, Form
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
@@ -14,20 +15,20 @@ from io import BytesIO
14
  # CONFIGURATION
15
  # -----------------------------------------------------------------------------
16
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
17
- PORT = int(os.getenv("PORT", 7860))
18
 
19
  app = FastAPI(
20
- title="AI‑Powered Web‑App API",
21
- description="Backend for summarisation, captioning & QA",
22
- version="1.2.2",
23
  )
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
27
- allow_origins=["*"],
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
  )
32
 
33
  # -----------------------------------------------------------------------------
@@ -40,14 +41,31 @@ if static_dir.exists():
40
  # -----------------------------------------------------------------------------
41
  # HUGGING FACE INFERENCE CLIENTS
42
  # -----------------------------------------------------------------------------
43
- summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
44
- qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
45
- image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------------------------------------------------------------
48
  # UTILITIES
49
  # -----------------------------------------------------------------------------
50
-
51
  def extract_text_from_pdf(content: bytes) -> str:
52
  reader = PdfReader(io.BytesIO(content))
53
  return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
@@ -58,7 +76,7 @@ def extract_text_from_docx(content: bytes) -> str:
58
 
59
  def process_uploaded_file(file: UploadFile) -> str:
60
  content = file.file.read()
61
- ext = file.filename.split(".")[-1].lower()
62
  if ext == "pdf":
63
  return extract_text_from_pdf(content)
64
  if ext == "docx":
@@ -70,7 +88,6 @@ def process_uploaded_file(file: UploadFile) -> str:
70
  # -----------------------------------------------------------------------------
71
  # ROUTES
72
  # -----------------------------------------------------------------------------
73
-
74
  @app.get("/", response_class=HTMLResponse)
75
  async def serve_index():
76
  return FileResponse("index.html")
@@ -90,51 +107,66 @@ async def summarize_document(file: UploadFile = File(...)):
90
  )
91
  return {"result": summary_txt}
92
  except Exception as exc:
93
- return JSONResponse(status_code=500, content={"error": f"Summarisation failure: {exc}"})
 
 
94
 
95
- # -------------------- Image Caption -----------------------------------------
96
  @app.post("/api/caption")
97
  async def caption_image(image: UploadFile = File(...)):
98
  """`image` field name matches frontend (was `file` before)."""
99
  try:
100
  img_bytes = await image.read()
101
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
102
  img.thumbnail((1024, 1024))
103
  buf = BytesIO(); img.save(buf, format="JPEG")
104
  result = image_caption_client.image_to_text(buf.getvalue())
105
  if isinstance(result, dict):
106
- caption = result.get("generated_text") or result.get("caption") or "No caption found."
 
 
107
  elif isinstance(result, list):
108
  caption = result[0].get("generated_text", "No caption found.")
109
  else:
110
  caption = str(result)
111
  return {"result": caption}
112
  except Exception as exc:
113
- return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
 
 
114
 
115
- # -------------------- Question Answering ------------------------------------
116
  @app.post("/api/qa")
117
- async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
118
  try:
119
  if file.content_type.startswith("image/"):
120
  img_bytes = await file.read()
121
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB"); img.thumbnail((1024, 1024))
 
122
  buf = BytesIO(); img.save(buf, format="JPEG")
123
- res = image_caption_client.image_to_text(buf.getvalue())
124
- context = res.get("generated_text") if isinstance(res, dict) else str(res)
 
125
  else:
126
  context = process_uploaded_file(file)[:3000]
 
127
  if not context:
128
  return {"result": "No context – cannot answer."}
 
129
  answer = qa_client.question_answering(question=question, context=context)
130
  return {"result": answer.get("answer", "No answer found.")}
131
  except Exception as exc:
132
- return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
 
 
133
 
134
- # -------------------- Health -------------------------------------------------
135
  @app.get("/api/health")
136
  async def health():
137
- return {"status": "healthy", "hf_token_set": bool(HUGGINGFACE_TOKEN), "version": app.version}
 
 
138
 
139
  # -----------------------------------------------------------------------------
140
  # ENTRYPOINT
 
1
  import os, io
2
  from pathlib import Path
3
+
4
  from fastapi import FastAPI, UploadFile, File, Form
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
 
15
  # CONFIGURATION
16
  # -----------------------------------------------------------------------------
17
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
18
+ PORT = int(os.getenv("PORT", 7860))
19
 
20
  app = FastAPI(
21
+ title = "AI-Powered Web-App API",
22
+ description = "Backend for summarisation, captioning & QA",
23
+ version = "1.2.3", # <-- bumped
24
  )
25
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
+ allow_origins = ["*"],
29
+ allow_credentials = True,
30
+ allow_methods = ["*"],
31
+ allow_headers = ["*"],
32
  )
33
 
34
  # -----------------------------------------------------------------------------
 
41
  # -----------------------------------------------------------------------------
42
  # HUGGING FACE INFERENCE CLIENTS
43
  # -----------------------------------------------------------------------------
44
+ summary_client = InferenceClient(
45
+ "facebook/bart-large-cnn",
46
+ token = HUGGINGFACE_TOKEN,
47
+ timeout = 120,
48
+ )
49
+
50
+ # ➜ Upgraded QA model (higher accuracy than roberta-base)
51
+ qa_client = InferenceClient(
52
+ "deepset/roberta-large-squad2",
53
+ token = HUGGINGFACE_TOKEN,
54
+ timeout = 120,
55
+ )
56
+ # If you need multilingual support, swap for:
57
+ # qa_client = InferenceClient("deepset/xlm-roberta-large-squad2",
58
+ # token=HUGGINGFACE_TOKEN, timeout=120)
59
+
60
+ image_caption_client = InferenceClient(
61
+ "nlpconnect/vit-gpt2-image-captioning",
62
+ token = HUGGINGFACE_TOKEN,
63
+ timeout = 60,
64
+ )
65
 
66
  # -----------------------------------------------------------------------------
67
  # UTILITIES
68
  # -----------------------------------------------------------------------------
 
69
  def extract_text_from_pdf(content: bytes) -> str:
70
  reader = PdfReader(io.BytesIO(content))
71
  return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
 
76
 
77
  def process_uploaded_file(file: UploadFile) -> str:
78
  content = file.file.read()
79
+ ext = file.filename.split(".")[-1].lower()
80
  if ext == "pdf":
81
  return extract_text_from_pdf(content)
82
  if ext == "docx":
 
88
  # -----------------------------------------------------------------------------
89
  # ROUTES
90
  # -----------------------------------------------------------------------------
 
91
  @app.get("/", response_class=HTMLResponse)
92
  async def serve_index():
93
  return FileResponse("index.html")
 
107
  )
108
  return {"result": summary_txt}
109
  except Exception as exc:
110
+ return JSONResponse(status_code=500,
111
+ content={"error": f"Summarisation failure: {exc}"})
112
+
113
 
114
+ # -------------------- Image Caption ------------------------------------------
115
  @app.post("/api/caption")
116
  async def caption_image(image: UploadFile = File(...)):
117
  """`image` field name matches frontend (was `file` before)."""
118
  try:
119
  img_bytes = await image.read()
120
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
121
  img.thumbnail((1024, 1024))
122
  buf = BytesIO(); img.save(buf, format="JPEG")
123
  result = image_caption_client.image_to_text(buf.getvalue())
124
  if isinstance(result, dict):
125
+ caption = (result.get("generated_text")
126
+ or result.get("caption")
127
+ or "No caption found.")
128
  elif isinstance(result, list):
129
  caption = result[0].get("generated_text", "No caption found.")
130
  else:
131
  caption = str(result)
132
  return {"result": caption}
133
  except Exception as exc:
134
+ return JSONResponse(status_code=500,
135
+ content={"error": f"Caption failure: {exc}"})
136
+
137
 
138
+ # -------------------- Question Answering -------------------------------------
139
  @app.post("/api/qa")
140
+ async def question_answering(file: UploadFile = File(...),
141
+ question: str = Form(...)):
142
  try:
143
  if file.content_type.startswith("image/"):
144
  img_bytes = await file.read()
145
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
146
+ img.thumbnail((1024, 1024))
147
  buf = BytesIO(); img.save(buf, format="JPEG")
148
+ res = image_caption_client.image_to_text(buf.getvalue())
149
+ context = (res.get("generated_text") if isinstance(res, dict)
150
+ else str(res))
151
  else:
152
  context = process_uploaded_file(file)[:3000]
153
+
154
  if not context:
155
  return {"result": "No context – cannot answer."}
156
+
157
  answer = qa_client.question_answering(question=question, context=context)
158
  return {"result": answer.get("answer", "No answer found.")}
159
  except Exception as exc:
160
+ return JSONResponse(status_code=500,
161
+ content={"error": f"QA failure: {exc}"})
162
+
163
 
164
+ # -------------------- Health --------------------------------------------------
165
  @app.get("/api/health")
166
  async def health():
167
+ return {"status": "healthy",
168
+ "hf_token_set": bool(HUGGINGFACE_TOKEN),
169
+ "version": app.version}
170
 
171
  # -----------------------------------------------------------------------------
172
  # ENTRYPOINT