Tonic commited on
Commit
66a9100
·
unverified ·
1 Parent(s): 10d56fb

add history add trust remote code

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -4,13 +4,28 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  def load_model():
6
  model_id = "microsoft/bitnet-b1.58-2B-4T"
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
- torch_dtype=torch.bfloat16
 
11
  )
12
  return model, tokenizer
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history):
15
  model, tokenizer = load_model()
16
 
@@ -38,6 +53,10 @@ def generate_response(user_input, system_prompt, max_new_tokens, temperature, to
38
  # Update history
39
  history.append({"role": "user", "content": user_input})
40
  history.append({"role": "assistant", "content": response})
 
 
 
 
41
  return history, history
42
 
43
  # Gradio interface
 
4
 
5
  def load_model():
6
  model_id = "microsoft/bitnet-b1.58-2B-4T"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
+ torch_dtype=torch.bfloat16,
11
+ trust_remote_code=True
12
  )
13
  return model, tokenizer
14
 
15
+ def manage_history(history):
16
+ # Limit to 3 turns (each turn is user + assistant = 2 messages)
17
+ max_messages = 6 # 3 turns * 2 messages per turn
18
+ if len(history) > max_messages:
19
+ history = history[-max_messages:]
20
+
21
+ # Limit total character count to 300
22
+ total_chars = sum(len(msg["content"]) for msg in history)
23
+ while total_chars > 300 and history:
24
+ history.pop(0) # Remove oldest message
25
+ total_chars = sum(len(msg["content"]) for msg in history)
26
+
27
+ return history
28
+
29
  def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history):
30
  model, tokenizer = load_model()
31
 
 
53
  # Update history
54
  history.append({"role": "user", "content": user_input})
55
  history.append({"role": "assistant", "content": response})
56
+
57
+ # Manage history limits
58
+ history = manage_history(history)
59
+
60
  return history, history
61
 
62
  # Gradio interface