from transformers import AutoModelForCausalLM, AutoTokenizer import torch from init import ACCESS_TOKEN, SYSTEM_PROMPT from utils import extract_sql, is_sql from database import execute # Load the model and tokenizer model_name = "Qwen/Qwen2.5-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) messages = [{"role": "system", "content": SYSTEM_PROMPT}] def respond(message, history, system_message, max_tokens, temperature, top_p): # Process chat history for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) # Tokenize input input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer.encode(input_text, return_tensors="pt") # Generate response output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True ) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) yield response # SQL Processing and Retry Logic if is_sql(response): sql_query = extract_sql(response) max_attempts = 3 attempts = 0 sql_result = None last_error = None while attempts < max_attempts: try: sql_result = execute(sql_query) break except Exception as e: last_error = str(e) attempts += 1 if attempts < max_attempts: clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?" messages += [ {"role": "assistant", "content": response}, {"role": "user", "content": clarification_prompt}, ] # Tokenize clarification prompt input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device) # Generate new response output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True ) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) yield response if is_sql(response): sql_query = extract_sql(response) else: retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?" yield retry_prompt return if sql_result is not None: reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên." messages += [ {"role": "assistant", "content": response}, {"role": "user", "content": reformulation_prompt}, ] input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device) output_ids = model.generate( input_ids, max_new_tokens=512, temperature=temperature, top_p=top_p, do_sample=True ) reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True) yield reformulated_response