File size: 4,240 Bytes
52a9bf8
 
8964e4f
 
 
 
52a9bf8
 
 
 
8964e4f
 
 
52a9bf8
 
8964e4f
 
 
 
 
 
 
 
52a9bf8
 
 
 
 
 
 
 
8964e4f
 
52a9bf8
 
8964e4f
52a9bf8
 
 
 
8964e4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a9bf8
8964e4f
 
 
 
 
52a9bf8
 
 
 
 
 
 
 
8964e4f
 
52a9bf8
 
 
 
 
8964e4f
 
 
 
52a9bf8
8964e4f
 
 
 
52a9bf8
8964e4f
 
 
 
 
52a9bf8
 
 
 
 
 
8964e4f
 
52a9bf8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from init import ACCESS_TOKEN, SYSTEM_PROMPT
from utils import extract_sql, is_sql
from database import execute

# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

def respond(message, history, system_message, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    
    # Process chat history
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Tokenize input
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

    # Generate response
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True
    )

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    yield response

    # SQL Processing and Retry Logic
    if is_sql(response):
        sql_query = extract_sql(response)
        max_attempts = 3
        attempts = 0
        sql_result = None
        last_error = None

        while attempts < max_attempts:
            try:
                sql_result = execute(sql_query)
                break
            except Exception as e:
                last_error = str(e)
                attempts += 1
                if attempts < max_attempts:
                    clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?"
                    messages += [
                        {"role": "assistant", "content": response},
                        {"role": "user", "content": clarification_prompt},
                    ]

                    # Tokenize clarification prompt
                    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

                    # Generate new response
                    output_ids = model.generate(
                        input_ids,
                        max_new_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        do_sample=True
                    )

                    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
                    yield response

                    if is_sql(response):
                        sql_query = extract_sql(response)
                else:
                    retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?"
                    yield retry_prompt
                    return

        if sql_result is not None:
            reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên."
            messages += [
                {"role": "assistant", "content": response},
                {"role": "user", "content": reformulation_prompt},
            ]

            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

            output_ids = model.generate(
                input_ids,
                max_new_tokens=512,
                temperature=temperature,
                top_p=top_p,
                do_sample=True
            )

            reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            yield reformulated_response