File size: 4,158 Bytes
52a9bf8 8964e4f 52a9bf8 892868a dd9176b 8964e4f dd9176b 8964e4f dd9176b 52a9bf8 8964e4f 52a9bf8 dd9176b 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 8964e4f 52a9bf8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from init import ACCESS_TOKEN, SYSTEM_PROMPT
from utils import extract_sql, is_sql
from database import execute
# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
def respond(message, history, system_message, max_tokens, temperature, top_p):
# Process chat history
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Tokenize input
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate response
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
yield response
# SQL Processing and Retry Logic
if is_sql(response):
sql_query = extract_sql(response)
max_attempts = 3
attempts = 0
sql_result = None
last_error = None
while attempts < max_attempts:
try:
sql_result = execute(sql_query)
break
except Exception as e:
last_error = str(e)
attempts += 1
if attempts < max_attempts:
clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?"
messages += [
{"role": "assistant", "content": response},
{"role": "user", "content": clarification_prompt},
]
# Tokenize clarification prompt
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
# Generate new response
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
yield response
if is_sql(response):
sql_query = extract_sql(response)
else:
retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?"
yield retry_prompt
return
if sql_result is not None:
reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên."
messages += [
{"role": "assistant", "content": response},
{"role": "user", "content": reformulation_prompt},
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=512,
temperature=temperature,
top_p=top_p,
do_sample=True
)
reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
yield reformulated_response
|