Spaces:
Running
Running
File size: 18,071 Bytes
db155ea 69d8b3b db155ea 274cae7 db155ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
import os
import gradio as gr
import json
import logging
from PIL import Image
from huggingface_hub import ModelCard, HfFileSystem # Keep ModelCard, HfFileSystem for add_custom_lora
from huggingface_hub import InferenceClient # Added for API inference
import copy
import random
import time
import re # Keep for add_custom_lora URL parsing and potentially trigger word finding
# --- Inference Client Setup ---
# It's recommended to load the API key from environment variables or Gradio secrets
HF_API_KEY = os.getenv("HF_API_KEY")
if not HF_API_KEY:
# Try to get from Gradio secrets if running in a Space
try:
HF_API_KEY = gr.secrets.get("HF_API_KEY")
except (AttributeError, KeyError):
HF_API_KEY = None # Set to None if not found
if not HF_API_KEY:
logging.warning("HF_API_KEY not found in environment variables or Gradio secrets. Inference API calls will likely fail.")
# Optionally, raise an error or provide a default behavior
# raise ValueError("Missing Hugging Face API Key (HF_API_KEY) for InferenceClient")
client = None # Initialize client as None if no key
else:
client = InferenceClient(provider="fal-ai", token=HF_API_KEY)
# Note: Provider choice depends on where the target models are hosted/supported for inference.
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Removed diffusers model initialization block
MAX_SEED = 2**32-1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
# Updated function signature: Removed width, height inputs
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
# Use the repo ID directly as the model identifier for the API call
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨ (Model ID: `{lora_repo}`)"
# Default width/height
width = 1024
height = 1024
# Update width/height based on aspect ratio if defined
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
# else keep 1024x1024
# Return updates for prompt, selection info, index, and width/height states
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
gr.update(value=width),
gr.update(value=height),
)
def run_lora(prompt, selected_index, current_seed, current_width, current_height):
global client # Access the global client
if client is None:
raise gr.Error("InferenceClient could not be initialized. Missing HF_API_KEY.")
if selected_index is None:
raise gr.Error("You must select a LoRA/Model before proceeding.")
# --- Hardcoded Defaults (Removed from UI) ---
cfg_scale = 7.0
steps = 30
# lora_scale = 0.95 # Might not be applicable/used by API
randomize_seed = True # Always randomize in this simplified version
# Removed image_input_path, image_strength - No img2img in this version
selected_lora = loras[selected_index]
# The 'repo' field now directly serves as the model identifier for the API
model_id = selected_lora["repo"]
trigger_word = selected_lora.get("trigger_word", "") # Use .get for safety
# --- Prompt Construction ---
if trigger_word:
trigger_position = selected_lora.get("trigger_position", "prepend") # Default prepend
if trigger_position == "prepend":
prompt_mash = f"{trigger_word} {prompt}"
else: # Append
prompt_mash = f"{prompt} {trigger_word}"
else:
prompt_mash = prompt
# --- Seed Handling ---
seed_to_use = current_seed # Use the state value by default
if randomize_seed:
seed_to_use = random.randint(0, MAX_SEED)
# Optional: Keep timer if desired
# with calculateDuration("Randomizing seed"):
# pass
# --- API Call (Text-to-Image only) ---
final_image = None
try:
with calculateDuration(f"API Inference (txt2img) for {model_id}"):
print(f"Running Text-to-Image for Model: {model_id}")
final_image = client.text_to_image(
prompt=prompt_mash,
model=model_id,
guidance_scale=cfg_scale,
num_inference_steps=steps,
seed=seed_to_use,
width=current_width, # Use width from state
height=current_height, # Use height from state
# lora_scale might need to be passed via 'parameters' if supported
# parameters={"lora_scale": lora_scale}
)
except Exception as e:
print(f"Error during API call: {e}")
# Improved error message for common API key issues
if "authorization" in str(e).lower() or "401" in str(e):
raise gr.Error(f"Authorization error calling the Inference API. Please ensure your HF_API_KEY is valid and has the necessary permissions. Error: {e}")
elif "model is currently loading" in str(e).lower() or "503" in str(e):
raise gr.Error(f"Model '{model_id}' is currently loading or unavailable. Please try again in a few moments. Error: {e}")
else:
raise gr.Error(f"Failed to generate image using the API. Model: {model_id}. Error: {e}")
# Return final image, the seed used, and hide progress bar
return final_image, seed_to_use, gr.update(visible=False)
# Removed get_huggingface_safetensors function as we don't download safetensors
def parse_hf_link(link):
"""Parses a Hugging Face link or repo ID string."""
if link.startswith("https://huggingface.co./"):
link = link.replace("https://huggingface.co./", "")
elif link.startswith("www.huggingface.co/"):
link = link.replace("www.huggingface.co/", "")
# Basic validation for "user/model" format
if "/" not in link or len(link.split("/")) != 2:
raise ValueError("Invalid Hugging Face repository ID format. Expected 'user/model'.")
return link.strip()
def get_model_details(repo_id):
"""Fetches model card details (image, trigger word) if possible."""
try:
model_card = ModelCard.load(repo_id)
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "") # Common key for trigger words
# Try another common key if the first fails
if not trigger_word:
trigger_word = model_card.data.get("trigger_words", [""])[0]
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{image_path}" if image_path else None
# Fallback: Check repo files for an image if not in card widget data
if not image_url:
fs = HfFileSystem()
files = fs.ls(repo_id, detail=False)
image_extensions = (".jpg", ".jpeg", ".png", ".webp")
for file in files:
filename = file.split("/")[-1]
if filename.lower().endswith(image_extensions):
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{filename}"
break # Take the first image found
# Use repo name as title if not specified elsewhere
title = model_card.data.get("model_display_name", repo_id.split('/')[-1]) # Example key, might vary
return title, trigger_word, image_url
except Exception as e:
print(f"Could not fetch model card details for {repo_id}: {e}")
# Fallback values
return repo_id.split('/')[-1], "", None # Use repo name part as title
def add_custom_lora(custom_lora_input):
global loras
if not custom_lora_input:
# Clear the custom LoRA section if input is empty
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(), "", None, ""
try:
repo_id = parse_hf_link(custom_lora_input)
print(f"Attempting to add custom model: {repo_id}")
# Check if model already exists in the list
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None)
if existing_item_index is not None:
print(f"Model {repo_id} already exists in the list.")
# Optionally re-select the existing one or just show info
selected_lora = loras[existing_item_index]
title = selected_lora.get('title', repo_id.split('/')[-1])
image = selected_lora.get('image', None) # Use existing image if available
trigger_word = selected_lora.get('trigger_word', '')
else:
# Fetch details for the new model
title, trigger_word, image = get_model_details(repo_id)
print(f"Adding new model: {repo_id}, Title: {title}, Trigger: '{trigger_word}', Image: {image}")
new_item = {
"image": image, # Store image URL (can be None)
"title": title,
"repo": repo_id, # Store the repo ID used for API calls
# "weights": path, # No longer needed
"trigger_word": trigger_word # Store trigger word if found
}
loras.append(new_item)
existing_item_index = len(loras) - 1 # Index of the newly added item
# Generate HTML card for display
card = f'''
<div class="custom_lora_card">
<span>Loaded custom model:</span>
<div class="card_internal">
{f'<img src="{image}" alt="{title} preview"/>' if image else '<div class="no-image">No Image</div>'}
<div>
<h3>{title}</h3>
<small>Model ID: <code>{repo_id}</code><br></small>
<small>{"Using trigger word: <code><b>"+trigger_word+"</code></b>" if trigger_word else "No specific trigger word found in card. Include if needed."}<br></small>
</div>
</div>
</div>
'''
# Update the gallery to include the new item (or reflect potential changes if re-added)
updated_gallery_items = [(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras]
# Update UI elements: show info card, show remove button, update gallery, clear selection info, set selected index, update prompt placeholder
return (
gr.update(visible=True, value=card),
gr.update(visible=True),
gr.Gallery(value=updated_gallery_items, selected_index=existing_item_index), # Select the added/found item
f"### Selected: [{repo_id}](https://huggingface.co./{repo_id}) ✨ (Model ID: `{repo_id}`)", # Update selection info
existing_item_index,
gr.update(placeholder=f"Type a prompt for {title}") # Update prompt placeholder
)
except ValueError as e: # Catch parsing errors
gr.Warning(f"Invalid Input: {e}")
return gr.update(visible=True, value=f"Invalid input: {e}"), gr.update(visible=False), gr.update(), "", None, ""
except Exception as e: # Catch other errors (e.g., network issues during card fetch)
gr.Warning(f"Error adding custom model: {e}")
# Show error in the info box, hide remove button, don't change gallery/selection
return gr.update(visible=True, value=f"Error adding custom model: {e}"), gr.update(visible=False), gr.update(), "", None, ""
def remove_custom_lora():
# This function might need adjustment if we want to remove the *last added* custom lora
# For now, it just clears the display and selection related to custom loras.
# It doesn't remove the item from the global `loras` list.
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(selected_index=None), "", None, gr.update(value="") # Clear custom_lora textbox too
# run_lora.zerogpu = True # Removed as inference is remote
css = '''
#gen_btn{height: 100%}
#gen_column{align-self: stretch}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
.card_internal{display: flex;height: 100px;margin-top: .5em; align-items: center;}
.card_internal img{margin-right: 1em; height: 100%; width: auto; object-fit: cover;}
.card_internal .no-image { width: 100px; height: 100px; background-color: #eee; display: flex; align-items: center; justify-content: center; color: #aaa; margin-right: 1em; font-size: small;}
.styler{--form-gap-width: 0px !important}
#progress{height:30px}
#progress .generating{display:none}
/* Keep progress bar CSS for potential future use or remove if definitely not needed */
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
'''
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co./spaces/reach-vb/Blazingly-fast-LoRA/resolve/main/flux_lora.png" alt="LoRA"> <a href="https://huggingface.co./docs/inference-providers/en/index">Blazingly Fast LoRA by Fal & HF</a> 🤗</h1>""",
elem_id="title",
)
# --- States for parameters previously in Advanced Settings ---
selected_index = gr.State(None)
width = gr.State(1024) # Default width
height = gr.State(1024) # Default height
seed = gr.State(0) # Default seed (will be randomized by run_lora)
# input_image = gr.State(None) # State for input image if img2img was kept
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA/Model")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("Select a base model or add a custom one below.") # Updated initial text
gallery = gr.Gallery(
# Ensure items have 'image' and 'title' keys, provide fallbacks if needed
[(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras],
label="Model Gallery", # Changed label
allow_preview=False,
columns=3,
elem_id="gallery",
show_share_button=False
)
with gr.Group():
custom_lora = gr.Textbox(label="Custom Model", info="Hugging Face model ID (e.g., user/model-name) or URL", placeholder="stabilityai/stable-diffusion-xl-base-1.0") # Updated label/placeholder
gr.Markdown("[Check Hugging Face Models](https://huggingface.co./models?inference_provider=fal-ai&pipeline_tag=text-to-image&sort=trending)", elem_id="lora_list") # Updated link/text
custom_lora_info = gr.HTML(visible=False)
custom_lora_button = gr.Button("Clear custom model info", visible=False) # Changed button text
with gr.Column():
# Keep progress bar element, but it will only be shown briefly if API is slow, then hidden by run_lora return
progress_bar = gr.Markdown(elem_id="progress", visible=False, value="Generating...")
result = gr.Image(label="Generated Image")
# Display the seed used for the generation
used_seed_display = gr.Textbox(label="Seed Used", value=0, interactive=False) # Display seed used
# --- Removed Advanced Settings Accordion ---
# with gr.Row():
# with gr.Accordion("Advanced Settings", open=False):
# ... (Removed content) ...
gallery.select(
update_selection,
inputs=[], # No direct inputs needed, uses evt
# Update prompt placeholder, selection text, selected index state, and width/height states
outputs=[prompt, selected_info, selected_index, width, height]
)
# Use submit event for Textbox to trigger add_custom_lora
custom_lora.submit(
add_custom_lora,
inputs=[custom_lora],
# Outputs: info card, remove button, gallery, selection text, selected index state, prompt placeholder
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
)
custom_lora_button.click(
remove_custom_lora,
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] # Clear textbox too
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
# Inputs now use state variables for width, height, seed
inputs=[prompt, selected_index, seed, width, height],
# Outputs: result image, seed state (updated with used seed), progress bar update
outputs=[result, seed, progress_bar]
).then(
# Update the displayed seed value after run_lora completes
lambda s: gr.update(value=s),
inputs=[seed],
outputs=[used_seed_display]
)
app.queue()
app.launch() |