Spaces:
Sleeping
Sleeping
File size: 4,567 Bytes
d092d11 1921336 d092d11 1921336 d092d11 1921336 d092d11 9a1ab03 488c5c4 031841d 488c5c4 031841d de53991 34ffea3 80a8eaa 1921336 488c5c4 031841d 4ea28ea 031841d 4ea28ea 031841d ed67a17 488c5c4 dd681d0 488c5c4 031841d 488c5c4 c0a9946 64df9ac dd681d0 031841d 1921336 031841d 1921336 dd681d0 031841d dd681d0 031841d 1921336 031841d dd681d0 1921336 dd681d0 1921336 031841d 1921336 9a1ab03 f664ce2 3f9babb 031841d dd681d0 031841d dd681d0 031841d dd681d0 031841d dd681d0 031841d dd681d0 64df9ac dd681d0 031841d dd681d0 64df9ac de53991 1921336 031841d 1921336 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import json
import random
import gradio as gr
from pages.summarization_playground import custom_css, get_model_batch_generation
from utils.data import dataset
from utils.multiple_stream import stream_data
def random_data_selection():
datapoint = random.choice(dataset)
datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
return datapoint
def create_arena():
with open("prompt/prompt.json", "r") as file:
json_data = file.read()
prompts = json.loads(json_data)
with gr.Blocks(css=custom_css) as demo:
with gr.Group():
datapoint = random_data_selection()
gr.Markdown(
"""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
)
data_textbox = gr.Textbox(
label="Data",
lines=10,
placeholder="Datapoints to test...",
value=datapoint,
)
with gr.Row():
random_selection_button = gr.Button("Change Data")
stream_button = gr.Button("✨ Click to Streaming ✨")
random_selection_button.click(
fn=random_data_selection, inputs=[], outputs=[data_textbox]
)
random.shuffle(prompts)
random_selected_prompts = prompts[:3]
# Store prompts in state components
state_prompts = gr.State(value=prompts)
state_random_selected_prompts = gr.State(value=random_selected_prompts)
with gr.Row():
columns = [
gr.Textbox(label=f"Prompt {i+1}", lines=10)
for i in range(len(random_selected_prompts))
]
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
def start_streaming(data, random_selected_prompts):
content_list = [
prompt["prompt"] + "\n{" + data + "}\n\nsummary:"
for prompt in random_selected_prompts
]
for response_data in stream_data(content_list, model):
updates = [
gr.update(value=response_data[i]) for i in range(len(columns))
]
yield tuple(updates)
stream_button.click(
fn=start_streaming,
inputs=[data_textbox, state_random_selected_prompts],
outputs=columns,
show_progress=False,
)
choice = gr.Radio(
label="Choose the best response:",
choices=["Response 1", "Response 2", "Response 3"],
)
submit_button = gr.Button("Submit")
output = gr.Textbox(label="You selected:", visible=False)
def update_prompt_metrics(
selected_choice, prompts, random_selected_prompts
):
if selected_choice == "Response 1":
prompt_id = random_selected_prompts[0]["id"]
elif selected_choice == "Response 2":
prompt_id = random_selected_prompts[1]["id"]
elif selected_choice == "Response 3":
prompt_id = random_selected_prompts[2]["id"]
else:
raise ValueError(f"No corresponding response of {selected_choice}")
for prompt in prompts:
if prompt["id"] == prompt_id:
prompt["metric"]["winning_number"] += 1
break
else:
raise ValueError(f"No prompt of id {prompt_id}")
with open("prompt/prompt.json", "w") as f:
json.dump(prompts, f)
return (
gr.update(value=f"You selected: {selected_choice}", visible=True),
gr.update(interactive=False),
gr.update(interactive=False),
)
submit_button.click(
fn=update_prompt_metrics,
inputs=[choice, state_prompts, state_random_selected_prompts],
outputs=[output, choice, submit_button],
)
return demo
if __name__ == "__main__":
demo = create_arena()
demo.queue()
demo.launch()
|