option to enable taesd sdxl
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ from diffusers import (
|
|
2 |
StableDiffusionXLPipeline,
|
3 |
EulerDiscreteScheduler,
|
4 |
UNet2DConditionModel,
|
|
|
5 |
)
|
6 |
import torch
|
7 |
import os
|
@@ -19,6 +20,7 @@ BASE = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
19 |
REPO = "ByteDance/SDXL-Lightning"
|
20 |
# 1-step
|
21 |
CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
|
|
|
22 |
|
23 |
# {
|
24 |
# "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
@@ -30,6 +32,8 @@ CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
|
|
30 |
|
31 |
SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
|
32 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
|
|
|
|
33 |
# check if MPS is available OSX only M1/M2/M3 chips
|
34 |
|
35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -38,6 +42,7 @@ torch_dtype = torch.float16
|
|
38 |
|
39 |
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
|
40 |
print(f"SFAST_COMPILE: {SFAST_COMPILE}")
|
|
|
41 |
print(f"device: {device}")
|
42 |
|
43 |
|
@@ -49,6 +54,12 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
49 |
BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
|
50 |
).to("cuda")
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Ensure sampler uses "trailing" timesteps.
|
53 |
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
54 |
pipe.scheduler.config, timestep_spacing="trailing"
|
|
|
2 |
StableDiffusionXLPipeline,
|
3 |
EulerDiscreteScheduler,
|
4 |
UNet2DConditionModel,
|
5 |
+
AutoencoderTiny,
|
6 |
)
|
7 |
import torch
|
8 |
import os
|
|
|
20 |
REPO = "ByteDance/SDXL-Lightning"
|
21 |
# 1-step
|
22 |
CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
|
23 |
+
taesd_model = "madebyollin/taesdxl"
|
24 |
|
25 |
# {
|
26 |
# "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
|
|
32 |
|
33 |
SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
|
34 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
35 |
+
USE_TAESD = os.environ.get("USE_TAESD", "0") == "1"
|
36 |
+
|
37 |
# check if MPS is available OSX only M1/M2/M3 chips
|
38 |
|
39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
42 |
|
43 |
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
|
44 |
print(f"SFAST_COMPILE: {SFAST_COMPILE}")
|
45 |
+
print(f"USE_TAESD: {USE_TAESD}")
|
46 |
print(f"device: {device}")
|
47 |
|
48 |
|
|
|
54 |
BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
|
55 |
).to("cuda")
|
56 |
|
57 |
+
if USE_TAESD:
|
58 |
+
pipe.vae = AutoencoderTiny.from_pretrained(
|
59 |
+
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
60 |
+
).to(device)
|
61 |
+
|
62 |
+
|
63 |
# Ensure sampler uses "trailing" timesteps.
|
64 |
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
65 |
pipe.scheduler.config, timestep_spacing="trailing"
|