FatimaGr commited on
Commit
efd633c
·
verified ·
1 Parent(s): b52b283
Files changed (1) hide show
  1. app.py +248 -0
app.py CHANGED
@@ -2,8 +2,256 @@ from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  app = FastAPI()
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Servir les fichiers statiques (HTML, CSS, JS)
8
  app.mount("/static", StaticFiles(directory="static", html=True), name="static")
9
 
 
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse
4
 
5
+ from fastapi import FastAPI, File, UploadFile, Form
6
+ from fastapi.responses import JSONResponse, RedirectResponse
7
+ from fastapi.staticfiles import StaticFiles
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer
10
+ import shutil
11
+ #
12
+ import os
13
+ import logging
14
+ from PyPDF2 import PdfReader
15
+ import docx
16
+ from PIL import Image
17
+ import openpyxl # 📌 Pour lire les fichiers Excel (.xlsx)
18
+ from pptx import Presentation
19
+ import fitz # PyMuPDF
20
+ import io
21
+ from docx import Document
22
+ import matplotlib.pyplot as plt
23
+ import seaborn as sns
24
+ import torch
25
+ import re
26
+ import pandas as pd
27
+ from transformers import AutoTokenizer, AutoModelForCausalLM
28
+ from fastapi.responses import FileResponse
29
+ import os
30
+ from fastapi.middleware.cors import CORSMiddleware
31
+ import matplotlib
32
+ matplotlib.use('Agg')
33
+
34
+ import re
35
+ import torch
36
+ import pandas as pd
37
+ import matplotlib.pyplot as plt
38
+ import seaborn as sns
39
+ from transformers import AutoTokenizer, AutoModelForCausalLM
40
+ from fastapi import FastAPI, File, UploadFile, Form
41
+ from fastapi.responses import FileResponse
42
+ import os
43
+ from fastapi.middleware.cors import CORSMiddleware
44
+ from fastapi import FastAPI, File, UploadFile, Form
45
+ from fastapi.responses import JSONResponse, RedirectResponse
46
+ from fastapi.staticfiles import StaticFiles
47
+ from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
48
+ import shutil
49
+ import os
50
+ import logging
51
+ from fastapi.middleware.cors import CORSMiddleware
52
+ from PyPDF2 import PdfReader
53
+ import docx
54
+ from PIL import Image # Pour ouvrir les images avant analyse
55
+ from transformers import MarianMTModel, MarianTokenizer
56
+ import os
57
+ import fitz
58
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
59
+
60
+ import logging
61
+ import openpyxl
62
+
63
+
64
+ # Configuration du logging
65
+ logging.basicConfig(level=logging.INFO)
66
+
67
+
68
  app = FastAPI()
69
 
