Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from diffusers import StableDiffusionXLPipeline | |
# --- Settings and paths --- | |
# Base SDXL model – change this to the base model you want to use. | |
BASE_MODEL = "ByteDance/Hyper-SD" | |
# Path to your LoRA weights (assumed to be in a format that Diffusers can use) | |
LORA_PATH = "fofr/sdxl-emoji" | |
# --- Load the base pipeline --- | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
BASE_MODEL, | |
torch_dtype=torch.float32, # Use FP32 for CPU | |
variant="fp16", # You may also need to adjust this if not using GPU | |
safety_checker=None, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe.to(device) | |
# --- Enable fast attention if available --- | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
except Exception as e: | |
print("xFormers not enabled:", e) | |
# --- Apply the LoRA weights --- | |
# Diffusers v0.18+ supports applying LoRA weights to parts of the pipeline. | |
# Here we assume the LoRA affects the UNet (and, if needed, the text encoder). | |
try: | |
# For the UNet: | |
pipe.unet.load_attn_procs(LORA_PATH) | |
# If you also have LoRA weights for the text encoder, you might do: | |
# pipe.text_encoder.load_attn_procs(LORA_PATH) | |
except Exception as e: | |
print("Error applying LoRA weights:", e) | |
# --- Define the image generation function --- | |
def generate_image(prompt: str, steps: int = 30, guidance: float = 7.5): | |
""" | |
Generate an image from a text prompt. | |
Args: | |
prompt (str): The text prompt. | |
steps (int): Number of inference steps. | |
guidance (float): Guidance scale (higher values encourage the image to follow the prompt). | |
Returns: | |
A generated PIL image. | |
""" | |
# Use autocast for faster FP16 inference on CUDA | |
with torch.cuda.amp.autocast(): | |
result = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance) | |
return result.images[0] | |
# --- Build the Gradio interface --- | |
demo = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
gr.Slider(minimum=1, maximum=8, step=1, value=30, label="Inference Steps"), | |
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale") | |
], | |
outputs=gr.Image(type="pil", label="Generated Image"), | |
title="Super Fast SDXL-Emoji Generator", | |
description=( | |
"This demo uses a Stable Diffusion XL model enhanced with a custom LoRA " | |
"to generate images quickly. Adjust the prompt and settings below, then hit 'Submit'!" | |
), | |
) | |
# --- Launch the demo --- | |
demo.launch() | |