use 2 steps unet
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConf
|
|
18 |
BASE = "stabilityai/stable-diffusion-xl-base-1.0"
|
19 |
REPO = "ByteDance/SDXL-Lightning"
|
20 |
# 1-step
|
21 |
-
CHECKPOINT = "
|
22 |
|
23 |
# {
|
24 |
# "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
@@ -39,11 +39,13 @@ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
|
|
39 |
print(f"device: {device}")
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
42 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
43 |
-
BASE, torch_dtype=torch.float16, variant="fp16"
|
44 |
).to("cuda")
|
45 |
-
pipe.load_lora_weights(hf_hub_download(REPO, CHECKPOINT))
|
46 |
-
pipe.fuse_lora()
|
47 |
|
48 |
# Ensure sampler uses "trailing" timesteps.
|
49 |
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
@@ -82,7 +84,6 @@ def predict(prompt, seed=1231231):
|
|
82 |
guidance_scale=0.0,
|
83 |
# width=768,
|
84 |
# height=768,
|
85 |
-
# original_inference_steps=params.lcm_steps,
|
86 |
output_type="pil",
|
87 |
)
|
88 |
print(f"Pipe took {time.time() - last_time} seconds")
|
@@ -133,23 +134,24 @@ with gr.Blocks(css=css) as demo:
|
|
133 |
"""## Running SDXL-Lightning with `diffusers`
|
134 |
```py
|
135 |
import torch
|
136 |
-
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
|
137 |
from huggingface_hub import hf_hub_download
|
|
|
138 |
|
139 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
140 |
repo = "ByteDance/SDXL-Lightning"
|
141 |
-
ckpt = "
|
142 |
|
143 |
# Load model.
|
144 |
-
|
145 |
-
|
146 |
-
pipe.
|
147 |
|
148 |
# Ensure sampler uses "trailing" timesteps.
|
149 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
150 |
|
151 |
# Ensure using the same inference steps as the loaded model and CFG set to 0.
|
152 |
-
pipe("A girl smiling", num_inference_steps=
|
153 |
```
|
154 |
"""
|
155 |
)
|
|
|
18 |
BASE = "stabilityai/stable-diffusion-xl-base-1.0"
|
19 |
REPO = "ByteDance/SDXL-Lightning"
|
20 |
# 1-step
|
21 |
+
CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
|
22 |
|
23 |
# {
|
24 |
# "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
|
|
39 |
print(f"device: {device}")
|
40 |
|
41 |
|
42 |
+
unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to(
|
43 |
+
"cuda", torch.float16
|
44 |
+
)
|
45 |
+
unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda"))
|
46 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
47 |
+
BASE, unet=unet, torch_dtype=torch.float16, variant="fp16"
|
48 |
).to("cuda")
|
|
|
|
|
49 |
|
50 |
# Ensure sampler uses "trailing" timesteps.
|
51 |
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
|
|
84 |
guidance_scale=0.0,
|
85 |
# width=768,
|
86 |
# height=768,
|
|
|
87 |
output_type="pil",
|
88 |
)
|
89 |
print(f"Pipe took {time.time() - last_time} seconds")
|
|
|
134 |
"""## Running SDXL-Lightning with `diffusers`
|
135 |
```py
|
136 |
import torch
|
137 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
138 |
from huggingface_hub import hf_hub_download
|
139 |
+
from safetensors.torch import load_file
|
140 |
|
141 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
142 |
repo = "ByteDance/SDXL-Lightning"
|
143 |
+
ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
|
144 |
|
145 |
# Load model.
|
146 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|
147 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
|
148 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
149 |
|
150 |
# Ensure sampler uses "trailing" timesteps.
|
151 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
152 |
|
153 |
# Ensure using the same inference steps as the loaded model and CFG set to 0.
|
154 |
+
pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("output.png")
|
155 |
```
|
156 |
"""
|
157 |
)
|