File size: 1,901 Bytes
368407f
e4e5d71
 
 
368407f
e4e5d71
 
 
 
e94cd94
e4e5d71
368407f
c677a45
e4e5d71
c677a45
e4e5d71
 
204fc4b
e4e5d71
 
204fc4b
e4e5d71
 
204fc4b
 
 
e4e5d71
 
204fc4b
 
e4e5d71
 
204fc4b
e4e5d71
 
204fc4b
e4e5d71
368407f
e4e5d71
368407f
 
 
 
 
 
6b52712
e4e5d71
6b52712
368407f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the MagicPrompt-Stable-Diffusion model and tokenizer
model_name = "Gustavosta/MagicPrompt-Stable-Diffusion"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu")

@spaces.GPU
def expand_prompt(prompt, num_variants=5, max_length=100):
    """
    Generate expanded prompts using a specialized model fine-tuned for Stable Diffusion.
    """
    # Move model to GPU
    model.to("cuda")
    
    # Tokenize input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    # Generate multiple prompt variants
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=True,
        num_return_sequences=num_variants,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode generated prompts
    expanded_prompts = [tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs]
    
    # Move model back to CPU
    model.to("cpu")
    
    return "\n\n".join(expanded_prompts)

# Create a Gradio Interface
iface = gr.Interface(
    fn=expand_prompt,
    inputs=gr.Textbox(lines=2, placeholder="Enter your basic prompt here...", label="Basic Prompt"),
    outputs=gr.Textbox(lines=10, label="Expanded Prompts"),
    title="Prompt Expansion Generator",
    description=(
        "Enter a basic prompt and receive multiple expanded prompt variants optimized for Stable Diffusion. Using ZeroGPU feature, results take 3 seconds "
        "This tool uses a specialized model fine-tuned on Stable Diffusion prompts. "
        "Simply copy the output for use with your image-generation pipeline. Thanks to https://huggingface.co./Gustavosta/MagicPrompt-Stable-Diffusion for the model!"
    )
)

if __name__ == "__main__":
    iface.launch()