Reality123b commited on
Commit
bf20e5c
·
verified ·
1 Parent(s): d833693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -36
app.py CHANGED
@@ -3,26 +3,24 @@ from huggingface_hub import InferenceClient
3
  from PIL import Image
4
  import time
5
  import os
 
 
6
 
7
- # Get the Hugging Face token from the environment variable, or a secret if available.
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
9
 
10
- # Check if HF_TOKEN is set
11
  if not HF_TOKEN:
12
  HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
13
  else:
14
  HF_TOKEN_ERROR = None
15
 
16
  client = InferenceClient(token=HF_TOKEN)
17
- PROMPT_IMPROVER_MODEL = "TheBloke/zephyr-7B-beta-AWQ" # A good general-purpose text model. AWQ for speed.
18
 
19
  def improve_prompt(original_prompt):
20
- """Improves the user's prompt using a language model."""
21
  if HF_TOKEN_ERROR:
22
  raise gr.Error(HF_TOKEN_ERROR)
23
 
24
  try:
25
- # Construct a prompt for the language model.
26
  system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
27
  prompt_for_llm = f"""<|system|>
28
  {system_prompt}</s>
@@ -34,32 +32,29 @@ Improve this prompt: {original_prompt}
34
  improved_prompt = client.text_generation(
35
  prompt=prompt_for_llm,
36
  model=PROMPT_IMPROVER_MODEL,
37
- max_new_tokens=128, # Limit the length of the improved prompt
38
  temperature=0.7,
39
  top_p=0.9,
40
- repetition_penalty=1.2, # Encourage diverse output
41
- stop_sequences=["</s>"], # stop at end of sentence
42
 
43
  )
44
 
45
- return improved_prompt.strip() # Remove leading/trailing whitespace
46
-
47
 
48
  except Exception as e:
49
- print(f"Error improving prompt: {e}") # Log the error for debugging
50
- return original_prompt # Return the original prompt if there's an error
51
 
52
 
53
  def generate_image(prompt, progress=gr.Progress()):
54
- """Generates an image using the InferenceClient and provides progress updates."""
55
-
56
  if HF_TOKEN_ERROR:
57
  raise gr.Error(HF_TOKEN_ERROR)
58
 
59
  progress(0, desc="Improving prompt...")
60
  improved_prompt = improve_prompt(prompt)
61
 
62
- progress(0.2, desc="Sending request to Hugging Face...") # More granular progress
63
  try:
64
  image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
65
 
@@ -69,7 +64,7 @@ def generate_image(prompt, progress=gr.Progress()):
69
  progress(0.8, desc="Processing image...")
70
  time.sleep(0.5)
71
  progress(1.0, desc="Done!")
72
- return image, improved_prompt # Return both image and improved prompt
73
  except Exception as e:
74
  if "rate limit" in str(e).lower():
75
  error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
@@ -77,10 +72,13 @@ def generate_image(prompt, progress=gr.Progress()):
77
  error_message = f"An error occurred: {e}"
78
  raise gr.Error(error_message)
79
 
80
-
 
 
 
 
81
 
82
  css = """
83
- /* ... (Rest of your CSS, unchanged, from the previous response) ... */
84
  .container {
85
  max-width: 800px;
86
  margin: auto;
@@ -94,7 +92,7 @@ css = """
94
  font-size: 2.5em;
95
  margin-bottom: 0.5em;
96
  color: #333;
97
- font-family: 'Arial', sans-serif; /* More readable font */
98
  }
99
  .description {
100
  text-align: center;
@@ -106,15 +104,14 @@ css = """
106
  margin-bottom: 1.5em;
107
  }
108
  .output-section img {
109
- display: block; /* Ensure image takes full width of container */
110
- margin: auto; /* Center the image horizontally */
111
- max-width: 100%; /* Prevent image overflow */
112
- height: auto; /* Maintain aspect ratio */
113
- border-radius: 8px; /* Rounded corners for the image */
114
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */
115
  }
116
 
117
- /* Animation for the image appearance - subtle fade-in */
118
  @keyframes fadeIn {
119
  from { opacity: 0; transform: translateY(20px); }
120
  to { opacity: 1; transform: translateY(0); }
@@ -123,7 +120,6 @@ css = """
123
  animation: fadeIn 0.8s ease-out;
124
  }
125
 
126
- /* Improve button style */
127
  .submit-button {
128
  display: block;
129
  margin: auto;
@@ -140,7 +136,6 @@ css = """
140
  background-color: #367c39;
141
  }
142
 
143
- /* Style the error messages */
144
  .error-message {
145
  color: red;
146
  text-align: center;
@@ -148,9 +143,9 @@ css = """
148
  font-weight: bold;
149
  }
150
  label{
151
- font-weight: bold; /* Make labels bold */
152
- display: block; /* Each label on its own line */
153
- margin-bottom: 0.5em; /* Space between label and input */
154
  }
155
 
156
  .improved-prompt-display {
@@ -162,6 +157,18 @@ label{
162
  font-style: italic;
163
  color: #444;
164
  }
 
 
 
 
 
 
 
 
 
 
 
 
165
  """
166
 
167
 
@@ -169,7 +176,6 @@ with gr.Blocks(css=css) as demo:
169
  gr.Markdown(
170
  """
171
  # Xylaria Iris Image Generator
172
- Enter a text prompt, and we'll enhance it before generating an image!
173
  """,
174
  elem_classes="title"
175
  )
@@ -181,18 +187,22 @@ with gr.Blocks(css=css) as demo:
181
  generate_button = gr.Button("Generate Image", elem_classes="submit-button")
182
  with gr.Column():
183
  with gr.Group(elem_classes="output-section") as output_group:
184
- image_output = gr.Image(label="Generated Image", show_download_button=False, interactive=False) # No SVG, not interactive
185
  improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display")
 
186
 
187
 
188
  def on_generate_click(prompt):
189
  output_group.elem_classes = ["output-section", "animate"]
190
  image, improved_prompt = generate_image(prompt)
191
  output_group.elem_classes = ["output-section"]
192
- return image, improved_prompt
 
 
 
193
 
194
- generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output])
195
- prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output])
196
 
