benkada commited on
Commit
cbdf3eb
·
verified ·
1 Parent(s): 6581e65

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +132 -85
main.py CHANGED
@@ -1,26 +1,28 @@
1
- import os, io
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from huggingface_hub import InferenceClient
7
  from PyPDF2 import PdfReader
8
  from docx import Document
9
  from PIL import Image
 
10
  from io import BytesIO
 
 
11
 
12
- # -----------------------------------------------------------------------------
13
- # CONFIGURATION
14
- # -----------------------------------------------------------------------------
15
- HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") # injected as a secret
16
- PORT = int(os.getenv("PORT", 7860)) # HF Spaces provides it
17
 
18
  app = FastAPI(
19
- title="AI‑Powered WebApp API",
20
- description="Backend endpoints for summarisation, captioning and QA",
21
- version="1.1.0",
22
  )
23
 
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"],
@@ -29,111 +31,156 @@ app.add_middleware(
29
  allow_headers=["*"],
30
  )
31
 
32
- # Optional: serve static assets from /static (images, css, js)
33
- app.mount("/static", StaticFiles(directory="static"), name="static")
34
 
35
- # -----------------------------------------------------------------------------
36
- # MODEL CLIENTS (remote HuggingFace Inference API)
37
- # -----------------------------------------------------------------------------
38
- summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
39
- qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
40
- image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
41
 
42
- # -----------------------------------------------------------------------------
43
- # UTILITY FUNCTIONS
44
- # -----------------------------------------------------------------------------
 
45
 
 
46
  def extract_text_from_pdf(content: bytes) -> str:
 
47
  reader = PdfReader(io.BytesIO(content))
48
- return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
 
 
 
49
 
50
  def extract_text_from_docx(content: bytes) -> str:
 
51
  doc = Document(io.BytesIO(content))
52
- return "\n".join(p.text for p in doc.paragraphs).strip()
 
 
53
 
54
  def process_uploaded_file(file: UploadFile) -> str:
55
- content = file.file.read()
56
- extension = file.filename.split(".")[-1].lower()
 
57
  if extension == "pdf":
58
  return extract_text_from_pdf(content)
59
- if extension == "docx":
60
  return extract_text_from_docx(content)
61
- if extension == "txt":
62
  return content.decode("utf-8").strip()
63
- raise ValueError("Unsupported file type")
64
-
65
- # -----------------------------------------------------------------------------
66
- # ROUTES
67
- # -----------------------------------------------------------------------------
68
 
 
69
  @app.get("/", response_class=HTMLResponse)
70
- async def serve_index():
71
- """Send the frontend HTML."""
72
- return FileResponse("index.html")
73
-
74
- # ---------- Summarisation -----------------------------------------------------
75
 
76
- @app.post("/api/summarize")
77
- async def summarize_document(file: UploadFile = File(...)):
 
78
  try:
79
  text = process_uploaded_file(file)
80
- if len(text) < 20:
81
- return {"result": "Document too short to summarise."}
82
- summary_text = summary_client.summarization(text[:3000])
83
- return {"result": str(summary_text)}
84
- except Exception as exc:
85
- return JSONResponse(status_code=500, content={"error": f"Analyse failure: {exc}"})
86
 
87
- # ---------- Image Caption -----------------------------------------------------
 
88
 
89
- @app.post("/api/caption")
90
- async def caption_image(file: UploadFile = File(...)):
91
- try:
92
- image_bytes = await file.read()
93
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
94
- image_pil.thumbnail((1024, 1024))
95
- buf = BytesIO(); image_pil.save(buf, format="JPEG"); img = buf.getvalue()
96
- result = image_caption_client.image_to_text(img)
97
- if isinstance(result, dict):
98
- caption = result.get("generated_text") or result.get("caption") or "No caption found."
99
- elif isinstance(result, list):
100
- caption = result[0].get("generated_text", "No caption found.")
101
- else:
102
- caption = str(result)
103
- return {"result": str(caption)}
104
- except Exception as exc:
105
- return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
106
 
