Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,26 +3,24 @@ from huggingface_hub import InferenceClient
|
|
3 |
from PIL import Image
|
4 |
import time
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
-
# Get the Hugging Face token from the environment variable, or a secret if available.
|
8 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
9 |
|
10 |
-
# Check if HF_TOKEN is set
|
11 |
if not HF_TOKEN:
|
12 |
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
|
13 |
else:
|
14 |
HF_TOKEN_ERROR = None
|
15 |
|
16 |
client = InferenceClient(token=HF_TOKEN)
|
17 |
-
PROMPT_IMPROVER_MODEL = "
|
18 |
|
19 |
def improve_prompt(original_prompt):
|
20 |
-
"""Improves the user's prompt using a language model."""
|
21 |
if HF_TOKEN_ERROR:
|
22 |
raise gr.Error(HF_TOKEN_ERROR)
|
23 |
|
24 |
try:
|
25 |
-
# Construct a prompt for the language model.
|
26 |
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
|
27 |
prompt_for_llm = f"""<|system|>
|
28 |
{system_prompt}</s>
|
@@ -34,32 +32,29 @@ Improve this prompt: {original_prompt}
|
|
34 |
improved_prompt = client.text_generation(
|
35 |
prompt=prompt_for_llm,
|
36 |
model=PROMPT_IMPROVER_MODEL,
|
37 |
-
max_new_tokens=128,
|
38 |
temperature=0.7,
|
39 |
top_p=0.9,
|
40 |
-
repetition_penalty=1.2,
|
41 |
-
stop_sequences=["</s>"],
|
42 |
|
43 |
)
|
44 |
|
45 |
-
return improved_prompt.strip()
|
46 |
-
|
47 |
|
48 |
except Exception as e:
|
49 |
-
print(f"Error improving prompt: {e}")
|
50 |
-
return original_prompt
|
51 |
|
52 |
|
53 |
def generate_image(prompt, progress=gr.Progress()):
|
54 |
-
"""Generates an image using the InferenceClient and provides progress updates."""
|
55 |
-
|
56 |
if HF_TOKEN_ERROR:
|
57 |
raise gr.Error(HF_TOKEN_ERROR)
|
58 |
|
59 |
progress(0, desc="Improving prompt...")
|
60 |
improved_prompt = improve_prompt(prompt)
|
61 |
|
62 |
-
progress(0.2, desc="Sending request to Hugging Face...")
|
63 |
try:
|
64 |
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
|
65 |
|
@@ -69,7 +64,7 @@ def generate_image(prompt, progress=gr.Progress()):
|
|
69 |
progress(0.8, desc="Processing image...")
|
70 |
time.sleep(0.5)
|
71 |
progress(1.0, desc="Done!")
|
72 |
-
return image, improved_prompt
|
73 |
except Exception as e:
|
74 |
if "rate limit" in str(e).lower():
|
75 |
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
|
@@ -77,10 +72,13 @@ def generate_image(prompt, progress=gr.Progress()):
|
|
77 |
error_message = f"An error occurred: {e}"
|
78 |
raise gr.Error(error_message)
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
css = """
|
83 |
-
/* ... (Rest of your CSS, unchanged, from the previous response) ... */
|
84 |
.container {
|
85 |
max-width: 800px;
|
86 |
margin: auto;
|
@@ -94,7 +92,7 @@ css = """
|
|
94 |
font-size: 2.5em;
|
95 |
margin-bottom: 0.5em;
|
96 |
color: #333;
|
97 |
-
font-family: 'Arial', sans-serif;
|
98 |
}
|
99 |
.description {
|
100 |
text-align: center;
|
@@ -106,15 +104,14 @@ css = """
|
|
106 |
margin-bottom: 1.5em;
|
107 |
}
|
108 |
.output-section img {
|
109 |
-
display: block;
|
110 |
-
margin: auto;
|
111 |
-
max-width: 100%;
|
112 |
-
height: auto;
|
113 |
-
border-radius: 8px;
|
114 |
-
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
115 |
}
|
116 |
|
117 |
-
/* Animation for the image appearance - subtle fade-in */
|
118 |
@keyframes fadeIn {
|
119 |
from { opacity: 0; transform: translateY(20px); }
|
120 |
to { opacity: 1; transform: translateY(0); }
|
@@ -123,7 +120,6 @@ css = """
|
|
123 |
animation: fadeIn 0.8s ease-out;
|
124 |
}
|
125 |
|
126 |
-
/* Improve button style */
|
127 |
.submit-button {
|
128 |
display: block;
|
129 |
margin: auto;
|
@@ -140,7 +136,6 @@ css = """
|
|
140 |
background-color: #367c39;
|
141 |
}
|
142 |
|
143 |
-
/* Style the error messages */
|
144 |
.error-message {
|
145 |
color: red;
|
146 |
text-align: center;
|
@@ -148,9 +143,9 @@ css = """
|
|
148 |
font-weight: bold;
|
149 |
}
|
150 |
label{
|
151 |
-
font-weight: bold;
|
152 |
-
display: block;
|
153 |
-
margin-bottom: 0.5em;
|
154 |
}
|
155 |
|
156 |
.improved-prompt-display {
|
@@ -162,6 +157,18 @@ label{
|
|
162 |
font-style: italic;
|
163 |
color: #444;
|
164 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
"""
|
166 |
|
167 |
|
@@ -169,7 +176,6 @@ with gr.Blocks(css=css) as demo:
|
|
169 |
gr.Markdown(
|
170 |
"""
|
171 |
# Xylaria Iris Image Generator
|
172 |
-
Enter a text prompt, and we'll enhance it before generating an image!
|
173 |
""",
|
174 |
elem_classes="title"
|
175 |
)
|
@@ -181,18 +187,22 @@ with gr.Blocks(css=css) as demo:
|
|
181 |
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
|
182 |
with gr.Column():
|
183 |
with gr.Group(elem_classes="output-section") as output_group:
|
184 |
-
image_output = gr.Image(label="Generated Image",
|
185 |
improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display")
|
|
|
186 |
|
187 |
|
188 |
def on_generate_click(prompt):
|
189 |
output_group.elem_classes = ["output-section", "animate"]
|
190 |
image, improved_prompt = generate_image(prompt)
|
191 |
output_group.elem_classes = ["output-section"]
|
192 |
-
|
|
|
|
|
|
|
193 |
|
194 |
-
generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output])
|
195 |
-
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output])
|
196 |
|
197 |
gr.Examples(
|
198 |
[["A dog"],
|
|
|
3 |
from PIL import Image
|
4 |
import time
|
5 |
import os
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
|
|
|
9 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
10 |
|
|
|
11 |
if not HF_TOKEN:
|
12 |
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
|
13 |
else:
|
14 |
HF_TOKEN_ERROR = None
|
15 |
|
16 |
client = InferenceClient(token=HF_TOKEN)
|
17 |
+
PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
|
18 |
|
19 |
def improve_prompt(original_prompt):
|
|
|
20 |
if HF_TOKEN_ERROR:
|
21 |
raise gr.Error(HF_TOKEN_ERROR)
|
22 |
|
23 |
try:
|
|
|
24 |
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
|
25 |
prompt_for_llm = f"""<|system|>
|
26 |
{system_prompt}</s>
|
|
|
32 |
improved_prompt = client.text_generation(
|
33 |
prompt=prompt_for_llm,
|
34 |
model=PROMPT_IMPROVER_MODEL,
|
35 |
+
max_new_tokens=128,
|
36 |
temperature=0.7,
|
37 |
top_p=0.9,
|
38 |
+
repetition_penalty=1.2,
|
39 |
+
stop_sequences=["</s>"],
|
40 |
|
41 |
)
|
42 |
|
43 |
+
return improved_prompt.strip()
|
|
|
44 |
|
45 |
except Exception as e:
|
46 |
+
print(f"Error improving prompt: {e}")
|
47 |
+
return original_prompt
|
48 |
|
49 |
|
50 |
def generate_image(prompt, progress=gr.Progress()):
|
|
|
|
|
51 |
if HF_TOKEN_ERROR:
|
52 |
raise gr.Error(HF_TOKEN_ERROR)
|
53 |
|
54 |
progress(0, desc="Improving prompt...")
|
55 |
improved_prompt = improve_prompt(prompt)
|
56 |
|
57 |
+
progress(0.2, desc="Sending request to Hugging Face...")
|
58 |
try:
|
59 |
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
|
60 |
|
|
|
64 |
progress(0.8, desc="Processing image...")
|
65 |
time.sleep(0.5)
|
66 |
progress(1.0, desc="Done!")
|
67 |
+
return image, improved_prompt
|
68 |
except Exception as e:
|
69 |
if "rate limit" in str(e).lower():
|
70 |
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
|
|
|
72 |
error_message = f"An error occurred: {e}"
|
73 |
raise gr.Error(error_message)
|
74 |
|
75 |
+
def pil_to_base64(img):
|
76 |
+
buffered = BytesIO()
|
77 |
+
img.save(buffered, format="PNG")
|
78 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
79 |
+
return f"data:image/png;base64,{img_str}"
|
80 |
|
81 |
css = """
|
|
|
82 |
.container {
|
83 |
max-width: 800px;
|
84 |
margin: auto;
|
|
|
92 |
font-size: 2.5em;
|
93 |
margin-bottom: 0.5em;
|
94 |
color: #333;
|
95 |
+
font-family: 'Arial', sans-serif;
|
96 |
}
|
97 |
.description {
|
98 |
text-align: center;
|
|
|
104 |
margin-bottom: 1.5em;
|
105 |
}
|
106 |
.output-section img {
|
107 |
+
display: block;
|
108 |
+
margin: auto;
|
109 |
+
max-width: 100%;
|
110 |
+
height: auto;
|
111 |
+
border-radius: 8px;
|
112 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
113 |
}
|
114 |
|
|
|
115 |
@keyframes fadeIn {
|
116 |
from { opacity: 0; transform: translateY(20px); }
|
117 |
to { opacity: 1; transform: translateY(0); }
|
|
|
120 |
animation: fadeIn 0.8s ease-out;
|
121 |
}
|
122 |
|
|
|
123 |
.submit-button {
|
124 |
display: block;
|
125 |
margin: auto;
|
|
|
136 |
background-color: #367c39;
|
137 |
}
|
138 |
|
|
|
139 |
.error-message {
|
140 |
color: red;
|
141 |
text-align: center;
|
|
|
143 |
font-weight: bold;
|
144 |
}
|
145 |
label{
|
146 |
+
font-weight: bold;
|
147 |
+
display: block;
|
148 |
+
margin-bottom: 0.5em;
|
149 |
}
|
150 |
|
151 |
.improved-prompt-display {
|
|
|
157 |
font-style: italic;
|
158 |
color: #444;
|
159 |
}
|
160 |
+
.download-link {
|
161 |
+
display: block;
|
162 |
+
text-align: center;
|
163 |
+
margin-top: 10px;
|
164 |
+
color: #4CAF50;
|
165 |
+
text-decoration: none;
|
166 |
+
font-weight: bold;
|
167 |
+
}
|
168 |
+
|
169 |
+
.download-link:hover{
|
170 |
+
text-decoration: underline;
|
171 |
+
}
|
172 |
"""
|
173 |
|
174 |
|
|
|
176 |
gr.Markdown(
|
177 |
"""
|
178 |
# Xylaria Iris Image Generator
|
|
|
179 |
""",
|
180 |
elem_classes="title"
|
181 |
)
|
|
|
187 |
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
|
188 |
with gr.Column():
|
189 |
with gr.Group(elem_classes="output-section") as output_group:
|
190 |
+
image_output = gr.Image(label="Generated Image", interactive=False)
|
191 |
improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display")
|
192 |
+
download_link = gr.HTML(visible=False)
|
193 |
|
194 |
|
195 |
def on_generate_click(prompt):
|
196 |
output_group.elem_classes = ["output-section", "animate"]
|
197 |
image, improved_prompt = generate_image(prompt)
|
198 |
output_group.elem_classes = ["output-section"]
|
199 |
+
image_b64 = pil_to_base64(image)
|
200 |
+
download_html = f'<a class="download-link" href="{image_b64}" download="generated_image.png">Download Image</a>'
|
201 |
+
|
202 |
+
return image, improved_prompt, download_html
|
203 |
|
204 |
+
generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
|
205 |
+
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
|
206 |
|
207 |
gr.Examples(
|
208 |
[["A dog"],
|