Spaces:
Sleeping
Sleeping
import json | |
import random | |
import gradio as gr | |
from pages.summarization_playground import custom_css, get_model_batch_generation | |
from utils.data import dataset | |
from utils.multiple_stream import stream_data | |
def random_data_selection(): | |
datapoint = random.choice(dataset) | |
datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"] | |
return datapoint | |
def create_arena(): | |
with open("prompt/prompt.json", "r") as file: | |
json_data = file.read() | |
prompts = json.loads(json_data) | |
with gr.Blocks(css=custom_css) as demo: | |
with gr.Group(): | |
datapoint = random_data_selection() | |
gr.Markdown( | |
"""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt. | |
Once the streaming is complete, you can choose the best response.\u2764\ufe0f""" | |
) | |
data_textbox = gr.Textbox( | |
label="Data", | |
lines=10, | |
placeholder="Datapoints to test...", | |
value=datapoint, | |
) | |
with gr.Row(): | |
random_selection_button = gr.Button("Change Data") | |
stream_button = gr.Button("✨ Click to Streaming ✨") | |
random_selection_button.click( | |
fn=random_data_selection, inputs=[], outputs=[data_textbox] | |
) | |
random.shuffle(prompts) | |
random_selected_prompts = prompts[:3] | |
# Store prompts in state components | |
state_prompts = gr.State(value=prompts) | |
state_random_selected_prompts = gr.State(value=random_selected_prompts) | |
with gr.Row(): | |
columns = [ | |
gr.Textbox(label=f"Prompt {i+1}", lines=10) | |
for i in range(len(random_selected_prompts)) | |
] | |
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct") | |
def start_streaming(data, random_selected_prompts): | |
content_list = [ | |
prompt["prompt"] + "\n{" + data + "}\n\nsummary:" | |
for prompt in random_selected_prompts | |
] | |
for response_data in stream_data(content_list, model): | |
updates = [ | |
gr.update(value=response_data[i]) for i in range(len(columns)) | |
] | |
yield tuple(updates) | |
stream_button.click( | |
fn=start_streaming, | |
inputs=[data_textbox, state_random_selected_prompts], | |
outputs=columns, | |
show_progress=False, | |
) | |
choice = gr.Radio( | |
label="Choose the best response:", | |
choices=["Response 1", "Response 2", "Response 3"], | |
) | |
submit_button = gr.Button("Submit") | |
output = gr.Textbox(label="You selected:", visible=False) | |
def update_prompt_metrics( | |
selected_choice, prompts, random_selected_prompts | |
): | |
if selected_choice == "Response 1": | |
prompt_id = random_selected_prompts[0]["id"] | |
elif selected_choice == "Response 2": | |
prompt_id = random_selected_prompts[1]["id"] | |
elif selected_choice == "Response 3": | |
prompt_id = random_selected_prompts[2]["id"] | |
else: | |
raise ValueError(f"No corresponding response of {selected_choice}") | |
for prompt in prompts: | |
if prompt["id"] == prompt_id: | |
prompt["metric"]["winning_number"] += 1 | |
break | |
else: | |
raise ValueError(f"No prompt of id {prompt_id}") | |
with open("prompt/prompt.json", "w") as f: | |
json.dump(prompts, f) | |
return ( | |
gr.update(value=f"You selected: {selected_choice}", visible=True), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
) | |
submit_button.click( | |
fn=update_prompt_metrics, | |
inputs=[choice, state_prompts, state_random_selected_prompts], | |
outputs=[output, choice, submit_button], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_arena() | |
demo.queue() | |
demo.launch() | |