kz209 commited on
Commit
031841d
·
1 Parent(s): 8e22bd4

update format

Browse files
app.py CHANGED
@@ -13,13 +13,10 @@ This application is for **display** and is designed to facilitate **fast prototy
13
 
14
  Select a demo from the sidebar below to begin experimentation."""
15
 
 
16
  with gr.Blocks() as demo:
17
  with gr.Column(scale=4):
18
- content = content = gr.Blocks(
19
- gr.Markdown(
20
- welcome_message()
21
- )
22
- )
23
 
24
  with gr.Tabs() as tabs:
25
  with gr.TabItem("Summarization"):
 
13
 
14
  Select a demo from the sidebar below to begin experimentation."""
15
 
16
+
17
  with gr.Blocks() as demo:
18
  with gr.Column(scale=4):
19
+ content = content = gr.Blocks(gr.Markdown(welcome_message()))
 
 
 
 
20
 
21
  with gr.Tabs() as tabs:
22
  with gr.TabItem("Summarization"):
pages/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
  # This is the __init__.py file for the utils package
2
  # You can add any initialization code or import statements here
3
 
4
- __all__ = ['arena', 'batch_evaluation', 'leaderboard', 'summarization_playground']
 
1
  # This is the __init__.py file for the utils package
2
  # You can add any initialization code or import statements here
3
 
4
+ __all__ = ["arena", "batch_evaluation", "leaderboard", "summarization_playground"]
pages/arena.py CHANGED
@@ -10,9 +10,10 @@ from utils.multiple_stream import stream_data
10
 
11
  def random_data_selection():
12
  datapoint = random.choice(dataset)
13
- datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
14
  return datapoint
15
 
 
16
  def create_arena():
17
  with open("prompt/prompt.json", "r") as file:
18
  json_data = file.read()
@@ -21,19 +22,24 @@ def create_arena():
21
  with gr.Blocks(css=custom_css) as demo:
22
  with gr.Group():
23
  datapoint = random_data_selection()
24
- gr.Markdown("""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
 
25
 
26
- Once the streaming is complete, you can choose the best response.\u2764\ufe0f""")
 
27
 
28
- data_textbox = gr.Textbox(label="Data", lines=10, placeholder="Datapoints to test...", value=datapoint)
 
 
 
 
 
29
  with gr.Row():
30
  random_selection_button = gr.Button("Change Data")
31
  stream_button = gr.Button("✨ Click to Streaming ✨")
32
 
33
  random_selection_button.click(
34
- fn=random_data_selection,
35
- inputs=[],
36
- outputs=[data_textbox]
37
  )
38
 
39
  random.shuffle(prompts)
@@ -42,43 +48,56 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
42
  # Store prompts in state components
43
  state_prompts = gr.State(value=prompts)
44
  state_random_selected_prompts = gr.State(value=random_selected_prompts)
45
-
46
  with gr.Row():
47
- columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(random_selected_prompts))]
48
-
 
 
 
49
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
50
 
51
  def start_streaming(data, random_selected_prompts):
52
- content_list = [prompt['prompt'] + '\n{' + data + '}\n\nsummary:' for prompt in random_selected_prompts]
 
 
 
53
  for response_data in stream_data(content_list, model):
54
- updates = [gr.update(value=response_data[i]) for i in range(len(columns))]
 
 
55
  yield tuple(updates)
56
-
57
  stream_button.click(
58
  fn=start_streaming,
59
  inputs=[data_textbox, state_random_selected_prompts],
60
  outputs=columns,
61
- show_progress=False
 
 
 
 
 
62
  )
63
 
64
- choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
65
-
66
  submit_button = gr.Button("Submit")
67
 
68
  output = gr.Textbox(label="You selected:", visible=False)
69
 
70
- def update_prompt_metrics(selected_choice, prompts, random_selected_prompts):
 
 
71
  if selected_choice == "Response 1":
72
- prompt_id = random_selected_prompts[0]['id']
73
  elif selected_choice == "Response 2":
74
- prompt_id = random_selected_prompts[1]['id']
75
  elif selected_choice == "Response 3":
76
- prompt_id = random_selected_prompts[2]['id']
77
  else:
78
  raise ValueError(f"No corresponding response of {selected_choice}")
79
 
80
  for prompt in prompts:
81
- if prompt['id'] == prompt_id:
82
  prompt["metric"]["winning_number"] += 1
83
  break
84
  else:
@@ -87,7 +106,11 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
87
  with open("prompt/prompt.json", "w") as f:
88
  json.dump(prompts, f)
89
 
90
- return gr.update(value=f"You selected: {selected_choice}", visible=True), gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
91
 
