File size: 2,648 Bytes
6675f35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from transformers import Qwen2ForCausalLM, AutoTokenizer
import torch
torch.set_grad_enabled(False)
model = Qwen2ForCausalLM.from_pretrained("Thouph/tag2prompt-qwen2-0.5b-v0.1")
model.generation_config.max_new_tokens = None
"""
Otherwise you will get this warning
Both `max_new_tokens` (=2048) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co./docs/transformers/main/en/main_classes/text_generation)
"""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
def post(
input_text,
temperature,
top_p,
top_k,
output_num_words,
):
global model, processor
prompt = f"{input_text}\n{output_num_words}\n\n"
inputs = tokenizer(
prompt,
padding="do_not_pad",
max_length=512,
truncation=True,
return_tensors="pt",
)
generate_ids = model.generate(
**inputs,
max_length=512,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
generated_text = generated_text[len(prompt):]
return generated_text
def main():
with gr.Blocks() as iface:
with gr.Row():
with gr.Column(scale=1):
text_input = gr.TextArea(label="Tags (in Underscore Format)",)
temperature = gr.Slider(maximum=1., value=0.5, minimum=0., label='Temperature')
top_p = gr.Slider(maximum=1., value=0.8, minimum=0.1, label='Top P')
top_k = gr.Slider(maximum=100, value=20, minimum=1, step=1, label='Top K')
output_num_words = gr.Slider(maximum=512, value=100, minimum=1, step=1, label='Output Num Words')
with gr.Column(scale=1):
with gr.Column():
caption_output = gr.Textbox(lines=1, label="Output")
caption_button = gr.Button(
value="Run tag2prompt", interactive=True, variant="primary"
)
caption_button.click(
post,
[
text_input,
temperature,
top_p,
top_k,
output_num_words
],
[caption_output],
)
iface.launch()
if __name__ == "__main__":
main()
|