File size: 3,030 Bytes
3257b28 50a2b44 d3dcfb8 3257b28 744b744 3257b28 d3dcfb8 50a2b44 d3dcfb8 50a2b44 d3dcfb8 50a2b44 d3dcfb8 3257b28 d3dcfb8 3257b28 50a2b44 3257b28 d3dcfb8 3257b28 d3dcfb8 3257b28 905ea86 3257b28 d3dcfb8 3257b28 d3dcfb8 3257b28 905ea86 3257b28 d3dcfb8 3257b28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import BitsAndBytesConfig
import logging
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
# Login to Hugging Face
hf_token = os.getenv('HF_TOKEN')
login(hf_token)
# Configuration du modèle
model_path = "mistralai/Mistral-Large-Instruct-2411"
# Détermination automatique du dtype optimal
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
logger.info(f"Using dtype: {dtype}")
# Configuration de la quantification 4-bits
logger.info("Configuring 4-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype, # Utilisation du dtype optimal
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Initialisation du modèle
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Tokenizer loaded successfully")
logger.info(f"Loading model from {model_path} with 4-bit quantization")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
quantization_config=quantization_config
)
logger.info("Model loaded successfully")
logger.info("Creating inference pipeline")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
logger.info("Inference pipeline created successfully")
def generate_response(message, temperature=0.7, max_new_tokens=256):
try:
logger.info(f"Generating response for message: {message[:50]}...")
parameters = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
# "do_sample": True,
# "top_k": 50,
# "top_p": 0.9,
# "pad_token_id": tokenizer.pad_token_id,
# "eos_token_id": tokenizer.eos_token_id,
# "batch_size": 1
}
logger.info(f"Parameters: {parameters}")
response = pipe(message, **parameters)
logger.info("Response generated successfully")
return response[0]['generated_text']
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return f"Une erreur s'est produite : {str(e)}"
# Interface Gradio
demo = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"),
gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens")
],
outputs=gr.Textbox(label="Réponse"),
title="Chat avec Sacha-Mistral",
description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral"
)
if __name__ == "__main__":
logger.info("Starting Gradio interface")
demo.launch(share=True) |