import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import os from dotenv import load_dotenv from huggingface_hub import login from transformers import BitsAndBytesConfig import logging # Configuration du logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() # Login to Hugging Face hf_token = os.getenv('HF_TOKEN') login(hf_token) # Configuration du modèle model_path = "mistralai/Mistral-Large-Instruct-2411" # Détermination automatique du dtype optimal dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 logger.info(f"Using dtype: {dtype}") # Configuration de la quantification 4-bits logger.info("Configuring 4-bit quantization") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=dtype, # Utilisation du dtype optimal bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) # Initialisation du modèle logger.info(f"Loading tokenizer from {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) logger.info("Tokenizer loaded successfully") logger.info(f"Loading model from {model_path} with 4-bit quantization") model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", quantization_config=quantization_config ) logger.info("Model loaded successfully") logger.info("Creating inference pipeline") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) logger.info("Inference pipeline created successfully") def generate_response(message, temperature=0.7, max_new_tokens=256): try: logger.info(f"Generating response for message: {message[:50]}...") parameters = { "temperature": temperature, "max_new_tokens": max_new_tokens, # "do_sample": True, # "top_k": 50, # "top_p": 0.9, # "pad_token_id": tokenizer.pad_token_id, # "eos_token_id": tokenizer.eos_token_id, # "batch_size": 1 } logger.info(f"Parameters: {parameters}") response = pipe(message, **parameters) logger.info("Response generated successfully") return response[0]['generated_text'] except Exception as e: logger.error(f"Error during generation: {str(e)}") return f"Une erreur s'est produite : {str(e)}" # Interface Gradio demo = gr.Interface( fn=generate_response, inputs=[ gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"), gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens") ], outputs=gr.Textbox(label="Réponse"), title="Chat avec Sacha-Mistral", description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral" ) if __name__ == "__main__": logger.info("Starting Gradio interface") demo.launch(share=True)