beyoru commited on
Commit
52a9bf8
·
verified ·
1 Parent(s): 8964e4f

Update client.py

Browse files
Files changed (1) hide show
  1. client.py +49 -51
client.py CHANGED
@@ -1,15 +1,18 @@
1
- from huggingface_hub import InferenceClient
 
2
  from init import ACCESS_TOKEN, SYSTEM_PROMPT
3
  from utils import extract_sql, is_sql
4
  from database import execute
5
- import os
6
-
7
- client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
8
 
 
 
 
 
9
 
10
  def respond(message, history, system_message, max_tokens, temperature, top_p):
11
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
12
- # Xử lý lịch sử chat
 
13
  for val in history:
14
  if val[0]:
15
  messages.append({"role": "user", "content": val[0]})
@@ -18,21 +21,23 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
18
 
19
  messages.append({"role": "user", "content": message})
20
 
21
- # Tạo response đầu tiên
22
- response = ""
23
- for message in client.chat.completions.create(
24
- model="Qwen/Qwen2.5-3B-Instruct",
25
- max_tokens=max_tokens,
26
- stream=True,
 
 
27
  temperature=temperature,
28
  top_p=top_p,
29
- messages=messages,
30
- ):
31
- token = message.choices[0].delta.content
32
- response += token
33
- yield response
34
 
35
- # Xử logic SQL và retry
 
 
 
36
  if is_sql(response):
37
  sql_query = extract_sql(response)
38
  max_attempts = 3
@@ -48,59 +53,52 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
48
  last_error = str(e)
49
  attempts += 1
50
  if attempts < max_attempts:
51
- # Thêm thông tin lỗi vào context yêu cầu hình hỏi lại người dùng
52
- clarification_prompt = f"""Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}
53
- Bạn có thể cung cấp thêm thông tin hoặc chỉnh sửa câu hỏi để tôi có thể sửa truy vấn không?"""
54
  messages += [
55
  {"role": "assistant", "content": response},
56
  {"role": "user", "content": clarification_prompt},
57
  ]
58
 
59
- # Tạo response yêu cầu thông tin thêm
60
- response = ""
61
- for message in client.chat.completions.create(
62
- model="Qwen/Qwen2.5-3B-Instruct",
63
- max_tokens=max_tokens,
64
- stream=True,
 
 
65
  temperature=temperature,
66
  top_p=top_p,
67
- messages=messages,
68
- ):
69
- token = message.choices[0].delta.content
70
- response += token
71
- yield response
72
 
73
- # Nếu mô hình cung cấp SQL mới, tiếp tục thử
74
  if is_sql(response):
75
  sql_query = extract_sql(response)
76
  else:
77
- # Nếu sau 3 lần vẫn lỗi, tiếp tục hỏi lại người dùng thay in lỗi
78
- retry_prompt = f"""Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}
79
- Bạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?"""
80
- messages.append({"role": "assistant", "content": retry_prompt})
81
  yield retry_prompt
82
  return
83
 
84
- # Nếu thực hiện truy vấn thành công
85
  if sql_result is not None:
86
- reformulation_prompt = f"""Kết quả truy vấn SQL:
87
- {sql_result}
88
- Hãy tóm tắt kết quả thành phản hồi tự nhiên cho người dùng."""
89
  messages += [
90
  {"role": "assistant", "content": response},
91
  {"role": "user", "content": reformulation_prompt},
92
  ]
93
 
94
- # Tạo response tóm tắt
95
- reformulated_response = ""
96
- for message in client.chat.completions.create(
97
- model="Qwen/Qwen2.5-3B-Instruct",
98
- max_tokens=512,
99
- stream=True,
100
  temperature=temperature,
101
  top_p=top_p,
102
- messages=messages,
103
- ):
104
- token = message.choices[0].delta.content
105
- reformulated_response += token
106
- yield reformulated_response
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
  from init import ACCESS_TOKEN, SYSTEM_PROMPT
4
  from utils import extract_sql, is_sql
5
  from database import execute
 
 
 
6
 
7
+ # Load the model and tokenizer
8
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
11
 
12
  def respond(message, history, system_message, max_tokens, temperature, top_p):
13
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
14
+
15
+ # Process chat history
16
  for val in history:
17
  if val[0]:
18
  messages.append({"role": "user", "content": val[0]})
 
21
 
22
  messages.append({"role": "user", "content": message})
23
 
24
+ # Tokenize input
25
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
26
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
27
+
28
+ # Generate response
29
+ output_ids = model.generate(
30
+ input_ids,
31
+ max_new_tokens=max_tokens,
32
  temperature=temperature,
33
  top_p=top_p,
34
+ do_sample=True
35
+ )
 
 
 
36
 
37
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
+ yield response
39
+
40
+ # SQL Processing and Retry Logic
41
  if is_sql(response):
42
  sql_query = extract_sql(response)
43
  max_attempts = 3
 
53
  last_error = str(e)
54
  attempts += 1
55
  if attempts < max_attempts:
56
+ clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?"
 
 
57
  messages += [
58
  {"role": "assistant", "content": response},
59
  {"role": "user", "content": clarification_prompt},
60
  ]
61
 
62
+ # Tokenize clarification prompt
63
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
64
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
65
+
66
+ # Generate new response
67
+ output_ids = model.generate(
68
+ input_ids,
69
+ max_new_tokens=max_tokens,
70
  temperature=temperature,
71
  top_p=top_p,
72
+ do_sample=True
73
+ )
74
+
75
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
76
+ yield response
77
 
 
78
  if is_sql(response):
79
  sql_query = extract_sql(response)
80
  else:
81
+ retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn thể cung cấp thêm chi ti���t về dữ liệu cần truy vấn không?"
 
 
 
82
  yield retry_prompt
83
  return
84
 
 
85
  if sql_result is not None:
86
+ reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên."
 
 
87
  messages += [
88
  {"role": "assistant", "content": response},
89
  {"role": "user", "content": reformulation_prompt},
90
  ]
91
 
92
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
94
+
95
+ output_ids = model.generate(
96
+ input_ids,
97
+ max_new_tokens=512,
98
  temperature=temperature,
99
  top_p=top_p,
100
+ do_sample=True
101
+ )
102
+
103
+ reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
104
+ yield reformulated_response