Chris-lab / pages /arena.py
kz209
update format
031841d
import json
import random
import gradio as gr
from pages.summarization_playground import custom_css, get_model_batch_generation
from utils.data import dataset
from utils.multiple_stream import stream_data
def random_data_selection():
datapoint = random.choice(dataset)
datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
return datapoint
def create_arena():
with open("prompt/prompt.json", "r") as file:
json_data = file.read()
prompts = json.loads(json_data)
with gr.Blocks(css=custom_css) as demo:
with gr.Group():
datapoint = random_data_selection()
gr.Markdown(
"""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
)
data_textbox = gr.Textbox(
label="Data",
lines=10,
placeholder="Datapoints to test...",
value=datapoint,
)
with gr.Row():
random_selection_button = gr.Button("Change Data")
stream_button = gr.Button("✨ Click to Streaming ✨")
random_selection_button.click(
fn=random_data_selection, inputs=[], outputs=[data_textbox]
)
random.shuffle(prompts)
random_selected_prompts = prompts[:3]
# Store prompts in state components
state_prompts = gr.State(value=prompts)
state_random_selected_prompts = gr.State(value=random_selected_prompts)
with gr.Row():
columns = [
gr.Textbox(label=f"Prompt {i+1}", lines=10)
for i in range(len(random_selected_prompts))
]
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
def start_streaming(data, random_selected_prompts):
content_list = [
prompt["prompt"] + "\n{" + data + "}\n\nsummary:"
for prompt in random_selected_prompts
]
for response_data in stream_data(content_list, model):
updates = [
gr.update(value=response_data[i]) for i in range(len(columns))
]
yield tuple(updates)
stream_button.click(
fn=start_streaming,
inputs=[data_textbox, state_random_selected_prompts],
outputs=columns,
show_progress=False,
)
choice = gr.Radio(
label="Choose the best response:",
choices=["Response 1", "Response 2", "Response 3"],
)
submit_button = gr.Button("Submit")
output = gr.Textbox(label="You selected:", visible=False)
def update_prompt_metrics(
selected_choice, prompts, random_selected_prompts
):
if selected_choice == "Response 1":
prompt_id = random_selected_prompts[0]["id"]
elif selected_choice == "Response 2":
prompt_id = random_selected_prompts[1]["id"]
elif selected_choice == "Response 3":
prompt_id = random_selected_prompts[2]["id"]
else:
raise ValueError(f"No corresponding response of {selected_choice}")
for prompt in prompts:
if prompt["id"] == prompt_id:
prompt["metric"]["winning_number"] += 1
break
else:
raise ValueError(f"No prompt of id {prompt_id}")
with open("prompt/prompt.json", "w") as f:
json.dump(prompts, f)
return (
gr.update(value=f"You selected: {selected_choice}", visible=True),
gr.update(interactive=False),
gr.update(interactive=False),
)
submit_button.click(
fn=update_prompt_metrics,
inputs=[choice, state_prompts, state_random_selected_prompts],
outputs=[output, choice, submit_button],
)
return demo
if __name__ == "__main__":
demo = create_arena()
demo.queue()
demo.launch()