107
- # ---------- Question Answering ----------------------------------------------
 
108
 
109
- @app.post("/api/qa")
110
- async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
111
  try:
112
- if file.content_type.startswith("image/"):
 
 
113
  image_bytes = await file.read()
114
- pil = Image.open(io.BytesIO(image_bytes)).convert("RGB"); pil.thumbnail((1024, 1024))
115
- buf = BytesIO(); pil.save(buf, format="JPEG"); img = buf.getvalue()
116
- res = image_caption_client.image_to_text(img)
117
- context = res.get("generated_text") if isinstance(res, dict) else str(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  else:
119
- context = process_uploaded_file(file)[:3000]
 
 
 
 
 
120
  if not context:
121
- return {"result": "No context cannot answer."}
122
- answer = qa_client.question_answering(question=question, context=context)
123
- return {"result": str(answer.get("answer", "No answer found."))}
124
- except Exception as exc:
125
- return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
126
 
127
- # ---------- Health check ------------------------------------------------------
 
128
 
129
- @app.get("/api/health")
130
- async def health():
131
- return {"status": "healthy", "hf_token_set": bool(HUGGINGFACE_TOKEN)}
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # -----------------------------------------------------------------------------
134
- # ENTRYPOINT
135
- # -----------------------------------------------------------------------------
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
- uvicorn.run(app, host="0.0.0.0", port=PORT)
 
1
+ import os
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse, HTMLResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from huggingface_hub import InferenceClient
7
  from PyPDF2 import PdfReader
8
  from docx import Document
9
  from PIL import Image
10
+ import io
11
  from io import BytesIO
12
+ import requests
13
+ from routers import ai
14
 
15
+ # Get environment variables
16
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
17
+ PORT = int(os.getenv("PORT", 7860))
 
 
18
 
19
  app = FastAPI(
20
+ title="AI Web App API",
21
+ description="Backend API for AI-powered web application",
22
+ version="1.0.0"
23
  )
24
 
25
+ # Configure CORS
26
  app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
 
31
  allow_headers=["*"],
32
  )
33
 
34
+ # Serve static files
35
+ app.mount("/", StaticFiles(directory=".", html=True), name="static")
36
 
37
+ # Include routers
38
+ app.include_router(ai.router)
 
 
 
 
39
 
40
+ # Initialisation des clients Hugging Face avec authentification
41
+ summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
42
+ qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
43
+ image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
44
 
45
+ # Extraction du texte des fichiers
46
  def extract_text_from_pdf(content: bytes) -> str:
47
+ text = ""
48
  reader = PdfReader(io.BytesIO(content))
49
+ for page in reader.pages:
50
+ if page.extract_text():
51
+ text += page.extract_text() + "\n"
52
+ return text.strip()
53
 
54
  def extract_text_from_docx(content: bytes) -> str:
55
+ text = ""
56
  doc = Document(io.BytesIO(content))
57
+ for para in doc.paragraphs:
58
+ text += para.text + "\n"
59
+ return text.strip()
60
 
61
  def process_uploaded_file(file: UploadFile) -> str:
62
+ content = file.file.read()
63
+ extension = file.filename.split('.')[-1].lower()
64
+
65
  if extension == "pdf":
66
  return extract_text_from_pdf(content)
67
+ elif extension == "docx":
68
  return extract_text_from_docx(content)
69
+ elif extension == "txt":
70
  return content.decode("utf-8").strip()
71
+ else:
72
+ raise ValueError("Type de fichier non supporté")
 
 
 
73
 
74
+ # Point d'entrée HTML
75
  @app.get("/", response_class=HTMLResponse)
76
+ async def serve_homepage():
77
+ with open("index.html", "r", encoding="utf-8") as f:
78
+ return HTMLResponse(content=f.read(), status_code=200)
 
 
79
 
80
+ # Résumé
81
+ @app.post("/analyze")
82
+ async def analyze_file(file: UploadFile = File(...)):
83
  try:
