File size: 4,158 Bytes
52a9bf8
 
8964e4f
 
 
 
52a9bf8
 
892868a
dd9176b
8964e4f
dd9176b
 
8964e4f
dd9176b
52a9bf8
 
8964e4f
 
 
 
 
 
 
 
52a9bf8
 
dd9176b
52a9bf8
 
 
 
 
8964e4f
 
52a9bf8
 
8964e4f
52a9bf8
 
 
 
8964e4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a9bf8
8964e4f
 
 
 
 
52a9bf8
 
 
 
 
 
 
 
8964e4f
 
52a9bf8
 
 
 
 
8964e4f
 
 
 
52a9bf8
8964e4f
 
 
 
52a9bf8
8964e4f
 
 
 
 
52a9bf8
 
 
 
 
 
8964e4f
 
52a9bf8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from init import ACCESS_TOKEN, SYSTEM_PROMPT
from utils import extract_sql, is_sql
from database import execute

# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


messages = [{"role": "system", "content": SYSTEM_PROMPT}]
def respond(message, history, system_message, max_tokens, temperature, top_p):
    
    
    # Process chat history
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Tokenize input
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer.encode(input_text, return_tensors="pt")

    # Generate response
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True
    )

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    yield response

    # SQL Processing and Retry Logic
    if is_sql(response):
        sql_query = extract_sql(response)
        max_attempts = 3
        attempts = 0
        sql_result = None
        last_error = None

        while attempts < max_attempts:
            try:
                sql_result = execute(sql_query)
                break
            except Exception as e:
                last_error = str(e)
                attempts += 1
                if attempts < max_attempts:
                    clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?"
                    messages += [
                        {"role": "assistant", "content": response},
                        {"role": "user", "content": clarification_prompt},
                    ]

                    # Tokenize clarification prompt
                    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

                    # Generate new response
                    output_ids = model.generate(
                        input_ids,
                        max_new_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        do_sample=True
                    )

                    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
                    yield response

                    if is_sql(response):
                        sql_query = extract_sql(response)
                else:
                    retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?"
                    yield retry_prompt
                    return

        if sql_result is not None:
            reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên."
            messages += [
                {"role": "assistant", "content": response},
                {"role": "user", "content": reformulation_prompt},
            ]

            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

            output_ids = model.generate(
                input_ids,
                max_new_tokens=512,
                temperature=temperature,
                top_p=top_p,
                do_sample=True
            )

            reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            yield reformulated_response