197
  gr.Examples(
198
  [["A dog"],
 
3
  from PIL import Image
4
  import time
5
  import os
6
+ import base64
7
+ from io import BytesIO
8
 
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
10
 
 
11
  if not HF_TOKEN:
12
  HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
13
  else:
14
  HF_TOKEN_ERROR = None
15
 
16
  client = InferenceClient(token=HF_TOKEN)
17
+ PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
18
 
19
  def improve_prompt(original_prompt):
 
20
  if HF_TOKEN_ERROR:
21
  raise gr.Error(HF_TOKEN_ERROR)
22
 
23
  try:
 
24
  system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
25
  prompt_for_llm = f"""<|system|>
26
  {system_prompt}</s>
 
32
  improved_prompt = client.text_generation(
33
  prompt=prompt_for_llm,
34
  model=PROMPT_IMPROVER_MODEL,
35
+ max_new_tokens=128,
36
  temperature=0.7,
37
  top_p=0.9,
38
+ repetition_penalty=1.2,
39
+ stop_sequences=["</s>"],
40
 
41
  )
42
 
43
+ return improved_prompt.strip()
 
44
 
45
  except Exception as e:
46
+ print(f"Error improving prompt: {e}")
47
+ return original_prompt
48
 
49
 
50
  def generate_image(prompt, progress=gr.Progress()):
 
 
51
  if HF_TOKEN_ERROR:
52
  raise gr.Error(HF_TOKEN_ERROR)
53
 
54
  progress(0, desc="Improving prompt...")
55
  improved_prompt = improve_prompt(prompt)
56
 
57
+ progress(0.2, desc="Sending request to Hugging Face...")
58
  try:
59
  image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
60
 
 
64
  progress(0.8, desc="Processing image...")
65
  time.sleep(0.5)
66
  progress(1.0, desc="Done!")
67
+ return image, improved_prompt
68
  except Exception as e:
69
  if "rate limit" in str(e).lower():
70
  error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
 
72
  error_message = f"An error occurred: {e}"
73
  raise gr.Error(error_message)
74
 
75
+ def pil_to_base64(img):
76
+ buffered = BytesIO()
77
+ img.save(buffered, format="PNG")
78
+ img_str = base64.b64encode(buffered.getvalue()).decode()
79
+ return f"data:image/png;base64,{img_str}"
80
 
81
  css = """
 
82
  .container {
83
  max-width: 800px;
84
  margin: auto;
 
92
  font-size: 2.5em;
93
  margin-bottom: 0.5em;
94
  color: #333;
95
+ font-family: 'Arial', sans-serif;
96
  }
97
  .description {
98
  text-align: center;
 
104
  margin-bottom: 1.5em;
105
  }
106
  .output-section img {
107
+ display: block;
108
+ margin: auto;
109
+ max-width: 100%;
110
+ height: auto;
111
+ border-radius: 8px;
112
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
113
  }
114
 
 
115
  @keyframes fadeIn {
116
  from { opacity: 0; transform: translateY(20px); }
117
  to { opacity: 1; transform: translateY(0); }
 
120
  animation: fadeIn 0.8s ease-out;
121
  }
122
 
 
123
  .submit-button {
124
  display: block;
125
  margin: auto;
 
136
  background-color: #367c39;
137
  }
138
 
 
139
  .error-message {
140
  color: red;
141
  text-align: center;
 
143
  font-weight: bold;
144
  }
145
  label{
146
+ font-weight: bold;
147
+ display: block;
148
+ margin-bottom: 0.5em;
149
  }
150
 
151
  .improved-prompt-display {
 
157
  font-style: italic;
158
  color: #444;
159
  }
160
+ .download-link {
161
+ display: block;
162
+ text-align: center;
163
+ margin-top: 10px;
164
+ color: #4CAF50;
165
+ text-decoration: none;
166
+ font-weight: bold;
167
+ }
168
+
169
+ .download-link:hover{
170
+ text-decoration: underline;
171
+ }
172
  """
173
 
174
 
 
176
  gr.Markdown(
177
  """
178
  # Xylaria Iris Image Generator
 
179
  """,
180
  elem_classes="title"
181
  )
 
187
  generate_button = gr.Button("Generate Image", elem_classes="submit-button")
188
  with gr.Column():
189
  with gr.Group(elem_classes="output-section") as output_group:
190
+ image_output = gr.Image(label="Generated Image", interactive=False)
191
  improved_prompt_output = gr.Textbox(label="Improved Prompt", interactive=False, elem_classes="improved-prompt-display")
192
+ download_link = gr.HTML(visible=False)
193
 
194
 
195
  def on_generate_click(prompt):
196
  output_group.elem_classes = ["output-section", "animate"]
197
  image, improved_prompt = generate_image(prompt)
198
  output_group.elem_classes = ["output-section"]
199
+ image_b64 = pil_to_base64(image)
200
+ download_html = f'<a class="download-link" href="{image_b64}" download="generated_image.png">Download Image</a>'
201
+
202
+ return image, improved_prompt, download_html
203
 
204
+ generate_button.click(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
205
+ prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=[image_output, improved_prompt_output, download_link])
206
 
207
  gr.Examples(
208
  [["A dog"],