File size: 4,567 Bytes
d092d11
1921336
d092d11
1921336
 
d092d11
1921336
 
d092d11
9a1ab03
488c5c4
 
031841d
488c5c4
 
031841d
de53991
34ffea3
 
 
 
80a8eaa
1921336
488c5c4
031841d
 
4ea28ea
031841d
 
4ea28ea
031841d
 
 
 
 
 
ed67a17
488c5c4
dd681d0
488c5c4
 
031841d
488c5c4
c0a9946
 
64df9ac
dd681d0
 
 
 
031841d
1921336
031841d
 
 
 
 
1921336
 
dd681d0
031841d
 
 
 
dd681d0
031841d
 
 
1921336
031841d
dd681d0
1921336
dd681d0
1921336
031841d
 
 
 
 
 
1921336
9a1ab03
f664ce2
 
3f9babb
 
031841d
 
 
dd681d0
031841d
dd681d0
031841d
dd681d0
031841d
dd681d0
 
 
 
031841d
dd681d0
 
 
64df9ac
 
dd681d0
 
 
031841d
 
 
 
 
dd681d0
 
 
 
 
 
64df9ac
de53991
1921336
031841d
1921336
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import json
import random

import gradio as gr

from pages.summarization_playground import custom_css, get_model_batch_generation
from utils.data import dataset
from utils.multiple_stream import stream_data


def random_data_selection():
    datapoint = random.choice(dataset)
    datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
    return datapoint


def create_arena():
    with open("prompt/prompt.json", "r") as file:
        json_data = file.read()
        prompts = json.loads(json_data)

    with gr.Blocks(css=custom_css) as demo:
        with gr.Group():
            datapoint = random_data_selection()
            gr.Markdown(
                """This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.

Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
            )

            data_textbox = gr.Textbox(
                label="Data",
                lines=10,
                placeholder="Datapoints to test...",
                value=datapoint,
            )
            with gr.Row():
                random_selection_button = gr.Button("Change Data")
                stream_button = gr.Button("✨ Click to Streaming ✨")

            random_selection_button.click(
                fn=random_data_selection, inputs=[], outputs=[data_textbox]
            )

            random.shuffle(prompts)
            random_selected_prompts = prompts[:3]

            # Store prompts in state components
            state_prompts = gr.State(value=prompts)
            state_random_selected_prompts = gr.State(value=random_selected_prompts)

            with gr.Row():
                columns = [
                    gr.Textbox(label=f"Prompt {i+1}", lines=10)
                    for i in range(len(random_selected_prompts))
                ]

            model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")

            def start_streaming(data, random_selected_prompts):
                content_list = [
                    prompt["prompt"] + "\n{" + data + "}\n\nsummary:"
                    for prompt in random_selected_prompts
                ]
                for response_data in stream_data(content_list, model):
                    updates = [
                        gr.update(value=response_data[i]) for i in range(len(columns))
                    ]
                    yield tuple(updates)

            stream_button.click(
                fn=start_streaming,
                inputs=[data_textbox, state_random_selected_prompts],
                outputs=columns,
                show_progress=False,
            )

            choice = gr.Radio(
                label="Choose the best response:",
                choices=["Response 1", "Response 2", "Response 3"],
            )

            submit_button = gr.Button("Submit")

            output = gr.Textbox(label="You selected:", visible=False)

            def update_prompt_metrics(
                selected_choice, prompts, random_selected_prompts
            ):
                if selected_choice == "Response 1":
                    prompt_id = random_selected_prompts[0]["id"]
                elif selected_choice == "Response 2":
                    prompt_id = random_selected_prompts[1]["id"]
                elif selected_choice == "Response 3":
                    prompt_id = random_selected_prompts[2]["id"]
                else:
                    raise ValueError(f"No corresponding response of {selected_choice}")

                for prompt in prompts:
                    if prompt["id"] == prompt_id:
                        prompt["metric"]["winning_number"] += 1
                        break
                else:
                    raise ValueError(f"No prompt of id {prompt_id}")

                with open("prompt/prompt.json", "w") as f:
                    json.dump(prompts, f)

                return (
                    gr.update(value=f"You selected: {selected_choice}", visible=True),
                    gr.update(interactive=False),
                    gr.update(interactive=False),
                )

            submit_button.click(
                fn=update_prompt_metrics,
                inputs=[choice, state_prompts, state_random_selected_prompts],
                outputs=[output, choice, submit_button],
            )

    return demo


if __name__ == "__main__":
    demo = create_arena()
    demo.queue()
    demo.launch()