zRzRzRzRzRzRzR commited on
Commit
6ae9ac6
·
1 Parent(s): f6f6dfc

update with code fix

Browse files
Files changed (1) hide show
  1. README.md +9 -4
README.md CHANGED
@@ -63,8 +63,10 @@ Below is a simple workflow to help you quickly connect the pipeline.
63
 
64
  ```python
65
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
66
 
67
- MODEL_PATH = "THUDM/GLM-Z1-Rumination-32B-0414"
68
 
69
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
70
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto")
@@ -75,6 +77,7 @@ generate_kwargs = {
75
  "temperature": 0.95,
76
  "top_p": 0.7,
77
  "do_sample": True,
 
78
  }
79
 
80
  def get_assistant():
@@ -84,16 +87,17 @@ def get_assistant():
84
  add_generation_prompt=True,
85
  return_dict=True,
86
  ).to(model.device)
87
- out = model.generate(input_ids=input["input_ids"], **generate_kwargs)
88
  return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
89
 
90
  def get_observation(function_name, args):
91
- if fucntion_name == "search":
 
92
  mock_search_res = [
93
  {"title": "t1", "url":"url1", "snippet": "snippet_content_1"},
94
  {"title": "t2", "url":"url2", "snippet": "snippet_content_2"}
95
  ]
96
- content = "\n\n".join([f"【{i}†{res['title']}†{res['url']}\n{res['snippet']}】"] for i, res in mock_search_res)
97
  elif function_name == "click":
98
  mock_click_res = "main content"
99
  content = mock_click_res
@@ -102,6 +106,7 @@ def get_observation(function_name, args):
102
  content = mock_open_res
103
  else:
104
  raise ValueError("unspport function name!")
 
105
 
106
  def get_func_name_args(llm_text):
107
  function_call = re.sub(r'.*?</think>', '', llm_text, flags=re.DOTALL)
 
63
 
64
  ```python
65
  from transformers import AutoModelForCausalLM, AutoTokenizer
66
+ import re
67
+ import json
68
 
69
+ MODEL_PATH = "THUDM/GLM-4-Z1-Rumination-32B-0414"
70
 
71
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
72
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto")
 
77
  "temperature": 0.95,
78
  "top_p": 0.7,
79
  "do_sample": True,
80
+ "max_new_tokens": 16384
81
  }
82
 
83
  def get_assistant():
 
87
  add_generation_prompt=True,
88
  return_dict=True,
89
  ).to(model.device)
90
+ out = model.generate(input_ids=inputs["input_ids"], **generate_kwargs)
91
  return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
92
 
93
  def get_observation(function_name, args):
94
+ content = None
95
+ if function_name == "search":
96
  mock_search_res = [
97
  {"title": "t1", "url":"url1", "snippet": "snippet_content_1"},
98
  {"title": "t2", "url":"url2", "snippet": "snippet_content_2"}
99
  ]
100
+ content = "\n\n".join([f"【{i}†{res['title']}†{res['url']}\n{res['snippet']}】"] for i, res in enumerate(mock_search_res))
101
  elif function_name == "click":
102
  mock_click_res = "main content"
103
  content = mock_click_res
 
106
  content = mock_open_res
107
  else:
108
  raise ValueError("unspport function name!")
109
+ return content
110
 
111
  def get_func_name_args(llm_text):
112
  function_call = re.sub(r'.*?</think>', '', llm_text, flags=re.DOTALL)