92
  submit_button.click(
93
  fn=update_prompt_metrics,
@@ -97,6 +120,7 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
97
 
98
  return demo
99
 
 
100
  if __name__ == "__main__":
101
  demo = create_arena()
102
  demo.queue()
 
10
 
11
  def random_data_selection():
12
  datapoint = random.choice(dataset)
13
+ datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
14
  return datapoint
15
 
16
+
17
  def create_arena():
18
  with open("prompt/prompt.json", "r") as file:
19
  json_data = file.read()
 
22
  with gr.Blocks(css=custom_css) as demo:
23
  with gr.Group():
24
  datapoint = random_data_selection()
25
+ gr.Markdown(
26
+ """This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
27
 
28
+ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
29
+ )
30
 
31
+ data_textbox = gr.Textbox(
32
+ label="Data",
33
+ lines=10,
34
+ placeholder="Datapoints to test...",
35
+ value=datapoint,
36
+ )
37
  with gr.Row():
38
  random_selection_button = gr.Button("Change Data")
39
  stream_button = gr.Button("✨ Click to Streaming ✨")
40
 
41
  random_selection_button.click(
42
+ fn=random_data_selection, inputs=[], outputs=[data_textbox]
 
 
43
  )
44
 
45
  random.shuffle(prompts)
 
48
  # Store prompts in state components
49
  state_prompts = gr.State(value=prompts)
50
  state_random_selected_prompts = gr.State(value=random_selected_prompts)
51
+
52
  with gr.Row():
53
+ columns = [
54
+ gr.Textbox(label=f"Prompt {i+1}", lines=10)
55
+ for i in range(len(random_selected_prompts))
56
+ ]
57
+
58
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
59
 
60
  def start_streaming(data, random_selected_prompts):
61
+ content_list = [
62
+ prompt["prompt"] + "\n{" + data + "}\n\nsummary:"
63
+ for prompt in random_selected_prompts
64
+ ]
65
  for response_data in stream_data(content_list, model):
66
+ updates = [
67
+ gr.update(value=response_data[i]) for i in range(len(columns))
68
+ ]
69
  yield tuple(updates)
70
+
71
  stream_button.click(
72
  fn=start_streaming,
73
  inputs=[data_textbox, state_random_selected_prompts],
74
  outputs=columns,
75
+ show_progress=False,
76
+ )
77
+
78
+ choice = gr.Radio(
79
+ label="Choose the best response:",
80
+ choices=["Response 1", "Response 2", "Response 3"],
81
  )
82
 
 
 
83
  submit_button = gr.Button("Submit")
84
 
85
  output = gr.Textbox(label="You selected:", visible=False)
86
 
87
+ def update_prompt_metrics(
88
+ selected_choice, prompts, random_selected_prompts
89
+ ):
90
  if selected_choice == "Response 1":
91
+ prompt_id = random_selected_prompts[0]["id"]
92
  elif selected_choice == "Response 2":
93
+ prompt_id = random_selected_prompts[1]["id"]
94
  elif selected_choice == "Response 3":
95
+ prompt_id = random_selected_prompts[2]["id"]
96
  else:
97
  raise ValueError(f"No corresponding response of {selected_choice}")
98
 
99
  for prompt in prompts:
100
+ if prompt["id"] == prompt_id:
101
  prompt["metric"]["winning_number"] += 1
102
  break
103
  else:
 
106
  with open("prompt/prompt.json", "w") as f:
107
  json.dump(prompts, f)
108
 
109
+ return (
110
+ gr.update(value=f"You selected: {selected_choice}", visible=True),
111
+ gr.update(interactive=False),
112
+ gr.update(interactive=False),
113
+ )
114
 
115
  submit_button.click(
116
  fn=update_prompt_metrics,
 
120
 
121
  return demo
122
 
123
+
124
  if __name__ == "__main__":
125
  demo = create_arena()
126
  demo.queue()
pages/batch_evaluation.py CHANGED
@@ -12,21 +12,22 @@ from utils.model import Model
12
 
13
  load_dotenv()
14
 
 
15
  def display_results(response_list):
16
- overall_score = np.mean([r['metric_score']['rouge_score'] for r in response_list])
17
-
18
  html_output = f"<h2>Overall Score: {overall_score:.2f}</h2>"
19
-
20
  for i, item in enumerate(response_list, 1):
21
- dialogue = item['dialogue']
22
- summary = item['summary']
23
- response = item['response']
24
- rouge_score = item['metric_score']['rouge_score']
25
-
26
- dialogue = html.escape(item['dialogue']).replace('\n', '<br>')
27
- summary = html.escape(item['summary']).replace('\n', '<br>')
28
- response = html.escape(item['response']).replace('\n', '<br>')
29
-
30
  html_output += f"""
31
  <details>
32
  <summary>Response {i} (Rouge Score: {rouge_score:.2f})</summary>
@@ -49,6 +50,7 @@ def display_results(response_list):
49
 
50
  return html_output
51
 
 
52
  def process(model_selection, prompt, num=10):
53
  response_list = []
54
  with open("test_samples/test_data.json", "r") as file:
@@ -57,21 +59,21 @@ def process(model_selection, prompt, num=10):
57
 
58
  for i, data in enumerate(dataset):
59
  logging.info(f"Start testing datapoint {i+1}")
60
- dialogue = data['dialogue']
61
- format = data['format']
62
- summary = data['summary']
63
- response = generate_answer(dialogue, model_selection, prompt + f' Output following {format} format.')
 
 
64
 
65
  rouge_score = metric_rouge_score(response, summary)
66
 
67
  response_list.append(
68
  {
69
- 'dialogue': dialogue,
70
- 'summary': summary,
71
- 'response': response,
72
- 'metric_score': {
73
- 'rouge_score': rouge_score
74
- }
75
  }
76
  )
77
 
@@ -81,22 +83,34 @@ def process(model_selection, prompt, num=10):
81
 
82
 
83
  def create_batch_evaluation_interface():
84
- with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
85
- gr.Markdown("## Here are evaluation setups. It will run though datapoints in test_data.josn to generate and evaluate. Show results once finished.")
 
 
 
 
86
 
87
- model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
88
- Template_text = gr.Textbox(value="""Summarize the following dialogue""", label='Input Prompting Template', lines=8, placeholder='Input your prompts')
 
 
 
 
 
 
 
 
 
89
  submit_button = gr.Button("✨ Submit ✨")
90
  output = gr.HTML(label="Results")
91
 
92
  submit_button.click(
93
- process,
94
- inputs=[model_dropdown, Template_text],
95
- outputs=output
96
  )
97
 
98
  return demo
99
 
 
100
  if __name__ == "__main__":
101
  demo = create_batch_evaluation_interface()
102
- demo.launch()
 
12
 
13
  load_dotenv()
14
 
15
+
16
  def display_results(response_list):
17
+ overall_score = np.mean([r["metric_score"]["rouge_score"] for r in response_list])
18
+
19
  html_output = f"<h2>Overall Score: {overall_score:.2f}</h2>"
20
+
21
  for i, item in enumerate(response_list, 1):
22
+ dialogue = item["dialogue"]
23
+ summary = item["summary"]
24
+ response = item["response"]
25
+ rouge_score = item["metric_score"]["rouge_score"]
26
+
27
+ dialogue = html.escape(item["dialogue"]).replace("\n", "<br>")
28
+ summary = html.escape(item["summary"]).replace("\n", "<br>")
29
+ response = html.escape(item["response"]).replace("\n", "<br>")
30
+
31
  html_output += f"""
32
  <details>
33
  <summary>Response {i} (Rouge Score: {rouge_score:.2f})</summary>
 
50
 
51
  return html_output
52
 
53
+
54
  def process(model_selection, prompt, num=10):
55
  response_list = []
56
  with open("test_samples/test_data.json", "r") as file:
 
59
 
60
  for i, data in enumerate(dataset):
61
  logging.info(f"Start testing datapoint {i+1}")
62
+ dialogue = data["dialogue"]
63
+ format = data["format"]
64
+ summary = data["summary"]
65
+ response = generate_answer(
66
+ dialogue, model_selection, prompt + f" Output following {format} format."
67
+ )
68
 
69
  rouge_score = metric_rouge_score(response, summary)
70
 
71
  response_list.append(
72
  {
73
+ "dialogue": dialogue,
74
+ "summary": summary,
75
+ "response": response,
76
+ "metric_score": {"rouge_score": rouge_score},
 
 
77
  }
78
  )
79
 
 
83
 
84
 
85
  def create_batch_evaluation_interface():
86
+ with gr.Blocks(
87
+ theme=gr.themes.Soft(spacing_size="sm", text_size="sm"), css=custom_css
88
+ ) as demo:
89
+ gr.Markdown(
90
+ "## Here are evaluation setups. It will run though datapoints in test_data.josn to generate and evaluate. Show results once finished."
91
+ )
92
 
93
+ model_dropdown = gr.Dropdown(
94
+ choices=Model.__model_list__,
95
+ label="Choose a model",
96
+ value=Model.__model_list__[0],
97
+ )
98
+ Template_text = gr.Textbox(
99
+ value="""Summarize the following dialogue""",
100
+ label="Input Prompting Template",
101
+ lines=8,
102
+ placeholder="Input your prompts",
103
+ )
104
  submit_button = gr.Button("✨ Submit ✨")
105
  output = gr.HTML(label="Results")
106
 
107
  submit_button.click(
108
+ process, inputs=[model_dropdown, Template_text], outputs=output
 
 
109
  )
110
 
111
  return demo
112
 
113
+
114
  if __name__ == "__main__":
115
  demo = create_batch_evaluation_interface()
116
+ demo.launch()
pages/leaderboard.py CHANGED
@@ -9,72 +9,90 @@ import pandas as pd
9
  def create_html_with_tooltip(id, base_url):
10
  return f'<a href="{base_url}"target="_blank">{id}</a>'
11
 
 
12
  # Load prompts from JSON
13
  with open("prompt/prompt.json", "r") as file:
14
  json_data = file.read()
15
  prompts = json.loads(json_data)
16
 
17
  # Prepare leaderboard data
18
- winning_rate = [prompt['metric']['winning_number'] for prompt in prompts]
19
- winning_rate = [round(num / sum(winning_rate), 4)for num in winning_rate]
20
  data = {
21
- 'Rank': [i+1 for i in range(len(prompts))],
22
- 'Methods': [create_html_with_tooltip(prompt['id'], prompt['url']) for prompt in prompts],
23
- 'Rouge Score': [prompt['metric']['Rouge'] for prompt in prompts],
24
- 'Winning Rate': winning_rate,
25
- 'Authors': [prompt['author'] for prompt in prompts],
 
 
26
  }
27
 
28
  # Create DataFrame and sort by Rouge Score
29
  df = pd.DataFrame(data)
30
- df.sort_values(by='Rouge Score', ascending=False, inplace=True, ignore_index=True)
31
- df['Rank'] = range(1, len(df) + 1)
32
 
33
  # Assign medals for top 3 authors
34
- medals = ['🏅', '🥈', '🥉']
35
  for i in range(3):
36
- df.loc[i, 'Authors'] = f"{medals[i]} {df.loc[i, 'Authors']}"
 
37
 
38
  # Function to update the leaderboard
39
  def update_leaderboard(sort_by):
40
  sorted_df = df.sort_values(by=sort_by, ascending=False, ignore_index=True)
41
- sorted_df['Rank'] = range(1, len(sorted_df) + 1)
42
 
43
  # Convert DataFrame to HTML with clickable headers for sorting
44
  table_html = sorted_df.to_html(index=False, escape=False)
45
 
46
  # Add sorting links to column headers
47
  for column in sorted_df.columns:
48
- table_html = table_html.replace(f'<th>{column}</th>',
49
- f'<th><a href="#" onclick="sortBy(\'{column}\'); return false;">{column}</a></th>')
 
 
50
 
51
  return table_html
52
 
 
53
  # Define Gradio interface
54
  def create_leaderboard():
55
- with gr.Blocks(css="""
 
56
  .tooltip { cursor: pointer; color: blue; text-decoration: underline; }
57
  table { border-collapse: collapse; width: 100%; }
58
  th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
59
  th { background-color: #f2f2f2; }
60
  #prompt-display { display: none; }
61
- """) as demo:
 
62
  gr.Markdown("# 🏆 Summarization Arena Leaderboard")
63
  with gr.Row():
64
- gr.Markdown("[Blog](placeholder) | [GitHub](placeholder) | [Paper](placeholder) | [Dataset](placeholder) | [Twitter](placeholder) | [Discord](placeholder)")
65
- gr.Markdown("Welcome to our open platform for evaluating LLM summarization capabilities.")
66
-
 
 
 
 
67
  # Dropdown for sorting
68
  sort_by = gr.Dropdown(list(df.columns), label="Sort by", value="Rouge Score")
69
 
70
  # Display the leaderboard
71
  leaderboard = gr.HTML(update_leaderboard("Rouge Score"), elem_id="leaderboard")
72
-
73
  # Change sorting when dropdown is changed
74
- sort_by.change(fn=lambda sort: update_leaderboard(sort), inputs=sort_by, outputs=leaderboard)
 
 
 
 
75
 
76
  return demo
77
 
 
78
  # Launch Gradio interface
79
  if __name__ == "__main__":
80
  demo = create_leaderboard()
 
9
  def create_html_with_tooltip(id, base_url):
10
  return f'<a href="{base_url}"target="_blank">{id}</a>'
11
 
12
+
13
  # Load prompts from JSON
14
  with open("prompt/prompt.json", "r") as file:
15
  json_data = file.read()
16
  prompts = json.loads(json_data)
17
 
18
  # Prepare leaderboard data
19
+ winning_rate = [prompt["metric"]["winning_number"] for prompt in prompts]
20
+ winning_rate = [round(num / sum(winning_rate), 4) for num in winning_rate]
21
  data = {
22
+ "Rank": [i + 1 for i in range(len(prompts))],
23
+ "Methods": [
24
+ create_html_with_tooltip(prompt["id"], prompt["url"]) for prompt in prompts
25
+ ],
26
+ "Rouge Score": [prompt["metric"]["Rouge"] for prompt in prompts],
27
+ "Winning Rate": winning_rate,
28
+ "Authors": [prompt["author"] for prompt in prompts],
29
  }
30
 
31
  # Create DataFrame and sort by Rouge Score
32
  df = pd.DataFrame(data)
33
+ df.sort_values(by="Rouge Score", ascending=False, inplace=True, ignore_index=True)
34
+ df["Rank"] = range(1, len(df) + 1)
35
 
36
  # Assign medals for top 3 authors
37
+ medals = ["🏅", "🥈", "🥉"]
38
  for i in range(3):
39
+ df.loc[i, "Authors"] = f"{medals[i]} {df.loc[i, 'Authors']}"
40
+
41
 
42
  # Function to update the leaderboard
43
  def update_leaderboard(sort_by):
44
  sorted_df = df.sort_values(by=sort_by, ascending=False, ignore_index=True)
45
+ sorted_df["Rank"] = range(1, len(sorted_df) + 1)
46
 
47
  # Convert DataFrame to HTML with clickable headers for sorting
48
  table_html = sorted_df.to_html(index=False, escape=False)
49
 
50
  # Add sorting links to column headers
51
  for column in sorted_df.columns:
52
+ table_html = table_html.replace(
53
+ f"<th>{column}</th>",
54
+ f'<th><a href="#" onclick="sortBy(\'{column}\'); return false;">{column}</a></th>',
55
+ )
56
 
57
  return table_html
58
 
59
+
60
  # Define Gradio interface
61
  def create_leaderboard():
62
+ with gr.Blocks(
63
+ css="""
64
  .tooltip { cursor: pointer; color: blue; text-decoration: underline; }
65
  table { border-collapse: collapse; width: 100%; }
66
  th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
67
  th { background-color: #f2f2f2; }
68
  #prompt-display { display: none; }
69
+ """
70
+ ) as demo:
71
  gr.Markdown("# 🏆 Summarization Arena Leaderboard")
72
  with gr.Row():
73
+ gr.Markdown(
74
+ "[Blog](placeholder) | [GitHub](placeholder) | [Paper](placeholder) | [Dataset](placeholder) | [Twitter](placeholder) | [Discord](placeholder)"
75
+ )
76
+ gr.Markdown(
77
+ "Welcome to our open platform for evaluating LLM summarization capabilities."
78
+ )
79
+
80
  # Dropdown for sorting
81
  sort_by = gr.Dropdown(list(df.columns), label="Sort by", value="Rouge Score")
82
 
83
  # Display the leaderboard
84
  leaderboard = gr.HTML(update_leaderboard("Rouge Score"), elem_id="leaderboard")
85
+
86
  # Change sorting when dropdown is changed
87
+ sort_by.change(
88
+ fn=lambda sort: update_leaderboard(sort),
89
+ inputs=sort_by,
90
+ outputs=leaderboard,
91
+ )
92
 
93
  return demo
94
 
95
+
96
  # Launch Gradio interface
97
  if __name__ == "__main__":
98
  demo = create_leaderboard()
pages/summarization_playground.py CHANGED
@@ -65,27 +65,26 @@ input-label {
65
  }
66
  """
67
 
68
- __model_on_gpu__ = ''
69
  model = {model_name: None for model_name in Model.__model_list__}
70
 
71
- random_label = '🔀 Random dialogue from dataset'
72
  examples = {
73
  "example 1": """Boston's injury reporting for Kristaps Porziņģis has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
74
  Joe Mazzulla said Porziņģis would "only be used in specific instances, if necessary." That sounds like the team doesn't want to risk further injury to his dislocated Posterior Tibialis (or some other body part, due to overcompensation for the ankle), unless it's in a desperate situation.
75
  Being up 3-1, with Game 5 at home, doesn't qualify as desperate. So, expect the Celtics to continue slow-playing KP's return.
76
  It'd obviously be nice for Boston to have his rim protection and jump shooting back. It was missed in the Game 4 blowout, but the Celtics have also demonstrated they can win without the big man throughout this campaign.
77
  On top of winning Game 3 of this series, Boston is plus-10.9 points per 100 possessions when Porziņģis has been off the floor this regular and postseason.""",
78
-
79
  "example 2": """Prior to the Finals, we predicted that Dereck Lively II's minutes would swell over the course of the series, and that's starting to play out.
80
  He averaged 18.8 minutes in Games 1 and 2 and was up to 26.2 in Games 3 and 4. That's with the regulars being pulled long before the final buzzer in Friday's game, too.
81
  Expect the rookie's playing time to continue to climb in Game 5. It seems increasingly clear that coach Jason Kidd trusts him over the rest of Dallas' bigs, and it's not hard to see why.
82
  Lively has been absolutely relentless on the offensive glass all postseason. He makes solid decisions as a passer when his rolls don't immediately lead to dunks. And he's not a liability when caught defending guards or wings outside.
83
  All of that has led to postseason averages of 8.2 points, 7.6 rebounds, 1.4 assists and 1.0 blocks in just 21.9 minutes, as well as a double-double in 22 minutes of Game 4.
84
  Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 30 minutes and reach double-figures in both scoring and rebounding again.""",
85
-
86
- random_label: ""
87
  }
88
 
 
89
  def model_device_check(model_name):
90
  global __model_on_gpu__
91
 
@@ -106,56 +105,134 @@ def get_model_batch_generation(model_name):
106
  return model[model_name]
107
 
108
 
109
- def generate_answer(sources, model_name, prompt, temperature=0.0001, max_new_tokens=500, do_sample=True):
 
 
110
  model_device_check(model_name)
111
- content = prompt + '\n{' + sources + '}\n\nsummary:'
112
- answer = model[model_name].gen(content,temperature,max_new_tokens,do_sample)[0].strip()
 
 
 
 
113
 
114
  return answer
115
 
116
- def process_input(input_text, model_selection, prompt, temperature=0.0001, max_new_tokens=500, do_sample=True):
 
 
 
 
 
 
 
 
117
  if input_text:
118
  logging.info("Start generation")
119
- response = generate_answer(input_text, model_selection, prompt, temperature, max_new_tokens, do_sample)
120
- return f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
 
 
 
 
121
  else:
122
  return "Please fill the input to generate outputs."
123
 
 
124
  def update_input(example):
125
  if example == random_label:
126
  datapoint = random.choice(dataset)
127
- return datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
128
  return examples[example]
129
 
 
130
  def create_summarization_interface():
131
- with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
132
- gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
 
 
 
 
133
 
134
  with gr.Row():
135
- example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
136
- model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
137
-
138
- gr.Markdown("<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt 👥</h3></center></div>")
139
- Template_text = gr.Textbox(value="""Summarize the following dialogue""", label='Input Prompting Template', lines=4, placeholder='Input your prompts')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  datapoint = random.choice(dataset)
141
- input_text = gr.Textbox(label="Input Dialogue", lines=7, placeholder="Enter text here...", value=datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'])
 
 
 
 
 
142
  submit_button = gr.Button("✨ Submit ✨")
143
 
144
  with gr.Row():
145
  with gr.Column(scale=1):
146
- gr.Markdown("<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters 📈</h3></center></b></div>")
 
 
147
  with gr.Column():
148
- temperature = gr.Number(label="Temperature",elem_classes="parameter-text", value=0.0001, minimum=0.000001, maximum=1.0)
149
- max_new_tokens = gr.Number(label="Max New Tokens",elem_classes="parameter-text", value=500, precision=0, minimum=0, maximum=500)
150
- do_sample = gr.Dropdown([True,False],label="Do Sample",elem_classes="parameter-text", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  with gr.Column(scale=3):
152
  output = gr.Markdown(line_breaks=True)
153
 
154
- example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
155
- submit_button.click(process_input, inputs=[input_text,model_dropdown,Template_text,temperature,max_new_tokens,do_sample], outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  return demo
158
 
 
159
  if __name__ == "__main__":
160
  demo = create_summarization_interface()
161
  demo.launch()
 
65
  }
66
  """
67
 
68
+ __model_on_gpu__ = ""
69
  model = {model_name: None for model_name in Model.__model_list__}
70
 
71
+ random_label = "🔀 Random dialogue from dataset"
72
  examples = {
73
  "example 1": """Boston's injury reporting for Kristaps Porziņģis has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
74
  Joe Mazzulla said Porziņģis would "only be used in specific instances, if necessary." That sounds like the team doesn't want to risk further injury to his dislocated Posterior Tibialis (or some other body part, due to overcompensation for the ankle), unless it's in a desperate situation.
75
  Being up 3-1, with Game 5 at home, doesn't qualify as desperate. So, expect the Celtics to continue slow-playing KP's return.
76
  It'd obviously be nice for Boston to have his rim protection and jump shooting back. It was missed in the Game 4 blowout, but the Celtics have also demonstrated they can win without the big man throughout this campaign.
77
  On top of winning Game 3 of this series, Boston is plus-10.9 points per 100 possessions when Porziņģis has been off the floor this regular and postseason.""",
 
78
  "example 2": """Prior to the Finals, we predicted that Dereck Lively II's minutes would swell over the course of the series, and that's starting to play out.
79
  He averaged 18.8 minutes in Games 1 and 2 and was up to 26.2 in Games 3 and 4. That's with the regulars being pulled long before the final buzzer in Friday's game, too.
80
  Expect the rookie's playing time to continue to climb in Game 5. It seems increasingly clear that coach Jason Kidd trusts him over the rest of Dallas' bigs, and it's not hard to see why.
81
  Lively has been absolutely relentless on the offensive glass all postseason. He makes solid decisions as a passer when his rolls don't immediately lead to dunks. And he's not a liability when caught defending guards or wings outside.
82
  All of that has led to postseason averages of 8.2 points, 7.6 rebounds, 1.4 assists and 1.0 blocks in just 21.9 minutes, as well as a double-double in 22 minutes of Game 4.
83
  Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 30 minutes and reach double-figures in both scoring and rebounding again.""",
84
+ random_label: "",
 
85
  }
86
 
87
+
88
  def model_device_check(model_name):
89
  global __model_on_gpu__
90
 
 
105
  return model[model_name]
106
 
107
 
108
+ def generate_answer(
109
+ sources, model_name, prompt, temperature=0.0001, max_new_tokens=500, do_sample=True
110
+ ):
111
  model_device_check(model_name)
112
+ content = prompt + "\n{" + sources + "}\n\nsummary:"
113
+ answer = (
114
+ model[model_name]
115
+ .gen(content, temperature, max_new_tokens, do_sample)[0]
116
+ .strip()
117
+ )
118
 
119
  return answer
120
 
121
+
122
+ def process_input(
123
+ input_text,
124
+ model_selection,
125
+ prompt,
126
+ temperature=0.0001,
127
+ max_new_tokens=500,
128
+ do_sample=True,
129
+ ):
130
  if input_text:
131
  logging.info("Start generation")
132
+ response = generate_answer(
133
+ input_text, model_selection, prompt, temperature, max_new_tokens, do_sample
134
+ )
135
+ return (
136
+ f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
137
+ )
138
  else:
139
  return "Please fill the input to generate outputs."
140
 
141
+
142
  def update_input(example):
143
  if example == random_label:
144
  datapoint = random.choice(dataset)
145
+ return datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
146
  return examples[example]
147
 
148
+
149
  def create_summarization_interface():
150
+ with gr.Blocks(
151
+ theme=gr.themes.Soft(spacing_size="sm", text_size="sm"), css=custom_css
152
+ ) as demo:
153
+ gr.Markdown(
154
+ "## This is a playground to test prompts for clinical dialogue summarizations"
155
+ )
156
 
157
  with gr.Row():
158
+ example_dropdown = gr.Dropdown(
159
+ choices=list(examples.keys()),
160
+ label="Choose an example",
161
+ value=random_label,
162
+ )
163
+ model_dropdown = gr.Dropdown(
164
+ choices=Model.__model_list__,
165
+ label="Choose a model",
166
+ value=Model.__model_list__[0],
167
+ )
168
+
169
+ gr.Markdown(
170
+ "<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt 👥</h3></center></div>"
171
+ )
172
+ Template_text = gr.Textbox(
173
+ value="""Summarize the following dialogue""",
174
+ label="Input Prompting Template",
175
+ lines=4,
176
+ placeholder="Input your prompts",
177
+ )
178
  datapoint = random.choice(dataset)
179
+ input_text = gr.Textbox(
180
+ label="Input Dialogue",
181
+ lines=7,
182
+ placeholder="Enter text here...",
183
+ value=datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"],
184
+ )
185
  submit_button = gr.Button("✨ Submit ✨")
186
 
187
  with gr.Row():
188
  with gr.Column(scale=1):
189
+ gr.Markdown(
190
+ "<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters 📈</h3></center></b></div>"
191
+ )
192
  with gr.Column():
193
+ temperature = gr.Number(
194
+ label="Temperature",
195
+ elem_classes="parameter-text",
196
+ value=0.0001,
197
+ minimum=0.000001,
198
+ maximum=1.0,
199
+ )
200
+ max_new_tokens = gr.Number(
201
+ label="Max New Tokens",
202
+ elem_classes="parameter-text",
203
+ value=500,
204
+ precision=0,
205
+ minimum=0,
206
+ maximum=500,
207
+ )
208
+ do_sample = gr.Dropdown(
209
+ [True, False],
210
+ label="Do Sample",
211
+ elem_classes="parameter-text",
212
+ value=True,
213
+ )
214
  with gr.Column(scale=3):
215
  output = gr.Markdown(line_breaks=True)
216
 
217
+ example_dropdown.change(
218
+ update_input, inputs=[example_dropdown], outputs=[input_text]
219
+ )
220
+ submit_button.click(
221
+ process_input,
222
+ inputs=[
223
+ input_text,
224
+ model_dropdown,
225
+ Template_text,
226
+ temperature,
227
+ max_new_tokens,
228
+ do_sample,
229
+ ],
230
+ outputs=[output],
231
+ )
232
 
233
  return demo
234
 
235
+
236
  if __name__ == "__main__":
237
  demo = create_summarization_interface()
238
  demo.launch()
utils/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
  # This is the __init__.py file for the utils package
2
  # You can add any initialization code or import statements here
3
 
4
- __all__ = ['multiple_stream', 'model', 'data', 'metric']
 
1
  # This is the __init__.py file for the utils package
2
  # You can add any initialization code or import statements here
3
 
4
+ __all__ = ["multiple_stream", "model", "data", "metric"]
utils/data.py CHANGED
@@ -1,4 +1,4 @@
1
- from datasets import load_dataset
2
- dialogsum = load_dataset('har1/MTS_Dialogue-Clinical_Note')
3
- dataset = list(dialogsum['train'])
4
 
 
 
 
1
+ from datasets import load_dataset
 
 
2
 
3
+ dialogsum = load_dataset("har1/MTS_Dialogue-Clinical_Note")
4
+ dataset = list(dialogsum["train"])
utils/metric.py CHANGED
@@ -1,6 +1,7 @@
1
  from rouge_score import rouge_scorer
2
 
3
- scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
 
4
 
5
  def metric_rouge_score(pred, ref):
6
- return scorer.score(pred, ref)['rougeL'].fmeasure
 
1
  from rouge_score import rouge_scorer
2
 
3
+ scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
4
+
5
 
6
  def metric_rouge_score(pred, ref):
7
+ return scorer.score(pred, ref)["rougeL"].fmeasure
utils/model.py CHANGED
@@ -6,7 +6,8 @@ from huggingface_hub import login
6
  from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
7
  from vllm import LLM, SamplingParams
8
 
9
- login(token=os.getenv('HF_TOKEN'))
 
10
 
11
  class Model(torch.nn.Module):
12
  number_of_models = 0
@@ -15,17 +16,17 @@ class Model(torch.nn.Module):
15
  "lmsys/vicuna-7b-v1.5",
16
  "google-t5/t5-large",
17
  "mistralai/Mistral-7B-Instruct-v0.1",
18
- "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
  ]
20
 
21
  def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
22
  super(Model, self).__init__()
23
-
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
26
  self.use_vllm = model_name != "google-t5/t5-large"
27
 
28
- logging.info(f'Start loading model {self.name}')
29
 
30
  if self.use_vllm:
31
  # 使用vLLM加载模型
@@ -33,18 +34,16 @@ class Model(torch.nn.Module):
33
  model=model_name,
34
  dtype="half",
35
  tokenizer=model_name,
36
- trust_remote_code=True
37
  )
38
  else:
39
  # 加载原始transformers模型
40
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
41
- model_name,
42
- torch_dtype=torch.bfloat16,
43
- device_map="auto"
44
  )
45
  self.model.eval()
46
 
47
- logging.info(f'Loaded model {self.name}')
48
  self.update()
49
 
50
  @classmethod
@@ -56,13 +55,15 @@ class Model(torch.nn.Module):
56
  sampling_params = SamplingParams(
57
  temperature=temp,
58
  max_tokens=max_length,
59
- #top_p=0.95 if do_sample else 1.0,
60
- stop_token_ids=[self.tokenizer.eos_token_id]
61
  )
62
  outputs = self.llm.generate(content_list, sampling_params)
63
  return [output.outputs[0].text for output in outputs]
64
  else:
65
- input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
 
 
66
  outputs = self.model.generate(
67
  input_ids,
68
  max_new_tokens=max_length,
@@ -70,7 +71,9 @@ class Model(torch.nn.Module):
70
  temperature=temp,
71
  eos_token_id=self.tokenizer.eos_token_id,
72
  )
73
- return self.tokenizer.batch_decode(outputs[:, input_ids.shape[1]:], skip_special_tokens=True)
 
 
74
 
75
  def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
76
  if self.use_vllm:
@@ -78,24 +81,28 @@ class Model(torch.nn.Module):
78
  temperature=temp,
79
  max_tokens=max_length,
80
  top_p=0.95 if do_sample else 1.0,
81
- stop_token_ids=[self.tokenizer.eos_token_id]
82
  )
83
  outputs = self.llm.generate(content_list, sampling_params, stream=True)
84
-
85
  prev_token_ids = [[] for _ in content_list]
86
-
87
  for output in outputs:
88
  for i, request_output in enumerate(output.outputs):
89
  current_token_ids = request_output.token_ids
90
- new_token_ids = current_token_ids[len(prev_token_ids[i]):]
91
  prev_token_ids[i] = current_token_ids.copy()
92
-
93
  for token_id in new_token_ids:
94
- token_text = self.tokenizer.decode(token_id, skip_special_tokens=True)
 
 
95
  yield i, token_text
96
  else:
97
- input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
98
-
 
 
99
  gen_kwargs = {
100
  "input_ids": input_ids,
101
  "do_sample": do_sample,
@@ -103,7 +110,7 @@ class Model(torch.nn.Module):
103
  "eos_token_id": self.tokenizer.eos_token_id,
104
  "max_new_tokens": 1,
105
  "return_dict_in_generate": True,
106
- "output_scores": True
107
  }
108
 
109
  generated_tokens = 0
@@ -113,16 +120,26 @@ class Model(torch.nn.Module):
113
  while generated_tokens < max_length and len(active_sequences) > 0:
114
  with torch.no_grad():
115
  output = self.model.generate(**gen_kwargs)
116
-
117
  next_tokens = output.sequences[:, -1].unsqueeze(-1)
118
-
119
  for i, token in zip(active_sequences, next_tokens):
120
- yield i.item(), self.tokenizer.decode(token[0], skip_special_tokens=True)
 
 
121
 
122
- gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
 
 
123
  generated_tokens += 1
124
 
125
- completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
126
- active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
 
 
 
 
 
 
127
  if len(active_sequences) > 0:
128
- gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
 
6
  from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
7
  from vllm import LLM, SamplingParams
8
 
9
+ login(token=os.getenv("HF_TOKEN"))
10
+
11
 
12
  class Model(torch.nn.Module):
13
  number_of_models = 0
 
16
  "lmsys/vicuna-7b-v1.5",
17
  "google-t5/t5-large",
18
  "mistralai/Mistral-7B-Instruct-v0.1",
19
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
20
  ]
21
 
22
  def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
23
  super(Model, self).__init__()
24
+
25
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
  self.name = model_name
27
  self.use_vllm = model_name != "google-t5/t5-large"
28
 
29
+ logging.info(f"Start loading model {self.name}")
30
 
31
  if self.use_vllm:
32
  # 使用vLLM加载模型
 
34
  model=model_name,
35
  dtype="half",
36
  tokenizer=model_name,
37
+ trust_remote_code=True,
38
  )
39
  else:
40
  # 加载原始transformers模型
41
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
42
+ model_name, torch_dtype=torch.bfloat16, device_map="auto"
 
 
43
  )
44
  self.model.eval()
45
 
46
+ logging.info(f"Loaded model {self.name}")
47
  self.update()
48
 
49
  @classmethod
 
55
  sampling_params = SamplingParams(
56
  temperature=temp,
57
  max_tokens=max_length,
58
+ # top_p=0.95 if do_sample else 1.0,
59
+ stop_token_ids=[self.tokenizer.eos_token_id],
60
  )
61
  outputs = self.llm.generate(content_list, sampling_params)
62
  return [output.outputs[0].text for output in outputs]
63
  else:
64
+ input_ids = self.tokenizer(
65
+ content_list, return_tensors="pt", padding=True, truncation=True
66
+ ).input_ids.to(self.model.device)
67
  outputs = self.model.generate(
68
  input_ids,
69
  max_new_tokens=max_length,
 
71
  temperature=temp,
72
  eos_token_id=self.tokenizer.eos_token_id,
73
  )
74
+ return self.tokenizer.batch_decode(
75
+ outputs[:, input_ids.shape[1] :], skip_special_tokens=True
76
+ )
77
 
78
  def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
79
  if self.use_vllm:
 
81
  temperature=temp,
82
  max_tokens=max_length,
83
  top_p=0.95 if do_sample else 1.0,
84
+ stop_token_ids=[self.tokenizer.eos_token_id],
85
  )
86
  outputs = self.llm.generate(content_list, sampling_params, stream=True)
87
+
88
  prev_token_ids = [[] for _ in content_list]
89
+
90
  for output in outputs:
91
  for i, request_output in enumerate(output.outputs):
92
  current_token_ids = request_output.token_ids
93
+ new_token_ids = current_token_ids[len(prev_token_ids[i]) :]
94
  prev_token_ids[i] = current_token_ids.copy()
95
+
96
  for token_id in new_token_ids:
97
+ token_text = self.tokenizer.decode(
98
+ token_id, skip_special_tokens=True
99
+ )
100
  yield i, token_text
101
  else:
102
+ input_ids = self.tokenizer(
103
+ content_list, return_tensors="pt", padding=True, truncation=True
104
+ ).input_ids.to(self.model.device)
105
+
106
  gen_kwargs = {
107
  "input_ids": input_ids,
108
  "do_sample": do_sample,
 
110
  "eos_token_id": self.tokenizer.eos_token_id,
111
  "max_new_tokens": 1,
112
  "return_dict_in_generate": True,
113
+ "output_scores": True,
114
  }
115
 
116
  generated_tokens = 0
 
120
  while generated_tokens < max_length and len(active_sequences) > 0:
121
  with torch.no_grad():
122
  output = self.model.generate(**gen_kwargs)
123
+
124
  next_tokens = output.sequences[:, -1].unsqueeze(-1)
125
+
126
  for i, token in zip(active_sequences, next_tokens):
127
+ yield i.item(), self.tokenizer.decode(
128
+ token[0], skip_special_tokens=True
129
+ )
130
 
131
+ gen_kwargs["input_ids"] = torch.cat(
132
+ [gen_kwargs["input_ids"], next_tokens], dim=-1
133
+ )
134
  generated_tokens += 1
135
 
136
+ completed = (
137
+ (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id)
138
+ .nonzero()
139
+ .squeeze(-1)
140
+ )
141
+ active_sequences = torch.tensor(
142
+ [i for i in active_sequences if i not in completed]
143
+ )
144
  if len(active_sequences) > 0:
145
+ gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
utils/multiple_stream.py CHANGED
@@ -7,32 +7,36 @@ TEST = """ Test of Time. A Benchmark for Evaluating LLMs on Temporal Reasoning.
7
  showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
8
  reasoning tasks involving complex temporal logic. """
9
 
 
10
  def generate_data_test():
11
  """Generator to yield words"""
12
  temp = copy.deepcopy(TEST)
13
  l1 = temp.split()
14
  random.shuffle(l1)
15
- temp = ' '.join(l1)
16
  for word in temp.split(" "):
17
  yield word + " "
18
 
 
19
  def stream_data(content_list, model):
20
  """Stream data to three columns"""
21
  outputs = ["" for _ in content_list]
22
 
23
  # Use the gen method to handle batch generation
24
  generator = model.streaming(content_list)
25
-
26
  while True:
27
  updated = False
28
 
29
  try:
30
- id, word = next(generator) # Get the next generated word for the corresponding content
 
 
31
  outputs[id] += f"{word} "
32
  updated = True
33
  except StopIteration:
34
  break
35
-
36
  if updated:
37
  yield tuple(outputs)
38
 
@@ -41,21 +45,22 @@ def create_interface():
41
  with gr.Blocks() as demo:
42
  with gr.Group():
43
  with gr.Row():
44
- columns = [gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(3)]
45
-
 
 
46
  start_btn = gr.Button("Start Streaming")
47
-
48
  def start_streaming():
49
- content_list = [col.value for col in columns] # Get input texts from text boxes
 
 
50
  for data in stream_data(content_list):
51
  updates = [gr.update(value=data[i]) for i in range(len(columns))]
52
  yield tuple(updates)
53
-
54
  start_btn.click(
55
- fn=start_streaming,
56
- inputs=[],
57
- outputs=columns,
58
- show_progress=False
59
  )
60
 
61
  return demo
@@ -64,4 +69,4 @@ def create_interface():
64
  if __name__ == "__main__":
65
  demo = create_interface()
66
  demo.queue()
67
- demo.launch()
 
7
  showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
8
  reasoning tasks involving complex temporal logic. """
9
 
10
+
11
  def generate_data_test():
12
  """Generator to yield words"""
13
  temp = copy.deepcopy(TEST)
14
  l1 = temp.split()
15
  random.shuffle(l1)
16
+ temp = " ".join(l1)
17
  for word in temp.split(" "):
18
  yield word + " "
19
 
20
+
21
  def stream_data(content_list, model):
22
  """Stream data to three columns"""
23
  outputs = ["" for _ in content_list]
24
 
25
  # Use the gen method to handle batch generation
26
  generator = model.streaming(content_list)
27
+
28
  while True:
29
  updated = False
30
 
31
  try:
32
+ id, word = next(
33
+ generator
34
+ ) # Get the next generated word for the corresponding content
35
  outputs[id] += f"{word} "
36
  updated = True
37
  except StopIteration:
38
  break
39
+
40
  if updated:
41
  yield tuple(outputs)
42
 
 
45
  with gr.Blocks() as demo:
46
  with gr.Group():
47
  with gr.Row():
48
+ columns = [
49
+ gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(3)
50
+ ]
51
+
52
  start_btn = gr.Button("Start Streaming")
53
+
54
  def start_streaming():
55
+ content_list = [
56
+ col.value for col in columns
57
+ ] # Get input texts from text boxes
58
  for data in stream_data(content_list):
59
  updates = [gr.update(value=data[i]) for i in range(len(columns))]
60
  yield tuple(updates)
61
+
62
  start_btn.click(
63
+ fn=start_streaming, inputs=[], outputs=columns, show_progress=False
 
 
 
64
  )
65
 
66
  return demo
 
69
  if __name__ == "__main__":
70
  demo = create_interface()
71
  demo.queue()
72
+ demo.launch()