AI / app.py
Starchik1's picture
Update app.py
4e8b762 verified
raw
history blame
4.7 kB
import os
import json
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import requests
app = Flask(__name__)
CORS(app)
# Глобальные переменные для модели
model = None
tokenizer = None
# Загрузка модели Mistral (локально)
def load_model():
global model, tokenizer
try:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Проверяем наличие GPU
if torch.cuda.is_available():
print("Загрузка модели на GPU...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True # Для оптимизации памяти
)
else:
print("GPU не обнаружен. Загрузка модели на CPU (это может быть медленно)...")
# Загрузка облегченной версии для CPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Используем float32 для CPU
low_cpu_mem_usage=True,
device_map="auto"
)
print("Модель успешно загружена!")
except Exception as e:
print(f"Ошибка при загрузке модели: {e}")
model = None
tokenizer = None
# Инициализация приложения
@app.before_request
def before_request():
global model, tokenizer
if model is None or tokenizer is None:
load_model()
# Функция для генерации ответа от модели
from dotenv import load_dotenv
import requests
# Загружаем переменные окружения
load_dotenv()
# Проверяем наличие API ключа
API_KEY = os.getenv("MISTRAL_API_KEY")
API_URL = os.getenv("MISTRAL_API_URL", "https://api.mistral.ai/v1/")
# Функция для генерации ответа через API
def generate_response_api(prompt, max_length=1024):
if not API_KEY:
return "Ошибка: API ключ не найден. Пожалуйста, добавьте MISTRAL_API_KEY в файл .env"
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "mistral-medium", # или другая доступная модель
"messages": [
{"role": "user", "content": prompt}
],
"max_tokens": max_length,
"temperature": 0.7,
"top_p": 0.9
}
try:
response = requests.post(f"{API_URL}chat/completions", headers=headers, json=data)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"Ошибка при обращении к API: {str(e)}"
# Обновление функции generate_response для использования API
def generate_response(prompt, max_length=1024):
return generate_response_api(prompt, max_length)
# Удаление старой функции load_model
# Маршруты
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/chat', methods=['POST'])
def chat():
data = request.json
prompt = data.get('prompt', '')
if not prompt:
return jsonify({"error": "Пустой запрос"}), 400
try:
response = generate_response(prompt)
return jsonify({"response": response})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/code', methods=['POST'])
def code():
data = request.json
prompt = data.get('prompt', '')
language = data.get('language', 'python')
if not prompt:
return jsonify({"error": "Пустой запрос"}), 400
# Добавляем контекст для генерации кода
code_prompt = f"Напиши код на языке {language} для решения следующей задачи: {prompt}"
try:
response = generate_response(code_prompt)
return jsonify({"code": response})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True, port=5000)