Maria commited on
Commit
ada0ab1
·
1 Parent(s): 5cbab77
Files changed (2) hide show
  1. app.py +42 -69
  2. infer.py +255 -0
app.py CHANGED
@@ -1,75 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- import os
5
-
6
- # import spaces #[uncomment to use ZeroGPU]
7
- from diffusers import DiffusionPipeline
8
- from peft import PeftModel, LoraConfig
9
- import torch
10
-
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- if torch.cuda.is_available():
14
- torch_dtype = torch.float16
15
- else:
16
- torch_dtype = torch.float32
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- LoRA_path = 'new_model'
22
-
23
- # @spaces.GPU #[uncomment to use ZeroGPU]
24
- def infer(
25
- model_id,
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- if model_id == 'Maria_Lashina_LoRA':
42
- adapter_name = 'a cartoonish mouse'
43
- unet_sub_dir = os.path.join(LoRA_path, "unet")
44
- text_encoder_sub_dir = os.path.join(LoRA_path, "text_encoder")
45
-
46
- pipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=torch_dtype).to(device)
47
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
48
-
49
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
50
-
51
- if torch_dtype == torch.float16:
52
- pipe.unet.half()
53
- pipe.text_encoder.half()
54
-
55
- pipe.to(device)
56
-
57
- else:
58
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
59
-
60
- image = pipe(
61
- prompt=prompt,
62
- negative_prompt=negative_prompt,
63
- guidance_scale=guidance_scale,
64
- num_inference_steps=num_inference_steps,
65
- width=width,
66
- height=height,
67
- generator=generator,
68
- ).images[0]
69
-
70
- return image, seed
71
-
72
-
73
  examples = [
74
  "The image of a cartoonish mouse eating from a red bowl of yellow triangle chips, her cheeks are full. The mouse is gray with big pink ears, small white eyes and a black pointed nose. It has a simple design, the background color is white. The style of the image is reminiscent of a sticker or a digital illustration.",
75
  "The image of a cartoonish mouse with red hearts instead of eyes meaning that the mouse is in love with something. The mouse is gray with big pink ears and a black pointed nose. It has a simple design, the background color is white. The style of the image is reminiscent of a sticker or a digital illustration.",
@@ -83,9 +18,15 @@ css = """
83
  }
84
  """
85
 
 
 
 
 
 
 
86
  with gr.Blocks(css=css) as demo:
87
  with gr.Column(elem_id="col-container"):
88
- gr.Markdown(" # Text-to-Image Gradio Template")
89
 
90
  MODEL_LIST = [
91
  "CompVis/stable-diffusion-v1-4",
@@ -116,8 +57,33 @@ with gr.Blocks(css=css) as demo:
116
  label="Negative prompt",
117
  max_lines=1,
118
  placeholder="Enter a negative prompt",
119
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
 
 
121
 
122
  seed = gr.Slider(
123
  label="Seed",
@@ -177,9 +143,16 @@ with gr.Blocks(css=css) as demo:
177
  height,
178
  guidance_scale,
179
  num_inference_steps,
 
 
 
 
 
 
 
180
  ],
181
  outputs=[result, seed],
182
  )
183
 
184
  if __name__ == "__main__":
185
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ from infer import infer, CONTROLNET_MODE
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  MAX_SEED = np.iinfo(np.int32).max
6
  MAX_IMAGE_SIZE = 1024
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  examples = [
9
  "The image of a cartoonish mouse eating from a red bowl of yellow triangle chips, her cheeks are full. The mouse is gray with big pink ears, small white eyes and a black pointed nose. It has a simple design, the background color is white. The style of the image is reminiscent of a sticker or a digital illustration.",
10
  "The image of a cartoonish mouse with red hearts instead of eyes meaning that the mouse is in love with something. The mouse is gray with big pink ears and a black pointed nose. It has a simple design, the background color is white. The style of the image is reminiscent of a sticker or a digital illustration.",
 
18
  }
19
  """
20
 
21
+ def on_checkbox_change(use_advanced):
22
+ visible = use_advanced
23
+ return (gr.update(visible=visible, interactive=visible),
24
+ gr.update(visible=visible, interactive=visible),
25
+ gr.update(visible=visible, interactive=visible))
26
+
27
  with gr.Blocks(css=css) as demo:
28
  with gr.Column(elem_id="col-container"):
29
+ gr.Markdown(" # Maria Lashina Text-to-Image Rat Stickers Generation App")
30
 
31
  MODEL_LIST = [
32
  "CompVis/stable-diffusion-v1-4",
 
57
  label="Negative prompt",
58
  max_lines=1,
59
  placeholder="Enter a negative prompt",
60
+ visible=True,
61
+ )
62
+
63
+ use_controlnet = gr.Checkbox(label="Use ControlNet")
64
+ control_strength = gr.Slider(
65
+ label="ControlNet strength",
66
+ minimum=0,
67
+ maximum=1,
68
+ step=0.01,
69
+ value=0.8,
70
+ visible=False
71
+ )
72
+ controlnet_mode = gr.Dropdown(CONTROLNET_MODE.keys(), label="ControlNet mode", visible=False)
73
+ controlnet_image = gr.Image(label="ControlNet image", visible=False)
74
+ use_controlnet.change(on_checkbox_change, use_controlnet, [control_strength, controlnet_mode, controlnet_image])
75
+
76
+ use_ip_adapter = gr.Checkbox(label="Use IPAdapter")
77
+ ip_adapter_scale = gr.Slider(
78
+ label="IPAdapter scale",
79
+ minimum=0,
80
+ maximum=1,
81
+ step=0.01,
82
+ value=0.8,
83
+ visible=False
84
  )
85
+ ip_adapter_image = gr.Image(label="IPAdapter image", visible=False)
86
+ use_advanced_ip.change(on_checkbox_change, use_advanced_ip, [ip_adapter_scale, image_upload_ip])
87
 
88
  seed = gr.Slider(
89
  label="Seed",
 
143
  height,
144
  guidance_scale,
145
  num_inference_steps,
146
+ use_controlnet,
147
+ controlnet_strength,
148
+ controlnet_mode,
149
+ controlnet_image,
150
+ use_ip_adapter,
151
+ ip_adapter_scale,
152
+ ip_adapter_image
153
  ],
154
  outputs=[result, seed],
155
  )
156
 
157
  if __name__ == "__main__":
158
+ demo.launch(share=False, debug=True)
infer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import cv2 as cv
4
+ import random
5
+ import os
6
+ import spaces
7
+ import gradio as gr
8
+
9
+ from transformers import pipeline
10
+ from controlnet_aux import MLSDdetector, HEDdetector, NormalBaeDetector, LineartDetector
11
+ from peft import PeftModel, LoraConfig
12
+ from diffusers import (
13
+ DiffusionPipeline,
14
+ StableDiffusionPipeline,
15
+ StableDiffusionControlNetPipeline,
16
+ StableDiffusionControlNetImg2ImgPipeline,
17
+ DPMSolverMultistepScheduler,
18
+ PNDMScheduler,
19
+ ControlNetModel
20
+ )
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
23
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.utils import load_image, make_image_grid
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ if torch.cuda.is_available():
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+ default_model = 'CompVis/stable-diffusion-v1-4'
35
+ LoRA_path = 'new_model'
36
+
37
+ CONTROLNET_MODE = {
38
+ "Canny Edge Detection" : "lllyasviel/control_v11p_sd15_canny",
39
+ "Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p",
40
+ "HED edge detection (soft edge)" : "lllyasviel/control_sd15_hed",
41
+ "Midas depth estimation" : "lllyasviel/control_v11f1p_sd15_depth",
42
+ "Surface Normal Estimation" : "lllyasviel/control_v11p_sd15_normalbae",
43
+ "Scribble-Based Generation" : "lllyasviel/control_v11p_sd15_scribble",
44
+ "Line Art Generation": "lllyasviel/control_v11p_sd15_lineart",
45
+ }
46
+
47
+ def get_pipe(
48
+ model_id,
49
+ use_controlnet,
50
+ controlnet_mode,
51
+ use_ip_adapter
52
+ ):
53
+
54
+ if use_controlnet and use_ip_adapter:
55
+
56
+ print('Pipe with ControlNet and IPAdapter')
57
+
58
+ controlnet = ControlNetModel.from_pretrained(
59
+ CONTROLNET_MODE[controlnet_mode],
60
+ cache_dir="./models_cache",
61
+ torch_dtype=torch.float16
62
+ )
63
+
64
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
65
+ model_id if model_id!='Maria_Lashina_LoRA' else default_model,
66
+ torch_dtype=torch_dtype,
67
+ controlnet=use_controlnet,
68
+ safety_checker=None,
69
+ ).to(device)
70
+
71
+ pipe.load_ip_adapter(
72
+ "h94/IP-Adapter",
73
+ subfolder="models",
74
+ weight_name="ip-adapter-plus_sd14.bin",
75
+ )
76
+
77
+ elif controlnet:
78
+
79
+ print('Pipe with ControlNet')
80
+
81
+ controlnet = ControlNetModel.from_pretrained(
82
+ CONTROLNET_MODE[controlnet_mode],
83
+ cache_dir="./models_cache",
84
+ torch_dtype=torch.float16)
85
+
86
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
87
+ model_id if model_id!='Maria_Lashina_LoRA' else default_model,
88
+ torch_dtype=torch_dtype,
89
+ controlnet=use_controlnet,
90
+ safety_checker=None,
91
+ ).to(device)
92
+
93
+ elif ip_adapter:
94
+
95
+ print('Pipe with IpAdapter')
96
+
97
+ pipe = StableDiffusionPipeline.from_pretrained(
98
+ model_id if model_id!='Maria_Lashina_LoRA' else default_model,
99
+ torch_dtype=torch_dtype,
100
+ safety_checker=None,
101
+ ).to(device)
102
+
103
+ pipe.load_ip_adapter(
104
+ "h94/IP-Adapter",
105
+ subfolder="models",
106
+ weight_name="ip-adapter-plus_sd14.bin")
107
+
108
+ else:
109
+
110
+ print('Pipe with only SD')
111
+
112
+ pipe = StableDiffusionPipeline.from_pretrained(
113
+ model_id if model_id!='Maria_Lashina_LoRA' else default_model,
114
+ torch_dtype=torch_dtype,
115
+ safety_checker=None,
116
+ ).to(device)
117
+
118
+
119
+ if model_id == 'Maria_Lashina_LoRA':
120
+ adapter_name = 'a cartoonish mouse'
121
+ unet_sub_dir = os.path.join(LoRA_path, "unet")
122
+ text_encoder_sub_dir = os.path.join(LoRA_path, "text_encoder")
123
+
124
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
125
+
126
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
127
+
128
+ if torch_dtype == torch.float16:
129
+ pipe.unet.half()
130
+ pipe.text_encoder.half()
131
+
132
+ return pipe
133
+
134
+ def prepare_controlnet_image(controlnet_image, mode):
135
+ if mode == "Canny Edge Detection":
136
+ image = cv.Canny(controlnet_image, 80, 160)
137
+ image = np.repeat(image[:, :, None], 3, axis=2)
138
+ image = Image.fromarray(image)
139
+
140
+ elif mode == "HED edge detection (soft edge)":
141
+ processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
142
+ image = processor(controlnet_image)
143
+
144
+ elif mode == "Midas depth estimation":
145
+ depth_estimator = pipeline('depth-estimation')
146
+ image = depth_estimator(controlnet_image)['depth']
147
+ image = np.array(image)
148
+ image = image[:, :, None]
149
+ image = np.concatenate([image, image, image], axis=2)
150
+ image = Image.fromarray(image)
151
+
152
+ elif mode == "Surface Normal Estimation":
153
+ processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
154
+ image = processor(controlnet_image)
155
+
156
+ elif mode == "Scribble-Based Generation":
157
+ processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
158
+ image = processor(controlnet_image, scribble=True)
159
+
160
+ elif mode == "Line Art Generation":
161
+ processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
162
+ image = processor(controlnet_image)
163
+
164
+ else:
165
+ image = controlnet_image
166
+
167
+ # @spaces.GPU #[uncomment to use ZeroGPU]
168
+ def infer(
169
+ model_id,
170
+ prompt,
171
+ negative_prompt,
172
+ seed,
173
+ randomize_seed,
174
+ width,
175
+ height,
176
+ guidance_scale,
177
+ num_inference_steps,
178
+ use_controlnet,
179
+ controlnet_strength,
180
+ controlnet_mode,
181
+ controlnet_image,
182
+ use_ip_adapter,
183
+ ip_adapter_scale,
184
+ ip_adapter_image,
185
+ progress=gr.Progress(track_tqdm=True),
186
+ ):
187
+ if randomize_seed:
188
+ seed = random.randint(0, MAX_SEED)
189
+
190
+ generator = torch.Generator().manual_seed(seed)
191
+
192
+ if not use_controlnet and not use_ip_adapter:
193
+
194
+ pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
195
+
196
+ image = pipe(
197
+ prompt=prompt,
198
+ negative_prompt=negative_prompt,
199
+ guidance_scale=guidance_scale,
200
+ num_inference_steps=num_inference_steps,
201
+ width=width,
202
+ height=height,
203
+ generator=generator
204
+ ).images[0]
205
+
206
+ elif use_controlnet and not use_ip_adapter:
207
+
208
+ cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode)
209
+
210
+ pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
211
+
212
+ image = pipe(
213
+ prompt,
214
+ cn_image,
215
+ negative_prompt=negative_prompt,
216
+ num_inference_steps = num_inference_steps,
217
+ controlnet_conditioning_scale=control_strength,
218
+ generator=generator
219
+ ).images[0]
220
+
221
+ elif not use_controlnet and use_ip_adapter:
222
+
223
+ pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
224
+
225
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
226
+
227
+ image = pipe(
228
+ prompt,
229
+ num_inference_steps=num_inference_steps,
230
+ guidance_scale=guidance_scale,
231
+ ip_adapter_image=ip_adapter_image,
232
+ generator=generator
233
+ ).images[0]
234
+
235
+ elif use_controlnet and use_ip_adapter:
236
+
237
+ cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode)
238
+
239
+ pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
240
+
241
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
242
+
243
+ image = pipe(
244
+ prompt,
245
+ cn_image,
246
+ num_inference_steps=num_inference_steps,
247
+ guidance_scale=guidance_scale,
248
+ height=height,
249
+ width=width,
250
+ controlnet_conditioning_scale=control_strength,
251
+ ip_adapter_image=image_upload_ip,
252
+ generator=generator,
253
+ ).images[0]
254
+
255
+ return image, seed