benkada commited on
Commit
91fed73
·
verified ·
1 Parent(s): 89d20c1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +88 -42
main.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import io
3
  from fastapi import FastAPI, UploadFile, File, Form
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import JSONResponse, HTMLResponse
@@ -7,13 +6,17 @@ from huggingface_hub import InferenceClient
7
  from PyPDF2 import PdfReader
8
  from docx import Document
9
  from PIL import Image
 
10
  from io import BytesIO
 
11
 
12
- # Load Hugging Face Token securely
13
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
14
 
 
15
  app = FastAPI()
16
 
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
@@ -22,22 +25,31 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Initialize Hugging Face clients
26
  summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
27
  qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
28
  image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
29
 
 
30
  def extract_text_from_pdf(content: bytes) -> str:
 
31
  reader = PdfReader(io.BytesIO(content))
32
- return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
 
 
 
33
 
34
  def extract_text_from_docx(content: bytes) -> str:
 
35
  doc = Document(io.BytesIO(content))
36
- return "\n".join(para.text for para in doc.paragraphs).strip()
 
 
37
 
38
  def process_uploaded_file(file: UploadFile) -> str:
39
  content = file.file.read()
40
  extension = file.filename.split('.')[-1].lower()
 
41
  if extension == "pdf":
42
  return extract_text_from_pdf(content)
43
  elif extension == "docx":
@@ -45,76 +57,110 @@ def process_uploaded_file(file: UploadFile) -> str:
45
  elif extension == "txt":
46
  return content.decode("utf-8").strip()
47
  else:
48
- raise ValueError("Unsupported file type.")
49
 
 
50
  @app.get("/", response_class=HTMLResponse)
51
  async def serve_homepage():
52
  with open("index.html", "r", encoding="utf-8") as f:
53
  return HTMLResponse(content=f.read(), status_code=200)
54
 
55
- @app.post("/api/summarize")
56
- async def summarize_document(file: UploadFile = File(...)):
 
57
  try:
58
  text = process_uploaded_file(file)
59
- if len(text) < 20:
60
- return {"result": "Document too short to summarize."}
61
- summary = summary_client.summarization(text[:3000])
62
- return {"result": summary}
63
- except Exception as e:
64
- return JSONResponse(status_code=500, content={"error": str(e)})
65
 
66
- @app.post("/api/caption")
67
- async def caption_image(file: UploadFile = File(...)):
68
- try:
69
- image_bytes = await file.read()
70
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
71
- image_pil.thumbnail((1024, 1024))
72
- img_byte_arr = BytesIO()
73
- image_pil.save(img_byte_arr, format='JPEG')
74
- img_byte_arr = img_byte_arr.getvalue()
75
- result = image_caption_client.image_to_text(img_byte_arr)
76
 
77
- if isinstance(result, dict):
78
- caption = result.get("generated_text") or result.get("caption") or "No caption found."
79
- elif isinstance(result, list) and result:
80
- caption = result[0].get("generated_text", "No caption found.")
81
- elif isinstance(result, str):
82
- caption = result
83
- else:
84
- caption = "No caption found."
85
 
86
- return {"result": caption}
87
  except Exception as e:
88
- return JSONResponse(status_code=500, content={"error": str(e)})
89
 
90
- @app.post("/api/qa")
91
- async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
92
  try:
 
93
  content_type = file.content_type
94
  if content_type.startswith("image/"):
95
  image_bytes = await file.read()
96
  image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
97
  image_pil.thumbnail((1024, 1024))
 
98
  img_byte_arr = BytesIO()
99
  image_pil.save(img_byte_arr, format='JPEG')
100
  img_byte_arr = img_byte_arr.getvalue()
 
 
101
  result = image_caption_client.image_to_text(img_byte_arr)
102
- context = result.get("generated_text") if isinstance(result, dict) else result
 
 
 
 
 
 
 
 
103
  else:
 
104
  text = process_uploaded_file(file)
105
  if len(text) < 20:
106
- return {"result": "Document too short to answer questions."}
107
  context = text[:3000]
108
 
109
  if not context:
110
- return {"result": "No context available to answer."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- answer = qa_client.question_answering(question=question, context=context)
113
- return {"result": answer.get("answer", "No answer found.")}
114
 
115
  except Exception as e:
116
- return JSONResponse(status_code=500, content={"error": str(e)})
117
 
 
118
  if __name__ == "__main__":
119
  import uvicorn
120
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
1
  import os
 
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse, HTMLResponse
 
6
  from PyPDF2 import PdfReader
7
  from docx import Document
8
  from PIL import Image
9
+ import io
10
  from io import BytesIO
11
+ import requests
12
 
13
+ # Remplace ce token par le tien de manière sécurisée (variable d'environnement recommandée en production)
14
  HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
15
 
16
+ # Initialisation de l'app FastAPI
17
  app = FastAPI()
18
 
19
+ # Autoriser les requêtes Cross-Origin
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
 
25
  allow_headers=["*"],
26
  )
27
 
28
+ # Initialisation des clients Hugging Face avec authentification
29
  summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
30
  qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
31
  image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
32
 
