Spaces:
Running
Running
import os | |
import gradio as gr | |
import json | |
import logging | |
from PIL import Image | |
from huggingface_hub import ModelCard, HfFileSystem # Keep ModelCard, HfFileSystem for add_custom_lora | |
from huggingface_hub import InferenceClient # Added for API inference | |
import copy | |
import random | |
import time | |
import re # Keep for add_custom_lora URL parsing and potentially trigger word finding | |
# --- Inference Client Setup --- | |
# It's recommended to load the API key from environment variables or Gradio secrets | |
HF_API_KEY = os.getenv("HF_API_KEY") | |
if not HF_API_KEY: | |
# Try to get from Gradio secrets if running in a Space | |
try: | |
HF_API_KEY = gr.secrets.get("HF_API_KEY") | |
except (AttributeError, KeyError): | |
HF_API_KEY = None # Set to None if not found | |
if not HF_API_KEY: | |
logging.warning("HF_API_KEY not found in environment variables or Gradio secrets. Inference API calls will likely fail.") | |
# Optionally, raise an error or provide a default behavior | |
# raise ValueError("Missing Hugging Face API Key (HF_API_KEY) for InferenceClient") | |
client = None # Initialize client as None if no key | |
else: | |
client = InferenceClient(provider="fal-ai", token=HF_API_KEY) | |
# Note: Provider choice depends on where the target models are hosted/supported for inference. | |
# Load LoRAs from JSON file | |
with open('loras.json', 'r') as f: | |
loras = json.load(f) | |
# Removed diffusers model initialization block | |
MAX_SEED = 2**32-1 | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
# Updated function signature: Removed width, height inputs | |
def update_selection(evt: gr.SelectData): | |
selected_lora = loras[evt.index] | |
new_placeholder = f"Type a prompt for {selected_lora['title']}" | |
lora_repo = selected_lora["repo"] | |
# Use the repo ID directly as the model identifier for the API call | |
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨ (Model ID: `{lora_repo}`)" | |
# Default width/height | |
width = 1024 | |
height = 1024 | |
# Update width/height based on aspect ratio if defined | |
if "aspect" in selected_lora: | |
if selected_lora["aspect"] == "portrait": | |
width = 768 | |
height = 1024 | |
elif selected_lora["aspect"] == "landscape": | |
width = 1024 | |
height = 768 | |
# else keep 1024x1024 | |
# Return updates for prompt, selection info, index, and width/height states | |
return ( | |
gr.update(placeholder=new_placeholder), | |
updated_text, | |
evt.index, | |
gr.update(value=width), | |
gr.update(value=height), | |
) | |
def run_lora(prompt, selected_index, current_seed, current_width, current_height): | |
global client # Access the global client | |
if client is None: | |
raise gr.Error("InferenceClient could not be initialized. Missing HF_API_KEY.") | |
if selected_index is None: | |
raise gr.Error("You must select a LoRA/Model before proceeding.") | |
# --- Hardcoded Defaults (Removed from UI) --- | |
cfg_scale = 7.0 | |
steps = 30 | |
# lora_scale = 0.95 # Might not be applicable/used by API | |
randomize_seed = True # Always randomize in this simplified version | |
# Removed image_input_path, image_strength - No img2img in this version | |
selected_lora = loras[selected_index] | |
# The 'repo' field now directly serves as the model identifier for the API | |
model_id = selected_lora["repo"] | |
trigger_word = selected_lora.get("trigger_word", "") # Use .get for safety | |
# --- Prompt Construction --- | |
if trigger_word: | |
trigger_position = selected_lora.get("trigger_position", "prepend") # Default prepend | |
if trigger_position == "prepend": | |
prompt_mash = f"{trigger_word} {prompt}" | |
else: # Append | |
prompt_mash = f"{prompt} {trigger_word}" | |
else: | |
prompt_mash = prompt | |
# --- Seed Handling --- | |
seed_to_use = current_seed # Use the state value by default | |
if randomize_seed: | |
seed_to_use = random.randint(0, MAX_SEED) | |
# Optional: Keep timer if desired | |
# with calculateDuration("Randomizing seed"): | |
# pass | |
# --- API Call (Text-to-Image only) --- | |
final_image = None | |
try: | |
with calculateDuration(f"API Inference (txt2img) for {model_id}"): | |
print(f"Running Text-to-Image for Model: {model_id}") | |
final_image = client.text_to_image( | |
prompt=prompt_mash, | |
model=model_id, | |
guidance_scale=cfg_scale, | |
num_inference_steps=steps, | |
seed=seed_to_use, | |
width=current_width, # Use width from state | |
height=current_height, # Use height from state | |
# lora_scale might need to be passed via 'parameters' if supported | |
# parameters={"lora_scale": lora_scale} | |
) | |
except Exception as e: | |
print(f"Error during API call: {e}") | |
# Improved error message for common API key issues | |
if "authorization" in str(e).lower() or "401" in str(e): | |
raise gr.Error(f"Authorization error calling the Inference API. Please ensure your HF_API_KEY is valid and has the necessary permissions. Error: {e}") | |
elif "model is currently loading" in str(e).lower() or "503" in str(e): | |
raise gr.Error(f"Model '{model_id}' is currently loading or unavailable. Please try again in a few moments. Error: {e}") | |
else: | |
raise gr.Error(f"Failed to generate image using the API. Model: {model_id}. Error: {e}") | |
# Return final image, the seed used, and hide progress bar | |
return final_image, seed_to_use, gr.update(visible=False) | |
# Removed get_huggingface_safetensors function as we don't download safetensors | |
def parse_hf_link(link): | |
"""Parses a Hugging Face link or repo ID string.""" | |
if link.startswith("https://huggingface.co./"): | |
link = link.replace("https://huggingface.co./", "") | |
elif link.startswith("www.huggingface.co/"): | |
link = link.replace("www.huggingface.co/", "") | |
# Basic validation for "user/model" format | |
if "/" not in link or len(link.split("/")) != 2: | |
raise ValueError("Invalid Hugging Face repository ID format. Expected 'user/model'.") | |
return link.strip() | |
def get_model_details(repo_id): | |
"""Fetches model card details (image, trigger word) if possible.""" | |
try: | |
model_card = ModelCard.load(repo_id) | |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) | |
trigger_word = model_card.data.get("instance_prompt", "") # Common key for trigger words | |
# Try another common key if the first fails | |
if not trigger_word: | |
trigger_word = model_card.data.get("trigger_words", [""])[0] | |
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{image_path}" if image_path else None | |
# Fallback: Check repo files for an image if not in card widget data | |
if not image_url: | |
fs = HfFileSystem() | |
files = fs.ls(repo_id, detail=False) | |
image_extensions = (".jpg", ".jpeg", ".png", ".webp") | |
for file in files: | |
filename = file.split("/")[-1] | |
if filename.lower().endswith(image_extensions): | |
image_url = f"https://huggingface.co./{repo_id}/resolve/main/{filename}" | |
break # Take the first image found | |
# Use repo name as title if not specified elsewhere | |
title = model_card.data.get("model_display_name", repo_id.split('/')[-1]) # Example key, might vary | |
return title, trigger_word, image_url | |
except Exception as e: | |
print(f"Could not fetch model card details for {repo_id}: {e}") | |
# Fallback values | |
return repo_id.split('/')[-1], "", None # Use repo name part as title | |
def add_custom_lora(custom_lora_input): | |
global loras | |
if not custom_lora_input: | |
# Clear the custom LoRA section if input is empty | |
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(), "", None, "" | |
try: | |
repo_id = parse_hf_link(custom_lora_input) | |
print(f"Attempting to add custom model: {repo_id}") | |
# Check if model already exists in the list | |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None) | |
if existing_item_index is not None: | |
print(f"Model {repo_id} already exists in the list.") | |
# Optionally re-select the existing one or just show info | |
selected_lora = loras[existing_item_index] | |
title = selected_lora.get('title', repo_id.split('/')[-1]) | |
image = selected_lora.get('image', None) # Use existing image if available | |
trigger_word = selected_lora.get('trigger_word', '') | |
else: | |
# Fetch details for the new model | |
title, trigger_word, image = get_model_details(repo_id) | |
print(f"Adding new model: {repo_id}, Title: {title}, Trigger: '{trigger_word}', Image: {image}") | |
new_item = { | |
"image": image, # Store image URL (can be None) | |
"title": title, | |
"repo": repo_id, # Store the repo ID used for API calls | |
# "weights": path, # No longer needed | |
"trigger_word": trigger_word # Store trigger word if found | |
} | |
loras.append(new_item) | |
existing_item_index = len(loras) - 1 # Index of the newly added item | |
# Generate HTML card for display | |
card = f''' | |
<div class="custom_lora_card"> | |
<span>Loaded custom model:</span> | |
<div class="card_internal"> | |
{f'<img src="{image}" alt="{title} preview"/>' if image else '<div class="no-image">No Image</div>'} | |
<div> | |
<h3>{title}</h3> | |
<small>Model ID: <code>{repo_id}</code><br></small> | |
<small>{"Using trigger word: <code><b>"+trigger_word+"</code></b>" if trigger_word else "No specific trigger word found in card. Include if needed."}<br></small> | |
</div> | |
</div> | |
</div> | |
''' | |
# Update the gallery to include the new item (or reflect potential changes if re-added) | |
updated_gallery_items = [(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras] | |
# Update UI elements: show info card, show remove button, update gallery, clear selection info, set selected index, update prompt placeholder | |
return ( | |
gr.update(visible=True, value=card), | |
gr.update(visible=True), | |
gr.Gallery(value=updated_gallery_items, selected_index=existing_item_index), # Select the added/found item | |
f"### Selected: [{repo_id}](https://huggingface.co./{repo_id}) ✨ (Model ID: `{repo_id}`)", # Update selection info | |
existing_item_index, | |
gr.update(placeholder=f"Type a prompt for {title}") # Update prompt placeholder | |
) | |
except ValueError as e: # Catch parsing errors | |
gr.Warning(f"Invalid Input: {e}") | |
return gr.update(visible=True, value=f"Invalid input: {e}"), gr.update(visible=False), gr.update(), "", None, "" | |
except Exception as e: # Catch other errors (e.g., network issues during card fetch) | |
gr.Warning(f"Error adding custom model: {e}") | |
# Show error in the info box, hide remove button, don't change gallery/selection | |
return gr.update(visible=True, value=f"Error adding custom model: {e}"), gr.update(visible=False), gr.update(), "", None, "" | |
def remove_custom_lora(): | |
# This function might need adjustment if we want to remove the *last added* custom lora | |
# For now, it just clears the display and selection related to custom loras. | |
# It doesn't remove the item from the global `loras` list. | |
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(selected_index=None), "", None, gr.update(value="") # Clear custom_lora textbox too | |
# run_lora.zerogpu = True # Removed as inference is remote | |
css = ''' | |
#gen_btn{height: 100%} | |
#gen_column{align-self: stretch} | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#gallery .grid-wrap{height: 10vh} | |
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%} | |
.card_internal{display: flex;height: 100px;margin-top: .5em; align-items: center;} | |
.card_internal img{margin-right: 1em; height: 100%; width: auto; object-fit: cover;} | |
.card_internal .no-image { width: 100px; height: 100px; background-color: #eee; display: flex; align-items: center; justify-content: center; color: #aaa; margin-right: 1em; font-size: small;} | |
.styler{--form-gap-width: 0px !important} | |
#progress{height:30px} | |
#progress .generating{display:none} | |
/* Keep progress bar CSS for potential future use or remove if definitely not needed */ | |
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px} | |
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out} | |
''' | |
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"] | |
with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app: | |
title = gr.HTML( | |
"""<h1><img src="https://huggingface.co./spaces/reach-vb/Blazingly-fast-LoRA/resolve/main/flux_lora.png" alt="LoRA"> <a href="https://huggingface.co./docs/inference-providers/en/index">Blazingly Fast LoRA by Fal & HF</a> 🤗</h1>""", | |
elem_id="title", | |
) | |
# --- States for parameters previously in Advanced Settings --- | |
selected_index = gr.State(None) | |
width = gr.State(1024) # Default width | |
height = gr.State(1024) # Default height | |
seed = gr.State(0) # Default seed (will be randomized by run_lora) | |
# input_image = gr.State(None) # State for input image if img2img was kept | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA/Model") | |
with gr.Column(scale=1, elem_id="gen_column"): | |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
with gr.Row(): | |
with gr.Column(): | |
selected_info = gr.Markdown("Select a base model or add a custom one below.") # Updated initial text | |
gallery = gr.Gallery( | |
# Ensure items have 'image' and 'title' keys, provide fallbacks if needed | |
[(item.get("image"), item.get("title", item["repo"].split('/')[-1])) for item in loras], | |
label="Model Gallery", # Changed label | |
allow_preview=False, | |
columns=3, | |
elem_id="gallery", | |
show_share_button=False | |
) | |
with gr.Group(): | |
custom_lora = gr.Textbox(label="Custom Model", info="Hugging Face model ID (e.g., user/model-name) or URL", placeholder="stabilityai/stable-diffusion-xl-base-1.0") # Updated label/placeholder | |
gr.Markdown("[Check Hugging Face Models](https://huggingface.co./models?inference_provider=fal-ai&pipeline_tag=text-to-image&sort=trending)", elem_id="lora_list") # Updated link/text | |
custom_lora_info = gr.HTML(visible=False) | |
custom_lora_button = gr.Button("Clear custom model info", visible=False) # Changed button text | |
with gr.Column(): | |
# Keep progress bar element, but it will only be shown briefly if API is slow, then hidden by run_lora return | |
progress_bar = gr.Markdown(elem_id="progress", visible=False, value="Generating...") | |
result = gr.Image(label="Generated Image") | |
# Display the seed used for the generation | |
used_seed_display = gr.Textbox(label="Seed Used", value=0, interactive=False) # Display seed used | |
# --- Removed Advanced Settings Accordion --- | |
# with gr.Row(): | |
# with gr.Accordion("Advanced Settings", open=False): | |
# ... (Removed content) ... | |
gallery.select( | |
update_selection, | |
inputs=[], # No direct inputs needed, uses evt | |
# Update prompt placeholder, selection text, selected index state, and width/height states | |
outputs=[prompt, selected_info, selected_index, width, height] | |
) | |
# Use submit event for Textbox to trigger add_custom_lora | |
custom_lora.submit( | |
add_custom_lora, | |
inputs=[custom_lora], | |
# Outputs: info card, remove button, gallery, selection text, selected index state, prompt placeholder | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt] | |
) | |
custom_lora_button.click( | |
remove_custom_lora, | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] # Clear textbox too | |
) | |
gr.on( | |
triggers=[generate_button.click, prompt.submit], | |
fn=run_lora, | |
# Inputs now use state variables for width, height, seed | |
inputs=[prompt, selected_index, seed, width, height], | |
# Outputs: result image, seed state (updated with used seed), progress bar update | |
outputs=[result, seed, progress_bar] | |
).then( | |
# Update the displayed seed value after run_lora completes | |
lambda s: gr.update(value=s), | |
inputs=[seed], | |
outputs=[used_seed_display] | |
) | |
app.queue() | |
app.launch() |