|
import os
|
|
import json
|
|
from flask import Flask, render_template, request, jsonify
|
|
from flask_cors import CORS
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
|
|
model = None
|
|
tokenizer = None
|
|
|
|
|
|
def load_model():
|
|
global model, tokenizer
|
|
try:
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
print("Загрузка модели на GPU...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
load_in_8bit=True
|
|
)
|
|
else:
|
|
print("GPU не обнаружен. Загрузка модели на CPU (это может быть медленно)...")
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.float32,
|
|
low_cpu_mem_usage=True,
|
|
device_map="auto"
|
|
)
|
|
print("Модель успешно загружена!")
|
|
except Exception as e:
|
|
print(f"Ошибка при загрузке модели: {e}")
|
|
model = None
|
|
tokenizer = None
|
|
|
|
|
|
@app.before_request
|
|
def before_request():
|
|
global model, tokenizer
|
|
if model is None or tokenizer is None:
|
|
load_model()
|
|
|
|
|
|
def generate_response(prompt, max_length=1024):
|
|
if model is None or tokenizer is None:
|
|
return "Ошибка: Модель не загружена"
|
|
|
|
|
|
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
inputs["input_ids"],
|
|
max_new_tokens=max_length,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
do_sample=True,
|
|
pad_token_id=tokenizer.eos_token_id
|
|
)
|
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
response = response.split("[/INST]")[-1].strip()
|
|
|
|
return response
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
@app.route('/api/chat', methods=['POST'])
|
|
def chat():
|
|
data = request.json
|
|
prompt = data.get('prompt', '')
|
|
|
|
if not prompt:
|
|
return jsonify({"error": "Пустой запрос"}), 400
|
|
|
|
try:
|
|
response = generate_response(prompt)
|
|
return jsonify({"response": response})
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
@app.route('/api/code', methods=['POST'])
|
|
def code():
|
|
data = request.json
|
|
prompt = data.get('prompt', '')
|
|
language = data.get('language', 'python')
|
|
|
|
if not prompt:
|
|
return jsonify({"error": "Пустой запрос"}), 400
|
|
|
|
|
|
code_prompt = f"Напиши код на языке {language} для решения следующей задачи: {prompt}"
|
|
|
|
try:
|
|
response = generate_response(code_prompt)
|
|
return jsonify({"code": response})
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True, port=5000) |