File size: 3,030 Bytes
3257b28
 
 
 
 
 
50a2b44
d3dcfb8
 
 
 
 
3257b28
 
 
 
 
 
 
 
744b744
3257b28
d3dcfb8
 
 
 
50a2b44
d3dcfb8
50a2b44
 
d3dcfb8
50a2b44
 
 
 
d3dcfb8
 
3257b28
d3dcfb8
 
 
3257b28
 
 
50a2b44
3257b28
d3dcfb8
 
 
3257b28
d3dcfb8
3257b28
905ea86
3257b28
d3dcfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3257b28
 
d3dcfb8
3257b28
 
 
 
 
 
 
 
905ea86
3257b28
 
 
 
 
 
 
d3dcfb8
3257b28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import BitsAndBytesConfig
import logging

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()

# Login to Hugging Face
hf_token = os.getenv('HF_TOKEN')
login(hf_token)

# Configuration du modèle
model_path = "mistralai/Mistral-Large-Instruct-2411"

# Détermination automatique du dtype optimal
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
logger.info(f"Using dtype: {dtype}")

# Configuration de la quantification 4-bits
logger.info("Configuring 4-bit quantization")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=dtype,  # Utilisation du dtype optimal
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Initialisation du modèle
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Tokenizer loaded successfully")

logger.info(f"Loading model from {model_path} with 4-bit quantization")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    quantization_config=quantization_config
)
logger.info("Model loaded successfully")

logger.info("Creating inference pipeline")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
logger.info("Inference pipeline created successfully")

def generate_response(message, temperature=0.7, max_new_tokens=256):
    try:
        logger.info(f"Generating response for message: {message[:50]}...")
        parameters = {
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            # "do_sample": True,
            # "top_k": 50,
            # "top_p": 0.9,
            # "pad_token_id": tokenizer.pad_token_id,
            # "eos_token_id": tokenizer.eos_token_id,
            # "batch_size": 1
        }
        logger.info(f"Parameters: {parameters}")
        
        response = pipe(message, **parameters)
        logger.info("Response generated successfully")
        return response[0]['generated_text']
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return f"Une erreur s'est produite : {str(e)}"

# Interface Gradio
demo = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"),
        gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens")
    ],
    outputs=gr.Textbox(label="Réponse"),
    title="Chat avec Sacha-Mistral",
    description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral"
)

if __name__ == "__main__":
    logger.info("Starting Gradio interface")
    demo.launch(share=True)