import os import gradio as gr import json import logging from PIL import Image from huggingface_hub import ModelCard, HfFileSystem # Keep ModelCard, HfFileSystem for add_custom_lora from huggingface_hub import InferenceClient # Added for API inference import copy import random import time import re # Keep for add_custom_lora URL parsing and potentially trigger word finding # --- Inference Client Setup --- # It's recommended to load the API key from environment variables or Gradio secrets HF_API_KEY = os.getenv("HF_API_KEY") if not HF_API_KEY: # Try to get from Gradio secrets if running in a Space try: HF_API_KEY = gr.secrets.get("HF_API_KEY") except (AttributeError, KeyError): HF_API_KEY = None # Set to None if not found if not HF_API_KEY: logging.warning("HF_API_KEY not found in environment variables or Gradio secrets. Inference API calls will likely fail.") # Optionally, raise an error or provide a default behavior # raise ValueError("Missing Hugging Face API Key (HF_API_KEY) for InferenceClient") client = None # Initialize client as None if no key else: client = InferenceClient(provider="fal-ai", token=HF_API_KEY) # Note: Provider choice depends on where the target models are hosted/supported for inference. # Load LoRAs from JSON file with open('loras.json', 'r') as f: loras = json.load(f) # Removed diffusers model initialization block MAX_SEED = 2**32-1 class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") # Updated function signature: Removed width, height inputs def update_selection(evt: gr.SelectData): selected_lora = loras[evt.index] new_placeholder = f"Type a prompt for {selected_lora['title']}" lora_repo = selected_lora["repo"] # Use the repo ID directly as the model identifier for the API call updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨ (Model ID: `{lora_repo}`)" # Default width/height width = 1024 height = 1024 # Update width/height based on aspect ratio if defined if "aspect" in selected_lora: if selected_lora["aspect"] == "portrait": width = 768 height = 1024 elif selected_lora["aspect"] == "landscape": width = 1024 height = 768 # else keep 1024x1024 # Return updates for prompt, selection info, index, and width/height states return ( gr.update(placeholder=new_placeholder), updated_text, evt.index, gr.update(value=width), gr.update(value=height), ) def run_lora(prompt, selected_index, current_seed, current_width, current_height): global client # Access the global client if client is None: raise gr.Error("InferenceClient could not be initialized. Missing HF_API_KEY.") if selected_index is None: raise gr.Error("You must select a LoRA/Model before proceeding.") # --- Hardcoded Defaults (Removed from UI) --- cfg_scale = 7.0 steps = 30 # lora_scale = 0.95 # Might not be applicable/used by API randomize_seed = True # Always randomize in this simplified version # Removed image_input_path, image_strength - No img2img in this version selected_lora = loras[selected_index] # The 'repo' field now directly serves as the model identifier for the API model_id = selected_lora["repo"] trigger_word = selected_lora.get("trigger_word", "") # Use .get for safety # --- Prompt Construction --- if trigger_word: trigger_position = selected_lora.get("trigger_position", "prepend") # Default prepend if trigger_position == "prepend": prompt_mash = f"{trigger_word} {prompt}" else: # Append prompt_mash = f"{prompt} {trigger_word}" else: prompt_mash = prompt # --- Seed Handling --- seed_to_use = current_seed # Use the state value by default if randomize_seed: seed_to_use = random.randint(0, MAX_SEED) # Optional: Keep timer if desired # with calculateDuration("Randomizing seed"): # pass # --- API Call (Text-to-Image only) --- final_image = None try: with calculateDuration(f"API Inference (txt2img) for {model_id}"): print(f"Running Text-to-Image for Model: {model_id}") final_image = client.text_to_image( prompt=prompt_mash, model=model_id, guidance_scale=cfg_scale, num_inference_steps=steps, seed=seed_to_use, width=current_width, # Use width from state height=current_height, # Use height from state # lora_scale might need to be passed via 'parameters' if supported # parameters={"lora_scale": lora_scale} ) except Exception as e: print(f"Error during API call: {e}") # Improved error message for common API key issues if "authorization" in str(e).lower() or "401" in str(e): raise gr.Error(f"Authorization error calling the Inference API. Please ensure your HF_API_KEY is valid and has the necessary permissions. Error: {e}") elif "model is currently loading" in str(e).lower() or "503" in str(e): raise gr.Error(f"Model '{model_id}' is currently loading or unavailable. Please try again in a few moments. Error: {e}") else: raise gr.Error(f"Failed to generate image using the API. Model: {model_id}. Error: {e}") # Return final image, the seed used, and hide progress bar return final_image, seed_to_use, gr.update(visible=False) # Removed get_huggingface_safetensors function as we don't download safetensors def parse_hf_link(link): """Parses a Hugging Face link or repo ID string.""" if link.startswith("https://huggingface.co./"): link = link.replace("https://huggingface.co./", "") elif link.startswith("www.huggingface.co/"): link = link.replace("www.huggingface.co/", "") # Basic validation for "user/model" format if "/" not in link or len(link.split("/")) != 2: raise ValueError("Invalid Hugging Face repository ID format. Expected 'user/model'.") return link.strip() def get_model_details(repo_id): """Fetches model card details (image, trigger word) if possible.""" try: model_card = ModelCard.load(repo_id) image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) trigger_word = model_card.data.get("instance_prompt", "") # Common key for trigger words # Try another common key if the first fails if not trigger_word: trigger_word = model_card.data.get("trigger_words", [""])[0] image_url = f"https://huggingface.co./{repo_id}/resolve/main/{image_path}" if image_path else None # Fallback: Check repo files for an image if not in card widget data if not image_url: fs = HfFileSystem() files = fs.ls(repo_id, detail=False) image_extensions = (".jpg", ".jpeg", ".png", ".webp") for file in files: filename = file.split("/")[-1] if filename.lower().endswith(image_extensions): image_url = f"https://huggingface.co./{repo_id}/resolve/main/{filename}" break # Take the first image found # Use repo name as title if not specified elsewhere title = model_card.data.get("model_display_name", repo_id.split('/')[-1]) # Example key, might vary return title, trigger_word, image_url except Exception as e: print(f"Could not fetch model card details for {repo_id}: {e}") # Fallback values return repo_id.split('/')[-1], "", None # Use repo name part as title def add_custom_lora(custom_lora_input): global loras if not custom_lora_input: # Clear the custom LoRA section if input is empty return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(), "", None, "" try: repo_id = parse_hf_link(custom_lora_input) print(f"Attempting to add custom model: {repo_id}") # Check if model already exists in the list existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None) if existing_item_index is not None: print(f"Model {repo_id} already exists in the list.") # Optionally re-select the existing one or just show info selected_lora = loras[existing_item_index] title = selected_lora.get('title', repo_id.split('/')[-1]) image = selected_lora.get('image', None) # Use existing image if available trigger_word = selected_lora.get('trigger_word', '') else: # Fetch details for the new model title, trigger_word, image = get_model_details(repo_id) print(f"Adding new model: {repo_id}, Title: {title}, Trigger: '{trigger_word}', Image: {image}") new_item = { "image": image, # Store image URL (can be None) "title": title, "repo": repo_id, # Store the repo ID used for API calls # "weights": path, # No longer needed "trigger_word": trigger_word # Store trigger word if found } loras.append(new_item) existing_item_index = len(loras) - 1 # Index of the newly added item # Generate HTML card for display card = f'''
{repo_id}
"+trigger_word+"
" if trigger_word else "No specific trigger word found in card. Include if needed."}