Spaces:
Running
Running
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import RedirectResponse | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer, MarianMTModel, MarianTokenizer | |
import shutil | |
# | |
import os | |
import logging | |
from PyPDF2 import PdfReader | |
import docx | |
from PIL import Image | |
import openpyxl # 📌 Pour lire les fichiers Excel (.xlsx) | |
from pptx import Presentation | |
import fitz # PyMuPDF | |
import io | |
from docx import Document | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import torch | |
import re | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from fastapi.responses import FileResponse | |
import os | |
from fastapi.middleware.cors import CORSMiddleware | |
import matplotlib | |
matplotlib.use('Agg') | |
import re | |
import torch | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import FileResponse | |
import os | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer | |
import shutil | |
import os | |
import logging | |
from fastapi.middleware.cors import CORSMiddleware | |
from PyPDF2 import PdfReader | |
import docx | |
from PIL import Image # Pour ouvrir les images avant analyse | |
from transformers import MarianMTModel, MarianTokenizer | |
import os | |
import fitz | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
import logging | |
import openpyxl | |
# Configuration du logging | |
logging.basicConfig(level=logging.INFO) | |
app = FastAPI() | |
# Configuration CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
UPLOAD_DIR = "uploads" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
model_name = "facebook/m2m100_418M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Fonction pour extraire le texte | |
def extract_text_from_pdf(file): | |
doc = fitz.open(stream=file.file.read(), filetype="pdf") | |
return "\n".join([page.get_text() for page in doc]).strip() | |
def extract_text_from_docx(file): | |
doc = Document(io.BytesIO(file.file.read())) | |
return "\n".join([para.text for para in doc.paragraphs]).strip() | |
def extract_text_from_pptx(file): | |
prs = Presentation(io.BytesIO(file.file.read())) | |
return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]).strip() | |
def extract_text_from_excel(file): | |
wb = openpyxl.load_workbook(io.BytesIO(file.file.read()), data_only=True) | |
text = [str(cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True) for cell in row if cell] | |
return "\n".join(text).strip() | |
async def translate_document(file: UploadFile = File(...), target_lang: str = Form(...)): | |
"""API pour traduire un document.""" | |
try: | |
logging.info(f"📥 Fichier reçu : {file.filename}") | |
logging.info(f"🌍 Langue cible reçue : {target_lang}") | |
if model is None or tokenizer is None: | |
return JSONResponse(status_code=500, content={"error": "Modèle de traduction non chargé"}) | |
# Extraction du texte | |
if file.filename.endswith(".pdf"): | |
text = extract_text_from_pdf(file) | |
elif file.filename.endswith(".docx"): | |
text = extract_text_from_docx(file) | |
elif file.filename.endswith(".pptx"): | |
text = extract_text_from_pptx(file) | |
elif file.filename.endswith(".xlsx"): | |
text = extract_text_from_excel(file) | |
else: | |
return JSONResponse(status_code=400, content={"error": "Format non supporté"}) | |
logging.info(f"📜 Texte extrait : {text[:50]}...") | |
if not text: | |
return JSONResponse(status_code=400, content={"error": "Aucun texte trouvé dans le document"}) | |
# Vérifier si la langue cible est supportée | |
target_lang_id = tokenizer.get_lang_id(target_lang) | |
if target_lang_id is None: | |
return JSONResponse( | |
status_code=400, | |
content={"error": f"Langue cible '{target_lang}' non supportée. Langues disponibles : {list(tokenizer.lang_code_to_id.keys())}"} | |
) | |
# Traduction | |
tokenizer.src_lang = "fr" | |
encoded_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
logging.info(f"🔍 ID de la langue cible : {target_lang_id}") | |
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=target_lang_id) | |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
logging.info(f"✅ Traduction réussie : {translated_text[:50]}...") | |
return {"translated_text": translated_text} | |
except Exception as e: | |
logging.error(f"❌ Erreur lors de la traduction : {e}") | |
return JSONResponse(status_code=500, content={"error": "Échec de la traduction"}) | |
# Charger le modèle pour la génération de code | |
codegen_model_name = "Salesforce/codegen-350M-mono" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
codegen_tokenizer = AutoTokenizer.from_pretrained(codegen_model_name) | |
codegen_model = AutoModelForCausalLM.from_pretrained(codegen_model_name).to(device) | |
VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"} | |
async def generate_viz(file: UploadFile = File(...), query: str = Form(...)): | |
try: | |
if query not in VALID_PLOTS: | |
return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"} | |
df = pd.read_excel(file.file) | |
numeric_cols = df.select_dtypes(include=["number"]).columns | |
if len(numeric_cols) < 2: | |
return {"error": "Le fichier doit contenir au moins deux colonnes numériques."} | |
x_col, y_col = numeric_cols[:2] | |
# Contraintes spécifiques pour éviter l'erreur avec histplot | |
if query == "histplot": | |
prompt_y = "" | |
else: | |
prompt_y = f', y="{y_col}"' | |
# Générer l'invite pour le modèle | |
prompt = f""" | |
### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ### | |
# Contraintes : | |
# - Utilise 'df' sans recréer de nouvelles données | |
# - Axe X : '{x_col}' | |
# - Enregistre le graphique sous 'plot.png' | |
# - Ne génère que du code Python valide, sans texte explicatif | |
# Contraintes spécifiques pour sns.histplot : | |
# - N'inclut pas "y=" car histplot ne supporte qu'un axe | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
plt.figure(figsize=(8,6)) | |
sns.{query}(data=df, x="{x_col}"{prompt_y}) | |
plt.savefig("plot.png") | |
plt.close() | |
""" | |
# Génération du code | |
inputs = codegen_tokenizer(prompt, return_tensors="pt").to(device) | |
outputs = codegen_model.generate(**inputs, max_new_tokens=120, pad_token_id=codegen_tokenizer.eos_token_id) | |
generated_code = codegen_tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
# Nettoyage du code | |
generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code) | |
if generated_code.strip().endswith("sns."): | |
generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète | |
print("🔹 Code généré par l'IA :\n", generated_code) | |
# Vérification syntaxique avant exécution | |
try: | |
compile(generated_code, "<string>", "exec") | |
except SyntaxError as e: | |
return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"} | |
# Vérification des données | |
print(df.head()) # Affiche les premières lignes du dataframe | |
print(df.dtypes) # Vérifie les types de colonnes | |
print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique()) | |
if df.empty or x_col not in df.columns or df[x_col].isnull().all(): | |
return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."} | |
# Exécution du code généré | |
exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd} | |
exec(generated_code, exec_env) | |
# Vérification de l'image générée | |
img_path = "plot.png" | |
if not os.path.exists(img_path): | |
return {"error": "Le fichier plot.png n'a pas été généré."} | |
if os.path.getsize(img_path) == 0: | |
return {"error": "Le fichier plot.png est vide."} | |
plt.close() | |
return FileResponse(img_path, media_type="image/png") | |
except Exception as e: | |
return {"error": f"Erreur lors de la génération du graphique : {str(e)}"} | |
# Servir les fichiers statiques (HTML, CSS, JS) | |
app.mount("/static", StaticFiles(directory="static", html=True), name="static") | |
async def root(): | |
return RedirectResponse(url="/static/principal.html") |