|
import gradio as gr |
|
from transformers import Qwen2ForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
model = Qwen2ForCausalLM.from_pretrained("Thouph/tag2prompt-qwen2-0.5b-v0.1") |
|
model.generation_config.max_new_tokens = None |
|
""" |
|
Otherwise you will get this warning |
|
Both `max_new_tokens` (=2048) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co./docs/transformers/main/en/main_classes/text_generation) |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") |
|
|
|
def post( |
|
input_text, |
|
temperature, |
|
top_p, |
|
top_k, |
|
output_num_words, |
|
): |
|
global model, processor |
|
|
|
prompt = f"{input_text}\n{output_num_words}\n\n" |
|
inputs = tokenizer( |
|
prompt, |
|
padding="do_not_pad", |
|
max_length=512, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
generate_ids = model.generate( |
|
**inputs, |
|
max_length=512, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k |
|
) |
|
generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
generated_text = generated_text[len(prompt):] |
|
return generated_text |
|
|
|
def main(): |
|
|
|
with gr.Blocks() as iface: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
text_input = gr.TextArea(label="Tags (in Underscore Format)",) |
|
|
|
temperature = gr.Slider(maximum=1., value=0.5, minimum=0., label='Temperature') |
|
top_p = gr.Slider(maximum=1., value=0.8, minimum=0.1, label='Top P') |
|
top_k = gr.Slider(maximum=100, value=20, minimum=1, step=1, label='Top K') |
|
output_num_words = gr.Slider(maximum=512, value=100, minimum=1, step=1, label='Output Num Words') |
|
with gr.Column(scale=1): |
|
with gr.Column(): |
|
caption_output = gr.Textbox(lines=1, label="Output") |
|
caption_button = gr.Button( |
|
value="Run tag2prompt", interactive=True, variant="primary" |
|
) |
|
caption_button.click( |
|
post, |
|
[ |
|
text_input, |
|
temperature, |
|
top_p, |
|
top_k, |
|
output_num_words |
|
], |
|
[caption_output], |
|
) |
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|