File size: 4,701 Bytes
77db817
 
 
 
 
 
4e8b762
 
77db817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8b762
77db817
 
 
 
 
 
 
 
 
 
 
4e8b762
 
 
 
 
 
 
 
 
 
 
 
 
 
77db817
4e8b762
 
 
 
77db817
4e8b762
 
 
 
 
 
 
 
 
77db817
4e8b762
 
 
 
 
 
 
 
 
 
 
 
 
77db817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb91b72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import json
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import requests

app = Flask(__name__)
CORS(app)

# Глобальные переменные для модели
model = None
tokenizer = None

# Загрузка модели Mistral (локально)
def load_model():
    global model, tokenizer
    try:
        model_name = "mistralai/Mistral-7B-Instruct-v0.3"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Проверяем наличие GPU
        if torch.cuda.is_available():
            print("Загрузка модели на GPU...")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                load_in_8bit=True  # Для оптимизации памяти
            )
        else:
            print("GPU не обнаружен. Загрузка модели на CPU (это может быть медленно)...")
            # Загрузка облегченной версии для CPU
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float32,  # Используем float32 для CPU
                low_cpu_mem_usage=True,
                device_map="auto"
            )
        print("Модель успешно загружена!")
    except Exception as e:
        print(f"Ошибка при загрузке модели: {e}")
        model = None
        tokenizer = None

# Инициализация приложения
@app.before_request
def before_request():
    global model, tokenizer
    if model is None or tokenizer is None:
        load_model()

# Функция для генерации ответа от модели
from dotenv import load_dotenv
import requests

# Загружаем переменные окружения
load_dotenv()

# Проверяем наличие API ключа
API_KEY = os.getenv("MISTRAL_API_KEY")
API_URL = os.getenv("MISTRAL_API_URL", "https://api.mistral.ai/v1/")

# Функция для генерации ответа через API
def generate_response_api(prompt, max_length=1024):
    if not API_KEY:
        return "Ошибка: API ключ не найден. Пожалуйста, добавьте MISTRAL_API_KEY в файл .env"
    
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    
    data = {
        "model": "mistral-medium",  # или другая доступная модель
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "max_tokens": max_length,
        "temperature": 0.7,
        "top_p": 0.9
    }
    
    try:
        response = requests.post(f"{API_URL}chat/completions", headers=headers, json=data)
        response.raise_for_status()
        result = response.json()
        return result["choices"][0]["message"]["content"]
    except Exception as e:
        return f"Ошибка при обращении к API: {str(e)}"

# Обновление функции generate_response для использования API
def generate_response(prompt, max_length=1024):
    return generate_response_api(prompt, max_length)

# Удаление старой функции load_model

# Маршруты
@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/chat', methods=['POST'])
def chat():
    data = request.json
    prompt = data.get('prompt', '')
    
    if not prompt:
        return jsonify({"error": "Пустой запрос"}), 400
    
    try:
        response = generate_response(prompt)
        return jsonify({"response": response})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/api/code', methods=['POST'])
def code():
    data = request.json
    prompt = data.get('prompt', '')
    language = data.get('language', 'python')
    
    if not prompt:
        return jsonify({"error": "Пустой запрос"}), 400
    
    # Добавляем контекст для генерации кода
    code_prompt = f"Напиши код на языке {language} для решения следующей задачи: {prompt}"
    
    try:
        response = generate_response(code_prompt)
        return jsonify({"code": response})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(debug=True, port=5000)