Spaces:
Runtime error
Runtime error
import os | |
import torch | |
#import # Unnecessary import | |
import gradio as gr | |
# from tqdm import tqdm # Not used directly in the simplified inference function | |
from PIL import Image | |
import torch.nn.functional as F # Not used directly in the simplified inference function | |
from torchvision import transforms as tfms # Not used directly in the simplified inference function | |
# from transformers import CLIPTextModel, CLIPTokenizer, logging # Not used directly | |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline | |
import warnings | |
# Suppress specific warnings if needed (optional) | |
# logging.set_verbosity_error() | |
warnings.filterwarnings("ignore", category=FutureWarning) # Example: Ignore FutureWarnings | |
# --- Device Setup --- | |
if torch.cuda.is_available(): | |
torch_device = "cuda" | |
print("Using CUDA (GPU)") | |
elif torch.backends.mps.is_available(): | |
torch_device = "mps" | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
print("Using MPS (Apple Silicon GPU)") | |
else: | |
torch_device = "cpu" | |
print("Using CPU") | |
# --- Configuration --- | |
model_path = "CompVis/stable-diffusion-v1-4" | |
# Use float16 for faster inference and lower memory on CUDA | |
use_fp16 = torch_device == "cuda" | |
dtype = torch.float16 if use_fp16 else torch.float32 | |
print(f"Using dtype: {dtype}") | |
# --- Load the Pipeline --- | |
print(f"Loading Stable Diffusion pipeline from {model_path}...") | |
try: | |
sd_pipeline = DiffusionPipeline.from_pretrained( | |
model_path, | |
# revision="fp16" if use_fp16 else "main", # Use fp16 revision if available and using fp16 | |
torch_dtype=dtype, | |
# low_cpu_mem_usage=True, # Useful for large models, might slightly slow down loading | |
).to(torch_device) | |
print("Pipeline loaded successfully.") | |
except Exception as e: | |
print(f"Error loading pipeline: {e}") | |
print("Ensure you have enough RAM/VRAM and are authenticated if required (huggingface-cli login).") | |
exit() # Exit if pipeline fails to load | |
# --- Enable xformers (Optional Speed/Memory Optimization) --- | |
try: | |
import xformers | |
sd_pipeline.enable_xformers_memory_efficient_attention() | |
print("xFormers enabled for memory efficient attention.") | |
except ImportError: | |
print("xFormers not found. For potential speedup, install with: pip install xformers") | |
# --- Load Textual Inversions --- | |
print("Loading textual inversions...") | |
try: | |
# Define paths or URLs - using Hugging Face Hub concepts library paths | |
inversions = { | |
"illustration-style": "sd-concepts-library/illustration-style", | |
"line-art": "sd-concepts-library/line-art", | |
"hitokomoru-style-nao": "sd-concepts-library/hitokomoru-style-nao", | |
"style-of-marc-allante": "sd-concepts-library/style-of-marc-allante", # Placeholder name likely needs adjustment | |
"midjourney-style": "sd-concepts-library/midjourney-style", | |
"hanfu-anime-style": "sd-concepts-library/hanfu-anime-style", | |
"birb-style": "sd-concepts-library/birb-style", | |
} | |
for name, path in inversions.items(): | |
print(f" Loading: {name} ({path})") | |
sd_pipeline.load_textual_inversion(path) # Assumes weights are downloaded or accessible | |
print("Textual inversions loaded.") | |
# Update style token dictionary based on loaded concepts | |
# Ensure the placeholder names match the actual token learned during TI training | |
style_token_dict = { | |
"Illustration Style": '<illustration-style>', | |
"Line Art":'<line-art>', | |
"Hitokomoru Style":'<hitokomoru-style-nao>', | |
"Marc Allante": '<style-of-marc-allante>', # Corrected placeholder based on repo name convention | |
"Midjourney":'<midjourney-style>', | |
"Hanfu Anime": '<hanfu-anime-style>', | |
"Birb Style": '<birb-style>' | |
} | |
except Exception as e: | |
print(f"Error loading textual inversions: {e}") | |
print("Please ensure the concepts exist and paths are correct.") | |
# Continue without textual inversions or exit, depending on desired behavior | |
style_token_dict = {"Default": ""} # Fallback | |
# --- Helper functions (Keep for potential future use with custom guidance) --- | |
# Note: These are not used in the current simplified 'generate_with_pipeline' approach | |
# def set_timesteps(scheduler, num_inference_steps): | |
# scheduler.set_timesteps(num_inference_steps) | |
# scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
# def pil_to_latent(vae, input_im): | |
# # VAE is part of sd_pipeline.vae | |
# transform = tfms.Compose([ | |
# tfms.ToTensor(), | |
# tfms.Normalize([0.5], [0.5]) # Important: Normalize to [-1, 1] | |
# ]) | |
# with torch.no_grad(): | |
# # Ensure image is RGB | |
# if input_im.mode != "RGB": | |
# input_im = input_im.convert("RGB") | |
# image = transform(input_im).unsqueeze(0).to(torch_device, dtype=vae.dtype) | |
# latent = vae.encode(image) # Note scaling | |
# return 0.18215 * latent.latent_dist.sample() # Magic number from SD | |
# def latents_to_pil(vae, latents): | |
# # VAE is part of sd_pipeline.vae | |
# latents = (1 / 0.18215) * latents # Reverse magic number | |
# with torch.no_grad(): | |
# image = vae.decode(latents).sample | |
# image = (image / 2 + 0.5).clamp(0, 1) # Denormalize | |
# image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
# images = (image * 255).round().astype("uint8") | |
# pil_images = [Image.fromarray(image) for image in images] | |
# return pil_images | |
# --- Generation Functions --- | |
def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed): | |
"""Generates an image using the main Diffusers pipeline.""" | |
generator = torch.Generator(device=torch_device).manual_seed(seed) | |
try: | |
# Offload VAE if low VRAM causes issues (will slow down inference) | |
# sd_pipeline.enable_vae_slicing() # Alternative memory saving | |
# sd_pipeline.enable_model_cpu_offload() # If really low on VRAM | |
print(f"\nGenerating with: Prompt='{prompt}', Steps={num_inference_steps}, Scale={guidance_scale}, Seed={seed}") | |
with torch.autocast(torch_device, enabled=use_fp16): # Use autocast for fp16 | |
image = sd_pipeline( | |
prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator | |
).images[0] | |
print("Generation complete.") | |
return image | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
# Return a placeholder or raise error | |
return Image.new('RGB', (512, 512), color = 'grey') # Placeholder grey image | |
# --- Main Inference Function for Gradio --- | |
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale): | |
""" | |
Gradio interface function. Currently generates one image using the standard pipeline. | |
The guidance method and loss scale parameters are placeholders for future implementation. | |
""" | |
if style in style_token_dict: | |
style_token = style_token_dict[style] | |
# Handle potential empty token for 'Default' or errors | |
prompt = f"{text} {style_token}".strip() | |
else: | |
print(f"Warning: Style '{style}' not found in token dictionary. Using prompt without style token.") | |
prompt = text | |
# Generate image with the standard pipeline | |
image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed) | |
# --- Placeholder for Guided Image Generation --- | |
# The code for custom guidance (Grayscale, Contrast, etc.) would go here. | |
# This typically involves a custom diffusion loop, calculating losses based on the | |
# guidance method, and modifying the latents at each step. It's significantly | |
# more complex and computationally intensive than the standard pipeline call. | |
# For now, we just return the same image as a placeholder for the second output. | |
print(f"Guidance method '{guidance_method}' and Loss Scale '{loss_scale}' are currently placeholders.") | |
image_guide = image_pipeline # Placeholder | |
return image_pipeline, image_guide | |
# --- Gradio Interface Definition --- | |
title = "Generative Art with Textual Inversion Styles" | |
description = """ | |
A Gradio interface to generate images using Stable Diffusion v1.4 with Textual Inversion styles. | |
Select a style, enter a prompt, and adjust generation parameters. | |
*Note:* The 'Generated art with guidance' output currently shows the same image as the first. Custom guidance logic (Grayscale, Contrast, etc.) is not yet implemented. Using lower inference steps speeds up generation. Enable the queue if timeouts occur. | |
""" | |
examples = [ | |
["A majestic castle on a floating island, detailed, fantasy art", 'Illustration Style', 25, 7.5, 1001, 'Grayscale', 200], | |
["A cyberpunk cityscape at night, neon lights, rain, cinematic", 'Midjourney', 30, 8.0, 42, 'Contrast', 300], | |
["Portrait of a woman in traditional chinese dress, anime style", "Hanfu Anime", 30, 7.0, 1234, 'Saturation', 250], | |
["Cute cartoon bird character sitting on a branch", "Birb Style", 20, 7.5, 5678, 'Symmetry', 150] | |
] | |
demo = gr.Interface( | |
inference, | |
inputs = [ | |
gr.Textbox(label="Prompt", info="Describe the image you want to create.", type="text"), | |
gr.Dropdown(label="Style", info="Select an art style (requires loaded textual inversion).", choices=list(style_token_dict.keys()), value="Illustration Style"), | |
gr.Slider(10, 50, 25, step = 1, label="Inference steps", info="More steps can improve detail but take longer."), # Default 25 steps | |
gr.Slider(1.0, 15.0, 7.5, step = 0.1, label="Guidance scale (CFG)", info="How strongly the prompt guides the image. Higher values follow prompt more closely."), | |
gr.Slider(0, 100000, 42, step = 1, label="Seed", info="Same seed + prompt = same image. 0 for random."), | |
gr.Dropdown(label="Guidance method (Placeholder)", info="Custom guidance method (Not implemented yet).", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale"), | |
gr.Slider(100, 10000, 200, step = 100, label="Loss scale (Placeholder)", info="Strength of custom guidance (Not implemented yet).") | |
], | |
outputs= [ | |
gr.Image(width=512, height=512, label="Generated Art (Standard Pipeline)"), | |
gr.Image(width=512, height=512, label="Generated Art with Guidance (Placeholder)") | |
], | |
title=title, | |
description=description, | |
examples=examples, | |
allow_flagging='never' # Disable flagging if not needed | |
) | |
# --- Launch the Interface with Queue Enabled --- | |
# Use .queue() to handle long inference times and prevent timeouts | |
print("Launching Gradio interface...") | |
demo.queue().launch(share=False) # Set share=True to get a public link (use with caution) |