File size: 2,648 Bytes
6675f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import Qwen2ForCausalLM, AutoTokenizer
import torch



torch.set_grad_enabled(False)
model = Qwen2ForCausalLM.from_pretrained("Thouph/tag2prompt-qwen2-0.5b-v0.1")
model.generation_config.max_new_tokens = None
"""
Otherwise you will get this warning
Both `max_new_tokens` (=2048) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co./docs/transformers/main/en/main_classes/text_generation)
"""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

def post(
        input_text,
        temperature,
        top_p,
        top_k,
        output_num_words,
):
    global model, processor

    prompt = f"{input_text}\n{output_num_words}\n\n"
    inputs = tokenizer(
        prompt,
        padding="do_not_pad",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )

    generate_ids = model.generate(
        **inputs,
        max_length=512,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )
    generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    generated_text = generated_text[len(prompt):]
    return generated_text

def main():

    with gr.Blocks() as iface:

        with gr.Row():
            with gr.Column(scale=1):
                text_input = gr.TextArea(label="Tags (in Underscore Format)",)

                temperature = gr.Slider(maximum=1., value=0.5, minimum=0., label='Temperature')
                top_p = gr.Slider(maximum=1., value=0.8, minimum=0.1, label='Top P')
                top_k = gr.Slider(maximum=100, value=20, minimum=1, step=1, label='Top K')
                output_num_words = gr.Slider(maximum=512, value=100, minimum=1, step=1, label='Output Num Words')
            with gr.Column(scale=1):
                with gr.Column():
                    caption_output = gr.Textbox(lines=1, label="Output")
                    caption_button = gr.Button(
                        value="Run tag2prompt", interactive=True, variant="primary"
                    )
                    caption_button.click(
                        post,
                        [
                            text_input,
                            temperature,
                            top_p,
                            top_k,
                            output_num_words
                        ],
                        [caption_output],
                    )

    iface.launch()




if __name__ == "__main__":
    main()