reach-vb's picture
reach-vb HF Staff
Update app.py
69d8b3b verified
import os
import gradio as gr
import json
import logging
from PIL import Image
from huggingface_hub import ModelCard, HfFileSystem # Keep ModelCard, HfFileSystem for add_custom_lora
from huggingface_hub import InferenceClient # Added for API inference
import copy
import random
import time
import re # Keep for add_custom_lora URL parsing and potentially trigger word finding
# --- Inference Client Setup ---
# It's recommended to load the API key from environment variables or Gradio secrets
HF_API_KEY = os.getenv("HF_API_KEY")
if not HF_API_KEY:
# Try to get from Gradio secrets if running in a Space
try:
HF_API_KEY = gr.secrets.get("HF_API_KEY")
except (AttributeError, KeyError):
HF_API_KEY = None # Set to None if not found
if not HF_API_KEY:
logging.warning("HF_API_KEY not found in environment variables or Gradio secrets. Inference API calls will likely fail.")
# Optionally, raise an error or provide a default behavior
# raise ValueError("Missing Hugging Face API Key (HF_API_KEY) for InferenceClient")
client = None # Initialize client as None if no key
else:
client = InferenceClient(provider="fal-ai", token=HF_API_KEY)
# Note: Provider choice depends on where the target models are hosted/supported for inference.
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Removed diffusers model initialization block
MAX_SEED = 2**32-1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
# Updated function signature: Removed width, height inputs
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
# Use the repo ID directly as the model identifier for the API call
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨ (Model ID: `{lora_repo}`)"
# Default width/height
width = 1024
height = 1024
# Update width/height based on aspect ratio if defined
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
# else keep 1024x1024
# Return updates for prompt, selection info, index, and width/height states
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
gr.update(value=width),
gr.update(value=height),
)
def run_lora(prompt, selected_index, current_seed, current_width, current_height):
global client # Access the global client
if client is None:
raise gr.Error("InferenceClient could not be initialized. Missing HF_API_KEY.")
if selected_index is None:
raise gr.Error("You must select a LoRA/Model before proceeding.")
# --- Hardcoded Defaults (Removed from UI) ---
cfg_scale = 7.0
steps = 30
# lora_scale = 0.95 # Might not be applicable/used by API
randomize_seed = True # Always randomize in this simplified version
# Removed image_input_path, image_strength - No img2img in this version
selected_lora = loras[selected_index]
# The 'repo' field now directly serves as the model identifier for the API
model_id = selected_lora["repo"]
trigger_word = selected_lora.get("trigger_word", "") # Use .get for safety
# --- Prompt Construction ---
if trigger_word:
trigger_position = selected_lora.get("trigger_position", "prepend") # Default prepend
if trigger_position == "prepend":
prompt_mash = f"{trigger_word} {prompt}"
else: # Append
prompt_mash = f"{prompt} {trigger_word}"
else:
prompt_mash = prompt
# --- Seed Handling ---
seed_to_use = current_seed # Use the state value by default
if randomize_seed:
seed_to_use = random.randint(0, MAX_SEED)
# Optional: Keep timer if desired
# with calculateDuration("Randomizing seed"):
# pass
# --- API Call (Text-to-Image only) ---
final_image = None
try:
with calculateDuration(f"API Inference (txt2img) for {model_id}"):
print(f"Running Text-to-Image for Model: {model_id}")
final_image = client.text_to_image(
prompt=prompt_mash,
model=model_id,
guidance_scale=cfg_scale,
num_inference_steps=steps,
seed=seed_to_use,
width=current_width, # Use width from state
height=current_height, # Use height from state
# lora_scale might need to be passed via 'parameters' if supported
# parameters={"lora_scale": lora_scale}
)
except Exception as e:
print(f"Error during API call: {e}")
# Improved error message for common API key issues
if "authorization" in str(e).lower() or "401" in str(e):
raise gr.Error(f"Authorization error calling the Inference API. Please ensure your HF_API_KEY is valid and has the necessary permissions. Error: {e}")
elif "model is currently loading" in str(e).lower() or "503" in str(e):
raise gr.Error(f"Model '{model_id}' is currently loading or unavailable. Please try again in a few moments. Error: {e}")
else:
raise gr.Error(f"Failed to generate image using the API. Model: {model_id}. Error: {e}")
# Return final image, the seed used, and hide progress bar
return final_image, seed_to_use, gr.update(visible=False)
# Removed get_huggingface_safetensors function as we don't download safetensors
def parse_hf_link(link):
"""Parses a Hugging Face link or repo ID string."""
if link.startswith("https://huggingface.co./"):
link = link.replace("https://huggingface.co./", "")
elif link.startswith("www.huggingface.co/"):
link = link.replace("www.huggingface.co/", "")
# Basic validation for "user/model" format
if "/" not in link or len(link.split("/")) != 2:
raise ValueError("Invalid Hugging Face repository ID format. Expected 'user/model'.")
return link.strip()
def get_model_details(repo_id):
"""Fetches model card details (image, trigger word) if possible."""
try:
model_card = ModelCard.load(repo_id)
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "") # Common key for trigger words
# Try another common key if the first fails
if not trigger_word:
trigger_word = model_card.data.get("trigger_words", [""])[0]
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{image_path}" if image_path else None
# Fallback: Check repo files for an image if not in card widget data
if not image_url:
fs = HfFileSystem()
files = fs.ls(repo_id, detail=False)
image_extensions = (".jpg", ".jpeg", ".png", ".webp")
for file in files:
filename = file.split("/")[-1]
if filename.lower().endswith(image_extensions):
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{filename}"
break # Take the first image found
# Use repo name as title if not specified elsewhere
title = model_card.data.get("model_display_name", repo_id.split('/')[-1]) # Example key, might vary
return title, trigger_word, image_url
except Exception as e:
print(f"Could not fetch model card details for {repo_id}: {e}")
# Fallback values
return repo_id.split('/')[-1], "", None # Use repo name part as title
def add_custom_lora(custom_lora_input):
global loras
if not custom_lora_input:
# Clear the custom LoRA section if input is empty
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(), "", None, ""
try:
repo_id = parse_hf_link(custom_lora_input)
print(f"Attempting to add custom model: {repo_id}")
# Check if model already exists in the list
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None)
if existing_item_index is not None:
print(f"Model {repo_id} already exists in the list.")
# Optionally re-select the existing one or just show info
selected_lora = loras[existing_item_index]
title = selected_lora.get('title', repo_id.split('/')[-1])
image = selected_lora.get('image', None) # Use existing image if available
trigger_word = selected_lora.get('trigger_word', '')
else:
# Fetch details for the new model
title, trigger_word, image = get_model_details(repo_id)
print(f"Adding new model: {repo_id}, Title: {title}, Trigger: '{trigger_word}', Image: {image}")
new_item = {
"image": image, # Store image URL (can be None)
"title": title,
"repo": repo_id, # Store the repo ID used for API calls
# "weights": path, # No longer needed
"trigger_word": trigger_word # Store trigger word if found
}
loras.append(new_item)
existing_item_index = len(loras) - 1 # Index of the newly added item
# Generate HTML card for display
card = f'''
<div class="custom_lora_card">
<span>Loaded custom model:</span>
<div class="card_internal">
{f'<img src="{image}" alt="{title} preview"/>' if image else '<div class="no-image">No Image</div>'}
<div>
<h3>{title}</h3>
<small>Model ID: <code>{repo_id}</code><br></small>
<small>{"Using trigger word: <code><b>"+trigger_word+"</code></b>" if trigger_word else "No specific trigger word found in card. Include if needed."}<br></small>
</div>
</div>
</div>
'''
# Update the gallery to include the new item (or reflect potential changes if re-added)
updated_gallery_items = [(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras]
# Update UI elements: show info card, show remove button, update gallery, clear selection info, set selected index, update prompt placeholder
return (
gr.update(visible=True, value=card),
gr.update(visible=True),
gr.Gallery(value=updated_gallery_items, selected_index=existing_item_index), # Select the added/found item
f"### Selected: [{repo_id}](https://huggingface.co./{repo_id}) ✨ (Model ID: `{repo_id}`)", # Update selection info
existing_item_index,
gr.update(placeholder=f"Type a prompt for {title}") # Update prompt placeholder
)
except ValueError as e: # Catch parsing errors
gr.Warning(f"Invalid Input: {e}")
return gr.update(visible=True, value=f"Invalid input: {e}"), gr.update(visible=False), gr.update(), "", None, ""
except Exception as e: # Catch other errors (e.g., network issues during card fetch)
gr.Warning(f"Error adding custom model: {e}")
# Show error in the info box, hide remove button, don't change gallery/selection
return gr.update(visible=True, value=f"Error adding custom model: {e}"), gr.update(visible=False), gr.update(), "", None, ""
def remove_custom_lora():
# This function might need adjustment if we want to remove the *last added* custom lora
# For now, it just clears the display and selection related to custom loras.
# It doesn't remove the item from the global `loras` list.
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(selected_index=None), "", None, gr.update(value="") # Clear custom_lora textbox too
# run_lora.zerogpu = True # Removed as inference is remote
css = '''
#gen_btn{height: 100%}
#gen_column{align-self: stretch}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
.card_internal{display: flex;height: 100px;margin-top: .5em; align-items: center;}
.card_internal img{margin-right: 1em; height: 100%; width: auto; object-fit: cover;}
.card_internal .no-image { width: 100px; height: 100px; background-color: #eee; display: flex; align-items: center; justify-content: center; color: #aaa; margin-right: 1em; font-size: small;}
.styler{--form-gap-width: 0px !important}
#progress{height:30px}
#progress .generating{display:none}
/* Keep progress bar CSS for potential future use or remove if definitely not needed */
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
'''
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co./spaces/reach-vb/Blazingly-fast-LoRA/resolve/main/flux_lora.png" alt="LoRA"> <a href="https://huggingface.co./docs/inference-providers/en/index">Blazingly Fast LoRA by Fal & HF</a> 🤗</h1>""",
elem_id="title",
)
# --- States for parameters previously in Advanced Settings ---
selected_index = gr.State(None)
width = gr.State(1024) # Default width
height = gr.State(1024) # Default height
seed = gr.State(0) # Default seed (will be randomized by run_lora)
# input_image = gr.State(None) # State for input image if img2img was kept
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA/Model")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("Select a base model or add a custom one below.") # Updated initial text
gallery = gr.Gallery(
# Ensure items have 'image' and 'title' keys, provide fallbacks if needed
[(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras],
label="Model Gallery", # Changed label
allow_preview=False,
columns=3,
elem_id="gallery",
show_share_button=False
)
with gr.Group():
custom_lora = gr.Textbox(label="Custom Model", info="Hugging Face model ID (e.g., user/model-name) or URL", placeholder="stabilityai/stable-diffusion-xl-base-1.0") # Updated label/placeholder
gr.Markdown("[Check Hugging Face Models](https://huggingface.co./models?inference_provider=fal-ai&pipeline_tag=text-to-image&sort=trending)", elem_id="lora_list") # Updated link/text
custom_lora_info = gr.HTML(visible=False)
custom_lora_button = gr.Button("Clear custom model info", visible=False) # Changed button text
with gr.Column():
# Keep progress bar element, but it will only be shown briefly if API is slow, then hidden by run_lora return
progress_bar = gr.Markdown(elem_id="progress", visible=False, value="Generating...")
result = gr.Image(label="Generated Image")
# Display the seed used for the generation
used_seed_display = gr.Textbox(label="Seed Used", value=0, interactive=False) # Display seed used
# --- Removed Advanced Settings Accordion ---
# with gr.Row():
# with gr.Accordion("Advanced Settings", open=False):
# ... (Removed content) ...
gallery.select(
update_selection,
inputs=[], # No direct inputs needed, uses evt
# Update prompt placeholder, selection text, selected index state, and width/height states
outputs=[prompt, selected_info, selected_index, width, height]
)
# Use submit event for Textbox to trigger add_custom_lora
custom_lora.submit(
add_custom_lora,
inputs=[custom_lora],
# Outputs: info card, remove button, gallery, selection text, selected index state, prompt placeholder
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
)
custom_lora_button.click(
remove_custom_lora,
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] # Clear textbox too
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
# Inputs now use state variables for width, height, seed
inputs=[prompt, selected_index, seed, width, height],
# Outputs: result image, seed state (updated with used seed), progress bar update
outputs=[result, seed, progress_bar]
).then(
# Update the displayed seed value after run_lora completes
lambda s: gr.update(value=s),
inputs=[seed],
outputs=[used_seed_display]
)
app.queue()
app.launch()