70
+ # Configuration CORS
71
+ app.add_middleware(
72
+ CORSMiddleware,
73
+ allow_origins=["*"],
74
+ allow_credentials=True,
75
+ allow_methods=["*"],
76
+ allow_headers=["*"],
77
+ )
78
+
79
+ UPLOAD_DIR = "uploads"
80
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
81
+
82
+
83
+
84
+
85
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
86
+ model_name = "facebook/m2m100_418M"
87
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
89
+
90
+
91
+ # Fonction pour extraire le texte
92
+ def extract_text_from_pdf(file):
93
+ doc = fitz.open(stream=file.file.read(), filetype="pdf")
94
+ return "\n".join([page.get_text() for page in doc]).strip()
95
+
96
+ def extract_text_from_docx(file):
97
+ doc = Document(io.BytesIO(file.file.read()))
98
+ return "\n".join([para.text for para in doc.paragraphs]).strip()
99
+
100
+ def extract_text_from_pptx(file):
101
+ prs = Presentation(io.BytesIO(file.file.read()))
102
+ return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]).strip()
103
+
104
+ def extract_text_from_excel(file):
105
+ wb = openpyxl.load_workbook(io.BytesIO(file.file.read()), data_only=True)
106
+ text = [str(cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True) for cell in row if cell]
107
+ return "\n".join(text).strip()
108
+
109
+ @app.post("/translate/")
110
+ async def translate_document(file: UploadFile = File(...), target_lang: str = Form(...)):
111
+ """API pour traduire un document."""
112
+ try:
113
+ logging.info(f"📥 Fichier reçu : {file.filename}")
114
+ logging.info(f"🌍 Langue cible reçue : {target_lang}")
115
+
116
+ if model is None or tokenizer is None:
117
+ return JSONResponse(status_code=500, content={"error": "Modèle de traduction non chargé"})
118
+
119
+ # Extraction du texte
120
+ if file.filename.endswith(".pdf"):
121
+ text = extract_text_from_pdf(file)
122
+ elif file.filename.endswith(".docx"):
123
+ text = extract_text_from_docx(file)
124
+ elif file.filename.endswith(".pptx"):
125
+ text = extract_text_from_pptx(file)
126
+ elif file.filename.endswith(".xlsx"):
127
+ text = extract_text_from_excel(file)
128
+ else:
129
+ return JSONResponse(status_code=400, content={"error": "Format non supporté"})
130
+
131
+ logging.info(f"📜 Texte extrait : {text[:50]}...")
132
+
133
+ if not text:
134
+ return JSONResponse(status_code=400, content={"error": "Aucun texte trouvé dans le document"})
135
+
136
+ # Vérifier si la langue cible est supportée
137
+ target_lang_id = tokenizer.get_lang_id(target_lang)
138
+
139
+ if target_lang_id is None:
140
+ return JSONResponse(
141
+ status_code=400,
142
+ content={"error": f"Langue cible '{target_lang}' non supportée. Langues disponibles : {list(tokenizer.lang_code_to_id.keys())}"}
143
+ )
144
+
145
+ # Traduction
146
+ tokenizer.src_lang = "fr"
147
+ encoded_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
148
+
149
+ logging.info(f"🔍 ID de la langue cible : {target_lang_id}")
150
+
151
+ generated_tokens = model.generate(**encoded_text, forced_bos_token_id=target_lang_id)
152
+
153
+ translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
154
+
155
+ logging.info(f"✅ Traduction réussie : {translated_text[:50]}...")
156
+ return {"translated_text": translated_text}
157
+
158
+ except Exception as e:
159
+ logging.error(f"❌ Erreur lors de la traduction : {e}")
160
+ return JSONResponse(status_code=500, content={"error": "Échec de la traduction"})
161
+
162
+
163
+
164
+
165
+ # Charger le modèle pour la génération de code
166
+ codegen_model_name = "Salesforce/codegen-350M-mono"
167
+ device = "cuda" if torch.cuda.is_available() else "cpu"
168
+
169
+ codegen_tokenizer = AutoTokenizer.from_pretrained(codegen_model_name)
170
+ codegen_model = AutoModelForCausalLM.from_pretrained(codegen_model_name).to(device)
171
+
172
+ VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"}
173
+
174
+ @app.post("/generate_viz/")
175
+ async def generate_viz(file: UploadFile = File(...), query: str = Form(...)):
176
+ try:
177
+ if query not in VALID_PLOTS:
178
+ return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"}
179
+
180
+ df = pd.read_excel(file.file)
181
+
182
+ numeric_cols = df.select_dtypes(include=["number"]).columns
183
+ if len(numeric_cols) < 2:
184
+ return {"error": "Le fichier doit contenir au moins deux colonnes numériques."}
185
+
186
+ x_col, y_col = numeric_cols[:2]
187
+
188
+ # Contraintes spécifiques pour éviter l'erreur avec histplot
189
+ if query == "histplot":
190
+ prompt_y = ""
191
+ else:
192
+ prompt_y = f', y="{y_col}"'
193
+
194
+ # Générer l'invite pour le modèle
195
+ prompt = f"""
196
+ ### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ###
197
+ # Contraintes :
198
+ # - Utilise 'df' sans recréer de nouvelles données
199
+ # - Axe X : '{x_col}'
200
+ # - Enregistre le graphique sous 'plot.png'
201
+ # - Ne génère que du code Python valide, sans texte explicatif
202
+ # Contraintes spécifiques pour sns.histplot :
203
+ # - N'inclut pas "y=" car histplot ne supporte qu'un axe
204
+ import matplotlib.pyplot as plt
205
+ import seaborn as sns
206
+ plt.figure(figsize=(8,6))
207
+ sns.{query}(data=df, x="{x_col}"{prompt_y})
208
+ plt.savefig("plot.png")
209
+ plt.close()
210
+ """
211
+
212
+ # Génération du code
213
+ inputs = codegen_tokenizer(prompt, return_tensors="pt").to(device)
214
+ outputs = codegen_model.generate(**inputs, max_new_tokens=120, pad_token_id=codegen_tokenizer.eos_token_id)
215
+ generated_code = codegen_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
216
+ # Nettoyage du code
217
+ generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code)
218
+ if generated_code.strip().endswith("sns."):
219
+ generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète
220
+
221
+ print("🔹 Code généré par l'IA :\n", generated_code)
222
+
223
+ # Vérification syntaxique avant exécution
224
+ try:
225
+ compile(generated_code, "<string>", "exec")
226
+ except SyntaxError as e:
227
+ return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"}
228
+
229
+ # Vérification des données
230
+ print(df.head()) # Affiche les premières lignes du dataframe
231
+ print(df.dtypes) # Vérifie les types de colonnes
232
+ print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique())
233
+
234
+ if df.empty or x_col not in df.columns or df[x_col].isnull().all():
235
+ return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."}
236
+
237
+ # Exécution du code généré
238
+ exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd}
239
+ exec(generated_code, exec_env)
240
+
241
+ # Vérification de l'image générée
242
+ img_path = "plot.png"
243
+ if not os.path.exists(img_path):
244
+ return {"error": "Le fichier plot.png n'a pas été généré."}
245
+ if os.path.getsize(img_path) == 0:
246
+ return {"error": "Le fichier plot.png est vide."}
247
+
248
+ plt.close()
249
+ return FileResponse(img_path, media_type="image/png")
250
+
251
+ except Exception as e:
252
+ return {"error": f"Erreur lors de la génération du graphique : {str(e)}"}
253
+
254
+
255
  # Servir les fichiers statiques (HTML, CSS, JS)
256
  app.mount("/static", StaticFiles(directory="static", html=True), name="static")
257