33
+ # Extraction du texte des fichiers
34
  def extract_text_from_pdf(content: bytes) -> str:
35
+ text = ""
36
  reader = PdfReader(io.BytesIO(content))
37
+ for page in reader.pages:
38
+ if page.extract_text():
39
+ text += page.extract_text() + "\n"
40
+ return text.strip()
41
 
42
  def extract_text_from_docx(content: bytes) -> str:
43
+ text = ""
44
  doc = Document(io.BytesIO(content))
45
+ for para in doc.paragraphs:
46
+ text += para.text + "\n"
47
+ return text.strip()
48
 
49
  def process_uploaded_file(file: UploadFile) -> str:
50
  content = file.file.read()
51
  extension = file.filename.split('.')[-1].lower()
52
+
53
  if extension == "pdf":
54
  return extract_text_from_pdf(content)
55
  elif extension == "docx":
 
57
  elif extension == "txt":
58
  return content.decode("utf-8").strip()
59
  else:
60
+ raise ValueError("Type de fichier non supporté")
61
 
62
+ # Point d'entrée HTML
63
  @app.get("/", response_class=HTMLResponse)
64
  async def serve_homepage():
65
  with open("index.html", "r", encoding="utf-8") as f:
66
  return HTMLResponse(content=f.read(), status_code=200)
67
 
68
+ # Résumé
69
+ @app.post("/analyze")
70
+ async def analyze_file(file: UploadFile = File(...)):
71
  try:
72
  text = process_uploaded_file(file)
 
 
 
 
 
 
73
 
74
+ if len(text) < 20:
75
+ return {"summary": "Document trop court pour être résumé."}
 
 
 
 
 
 
 
 
76
 
77
+ summary = summary_client.summarization(text[:3000])
78
+ return {"summary": summary}
 
 
 
 
 
 
79
 
 
80
  except Exception as e:
81
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'analyse: {str(e)}"})
82
 
83
+ # Question-Réponse
84
+ @app.post("/ask")
85
+ async def ask_question(file: UploadFile = File(...), question: str = Form(...)):
86
  try:
87
+ # Determine if the file is an image
88
  content_type = file.content_type
89
  if content_type.startswith("image/"):
90
  image_bytes = await file.read()
91
  image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
92
  image_pil.thumbnail((1024, 1024))
93
+
94
  img_byte_arr = BytesIO()
95
  image_pil.save(img_byte_arr, format='JPEG')
96
  img_byte_arr = img_byte_arr.getvalue()
97
+
98
+ # Generate image description
99
  result = image_caption_client.image_to_text(img_byte_arr)
100
+ if isinstance(result, dict):
101
+ context = result.get("generated_text") or result.get("caption") or ""
102
+ elif isinstance(result, list) and len(result) > 0:
103
+ context = result[0].get("generated_text", "")
104
+ elif isinstance(result, str):
105
+ context = result
106
+ else:
107
+ context = ""
108
+
109
  else:
110
+ # Not an image, process as document
111
  text = process_uploaded_file(file)
112
  if len(text) < 20:
113
+ return {"answer": "Document trop court pour répondre à la question."}
114
  context = text[:3000]
115
 
116
  if not context:
117
+ return {"answer": "Aucune information disponible pour répondre à la question."}
118
+
119
+ result = qa_client.question_answering(question=question, context=context)
120
+ return {"answer": result.get("answer", "Aucune réponse trouvée.")}
121
+
122
+ except Exception as e:
123
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de la recherche de réponse: {str(e)}"})
124
+
125
+ # Interprétation d'Image
126
+ @app.post("/interpret_image")
127
+ async def interpret_image(image: UploadFile = File(...)):
128
+ try:
129
+ # Lire l'image
130
+ image_bytes = await image.read()
131
+
132
+ # Ouvrir l'image avec PIL
133
+ image_pil = Image.open(io.BytesIO(image_bytes))
134
+ image_pil = image_pil.convert("RGB")
135
+ image_pil.thumbnail((1024, 1024))
136
+
137
+ # Convertir en bytes (JPEG)
138
+ img_byte_arr = BytesIO()
139
+ image_pil.save(img_byte_arr, format='JPEG')
140
+ img_byte_arr = img_byte_arr.getvalue()
141
+
142
+ # Appeler le modèle
143
+ result = image_caption_client.image_to_text(img_byte_arr)
144
+
145
+ # 🔍 Affichage du résultat brut pour débogage
146
+ print("Résultat brut du modèle image-to-text:", result)
147
+
148
+ # Extraire la description si disponible
149
+ if isinstance(result, dict):
150
+ description = result.get("generated_text") or result.get("caption") or "Description non trouvée."
151
+ elif isinstance(result, list) and len(result) > 0:
152
+ description = result[0].get("generated_text", "Description non trouvée.")
153
+ elif isinstance(result, str):
154
+ description = result
155
+ else:
156
+ description = "Description non trouvée."
157
 
158
+ return {"description": description}
 
159
 
160
  except Exception as e:
161
+ return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'interprétation de l'image: {str(e)}"})
162
 
163
+ # Démarrage local
164
  if __name__ == "__main__":
165
  import uvicorn
166
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)