Spaces:
Running
Running
from fastapi import FastAPI, Body | |
import torch | |
import spacy | |
import os | |
from pathlib import Path | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification | |
from pydantic import BaseModel | |
# Define input model | |
class TextInput(BaseModel): | |
text: str | |
# Initialize FastAPI | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
# Vous pouvez restreindre ceci à votre frontend spécifique | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Get base directory | |
base_dir = Path(__file__).parent.absolute() | |
# Your Hugging Face Hub username | |
HF_USERNAME = "YassineJedidi" # Replace with your actual username | |
# Définition des entités valides pour chaque type | |
entites_valides = { | |
"Tâche": {"TITRE", "DELAI", "PRIORITE"}, | |
"Événement": {"TITRE", "DATE_HEURE"}, | |
} | |
# Try to load models from Hugging Face Hub | |
try: | |
print("Loading models from Hugging Face Hub") | |
# Model repositories on Hugging Face | |
tokenizer_repo = f"{HF_USERNAME}/tasks-tokenizer" | |
ner_model_repo = f"{HF_USERNAME}/tasks-ner" | |
type_model_repo = f"{HF_USERNAME}/tasks-type" | |
print(f"Loading tokenizer from: {tokenizer_repo}") | |
print(f"Loading NER model from: {ner_model_repo}") | |
print(f"Loading type model from: {type_model_repo}") | |
# Load models from Hugging Face Hub | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo) | |
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_repo) | |
type_model = AutoModelForSequenceClassification.from_pretrained( | |
type_model_repo) | |
except Exception as e: | |
print(f"Error loading models from Hugging Face Hub: {e}") | |
# Fallback to local files if available | |
try: | |
# Convert paths to strings with forward slashes | |
tokenizer_path = str(base_dir / "models" / | |
"tasks-tokenizer").replace("\\", "/") | |
ner_model_path = str(base_dir / "models" / | |
"tasks-ner").replace("\\", "/") | |
type_model_path = str(base_dir / "models" / | |
"tasks-types").replace("\\", "/") | |
print(f"Falling back to local models") | |
print(f"Loading tokenizer from: {tokenizer_path}") | |
print(f"Loading NER model from: {ner_model_path}") | |
print(f"Loading type model from: {type_model_path}") | |
# Load models from local files | |
tokenizer = AutoTokenizer.from_pretrained( | |
tokenizer_path, local_files_only=True) | |
ner_model = AutoModelForTokenClassification.from_pretrained( | |
ner_model_path, local_files_only=True) | |
type_model = AutoModelForSequenceClassification.from_pretrained( | |
type_model_path, local_files_only=True) | |
except Exception as e: | |
print(f"Error loading local models: {e}") | |
# Fallback to base model from HuggingFace | |
print("Falling back to base CamemBERT model from HuggingFace Hub") | |
tokenizer = AutoTokenizer.from_pretrained("camembert-base") | |
ner_model = AutoModelForTokenClassification.from_pretrained( | |
"camembert-base") | |
type_model = AutoModelForSequenceClassification.from_pretrained( | |
"camembert-base") | |
# Load spaCy for tokenization | |
nlp = spacy.load('fr_core_news_lg') | |
# Set device (CPU or GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ner_model = ner_model.to(device) | |
type_model = type_model.to(device) | |
# Retrieve label mappings | |
id2label = ner_model.config.id2label | |
id2type = type_model.config.id2label | |
def root(): | |
return {"message": "FastAPI NLP Model is running!"} | |
async def predict_type(input_data: TextInput): | |
text = input_data.text | |
inputs = tokenizer(text, return_tensors="pt", | |
truncation=True, padding=True).to(device) | |
with torch.no_grad(): | |
outputs = type_model(**inputs) | |
predicted_class_id = outputs.logits.argmax().item() | |
predicted_type = id2type[predicted_class_id] | |
confidence = torch.softmax(outputs.logits, dim=1).max().item() | |
return {"type": predicted_type, "confidence": confidence} | |
async def extract_entities(input_data: TextInput): | |
text = input_data.text | |
doc = nlp(text) | |
tokens = [token.text for token in doc] | |
inputs = tokenizer(tokens, is_split_into_words=True, | |
return_tensors="pt", truncation=True, padding=True).to(device) | |
with torch.no_grad(): | |
outputs = ner_model(**inputs) | |
predictions = outputs.logits.argmax(dim=2) | |
entities = {} | |
current_entity = None | |
current_text = [] | |
word_ids = inputs.word_ids(0) | |
for idx, word_idx in enumerate(word_ids): | |
if word_idx is None: | |
continue | |
if idx > 0 and word_ids[idx-1] == word_idx: | |
continue | |
prediction = predictions[0, idx].item() | |
predicted_label = id2label[prediction] | |
if predicted_label.startswith("B-"): | |
if current_entity: | |
entity_type = current_entity[2:] | |
if entity_type not in entities: | |
entities[entity_type] = [] | |
entities[entity_type].append(" ".join(current_text)) | |
current_entity = predicted_label | |
current_text = [tokens[word_idx]] | |
elif predicted_label.startswith("I-") and current_entity and predicted_label[2:] == current_entity[2:]: | |
current_text.append(tokens[word_idx]) | |
else: | |
if current_entity: | |
entity_type = current_entity[2:] | |
if entity_type not in entities: | |
entities[entity_type] = [] | |
entities[entity_type].append(" ".join(current_text)) | |
current_entity = None | |
current_text = [] | |
if current_entity: | |
entity_type = current_entity[2:] | |
if entity_type not in entities: | |
entities[entity_type] = [] | |
entities[entity_type].append(" ".join(current_text)) | |
return {"entities": entities} | |
async def analyze_text(input_data: TextInput): | |
type_result = await predict_type(input_data) | |
text_type = type_result["type"] | |
confidence = type_result["confidence"] | |
raw_entities = (await extract_entities(input_data))["entities"] | |
# Filtrage des entités selon le type détecté | |
allowed = entites_valides.get(text_type, set()) | |
filtered_entities = {k: v for k, v in raw_entities.items() if k in allowed} | |
return { | |
"type": text_type, | |
"confidence": confidence, | |
"entities": filtered_entities | |
} | |