AI / app.py
Starchik1's picture
Upload 9 files
fb91b72 verified
raw
history blame
4.21 kB
import os
import json
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
app = Flask(__name__)
CORS(app)
# Глобальные переменные для модели
model = None
tokenizer = None
# Загрузка модели Mistral (локально)
def load_model():
global model, tokenizer
try:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Проверяем наличие GPU
if torch.cuda.is_available():
print("Загрузка модели на GPU...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True # Для оптимизации памяти
)
else:
print("GPU не обнаружен. Загрузка модели на CPU (это может быть медленно)...")
# Загрузка облегченной версии для CPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Используем float32 для CPU
low_cpu_mem_usage=True,
device_map="auto"
)
print("Модель успешно загружена!")
except Exception as e:
print(f"Ошибка при загрузке модели: {e}")
model = None
tokenizer = None
# Инициализация приложения
@app.before_request
def before_request():
global model, tokenizer
if model is None or tokenizer is None:
load_model()
# Функция для генерации ответа от модели
def generate_response(prompt, max_length=1024):
if model is None or tokenizer is None:
return "Ошибка: Модель не загружена"
# Форматирование запроса в формате Mistral Instruct
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Генерация ответа
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=max_length,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Декодирование ответа
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Извлечение только ответа модели (после [/INST])
response = response.split("[/INST]")[-1].strip()
return response
# Маршруты
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/chat', methods=['POST'])
def chat():
data = request.json
prompt = data.get('prompt', '')
if not prompt:
return jsonify({"error": "Пустой запрос"}), 400
try:
response = generate_response(prompt)
return jsonify({"response": response})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/code', methods=['POST'])
def code():
data = request.json
prompt = data.get('prompt', '')
language = data.get('language', 'python')
if not prompt:
return jsonify({"error": "Пустой запрос"}), 400
# Добавляем контекст для генерации кода
code_prompt = f"Напиши код на языке {language} для решения следующей задачи: {prompt}"
try:
response = generate_response(code_prompt)
return jsonify({"code": response})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True, port=5000)