84
  text = process_uploaded_file(file)
 
 
 
 
 
 
85
 
86
+ if len(text) < 20:
87
+ return {"summary": "Document trop court pour être résumé."}
88
 
89
+ summary = summary_client.summarization(text[:3000])
90
+ return {"summary": summary}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ except Exception as e:
93
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'analyse: {str(e)}"})
94
 
95
+ # Question-Réponse
96
+ @app.post("/ask")
97
+ async def ask_question(file: UploadFile = File(...), question: str = Form(...)):
98
  try:
99
+ # Determine if the file is an image
100
+ content_type = file.content_type
101
+ if content_type.startswith("image/"):
102
  image_bytes = await file.read()
103
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
104
+ image_pil.thumbnail((1024, 1024))
105
+
106
+ img_byte_arr = BytesIO()
107
+ image_pil.save(img_byte_arr, format='JPEG')
108
+ img_byte_arr = img_byte_arr.getvalue()
109
+
110
+ # Generate image description
111
+ result = image_caption_client.image_to_text(img_byte_arr)
112
+ if isinstance(result, dict):
113
+ context = result.get("generated_text") or result.get("caption") or ""
114
+ elif isinstance(result, list) and len(result) > 0:
115
+ context = result[0].get("generated_text", "")
116
+ elif isinstance(result, str):
117
+ context = result
118
+ else:
119
+ context = ""
120
+
121
  else:
122
+ # Not an image, process as document
123
+ text = process_uploaded_file(file)
124
+ if len(text) < 20:
125
+ return {"answer": "Document trop court pour répondre à la question."}
126
+ context = text[:3000]
127
+
128
  if not context:
129
+ return {"answer": "Aucune information disponible pour répondre à la question."}
 
 
 
 
130
 
131
+ result = qa_client.question_answering(question=question, context=context)
132
+ return {"answer": result.get("answer", "Aucune réponse trouvée.")}
133
 
134
+ except Exception as e:
135
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de la recherche de réponse: {str(e)}"})
136
+
137
+ # Interprétation d'Image
138
+ @app.post("/interpret_image")
139
+ async def interpret_image(image: UploadFile = File(...)):
140
+ try:
141
+ # Lire l'image
142
+ image_bytes = await image.read()
143
+
144
+ # Ouvrir l'image avec PIL
145
+ image_pil = Image.open(io.BytesIO(image_bytes))
146
+ image_pil = image_pil.convert("RGB")
147
+ image_pil.thumbnail((1024, 1024))
148
 
149
+ # Convertir en bytes (JPEG)
150
+ img_byte_arr = BytesIO()
151
+ image_pil.save(img_byte_arr, format='JPEG')
152
+ img_byte_arr = img_byte_arr.getvalue()
153
 
154
+ # Appeler le modèle
155
+ result = image_caption_client.image_to_text(img_byte_arr)
156
+
157
+ # 🔍 Affichage du résultat brut pour débogage
158
+ print("Résultat brut du modèle image-to-text:", result)
159
+
160
+ # Extraire la description si disponible
161
+ if isinstance(result, dict):
162
+ description = result.get("generated_text") or result.get("caption") or "Description non trouvée."
163
+ elif isinstance(result, list) and len(result) > 0:
164
+ description = result[0].get("generated_text", "Description non trouvée.")
165
+ elif isinstance(result, str):
166
+ description = result
167
+ else:
168
+ description = "Description non trouvée."
169
+
170
+ return {"description": description}
171
+
172
+ except Exception as e:
173
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'interprétation de l'image: {str(e)}"})
174
+
175
+ @app.get("/api/health")
176
+ async def health_check():
177
+ return {
178
+ "status": "healthy",
179
+ "version": "1.0.0",
180
+ "hf_token_set": bool(HUGGINGFACE_TOKEN)
181
+ }
182
+
183
+ # Démarrage local
184
  if __name__ == "__main__":
185
  import uvicorn
186
+ uvicorn.run(app, host="0.0.0.0", port=PORT)