JunhaoZhuang commited on
Commit
2227ec5
·
verified ·
1 Parent(s): 6fc8df6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -108
app.py CHANGED
@@ -178,123 +178,116 @@ global pipeline
178
  global MultiResNetModel
179
  global cur_style
180
 
181
- @spaces.GPU
182
- def load_ckpt():
183
- global pipeline
184
- global MultiResNetModel
185
- global cur_style
186
- cur_style = 'line + shadow'
187
- weight_dtype = torch.float16
188
 
189
- block_out_channels = [128, 128, 256, 512, 512]
190
- MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
191
- MultiResNetModel.load_state_dict(torch.load(os.path.join(model_global_path, 'shadow_GSRP', 'MultiResNetModel.bin'), map_location='cpu'), strict=True)
192
- MultiResNetModel.to('cuda', dtype=weight_dtype)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- # transformer
196
- transform = transforms.Compose([
197
- transforms.ToTensor(), # 将 PIL 图像转换为 Tensor
198
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
199
- ])
200
- # seed = 43
201
- lora_rank = 128
202
- pretrained_model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
203
 
204
- transformer = PixArtTransformer2DModel.from_pretrained(
205
- pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
206
- )
207
- pixart_config = get_pixart_config()
208
- causal_dit = CausalSparseDiTModel(num_attention_heads=pixart_config.get("num_attention_heads"),
209
- attention_head_dim=pixart_config.get("attention_head_dim"),
210
- in_channels=pixart_config.get("in_channels"),
211
- out_channels=pixart_config.get("out_channels"),
212
- num_layers=pixart_config.get("num_layers"),
213
- dropout=pixart_config.get("dropout"),
214
- norm_num_groups=pixart_config.get("norm_num_groups"),
215
- cross_attention_dim=pixart_config.get("cross_attention_dim"),
216
- attention_bias=pixart_config.get("attention_bias"),
217
- sample_size=pixart_config.get("sample_size"),
218
- patch_size=pixart_config.get("patch_size"),
219
- activation_fn=pixart_config.get("activation_fn"),
220
- num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
221
- upcast_attention=pixart_config.get("upcast_attention"),
222
- norm_type=pixart_config.get("norm_type"),
223
- norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
224
- norm_eps=pixart_config.get("norm_eps"),
225
- caption_channels=pixart_config.get("caption_channels"),
226
- attention_type=pixart_config.get("attention_type"))
227
-
228
- causal_dit = init_causal_dit(causal_dit, transformer)
229
- print('loaded causal_dit')
230
- controlnet = CausalSparseDiTControlModel(num_attention_heads=pixart_config.get("num_attention_heads"),
231
- attention_head_dim=pixart_config.get("attention_head_dim"),
232
- in_channels=pixart_config.get("in_channels"),
233
- cond_chanels = 9,
234
- out_channels=pixart_config.get("out_channels"),
235
- num_layers=pixart_config.get("num_layers"),
236
- dropout=pixart_config.get("dropout"),
237
- norm_num_groups=pixart_config.get("norm_num_groups"),
238
- cross_attention_dim=pixart_config.get("cross_attention_dim"),
239
- attention_bias=pixart_config.get("attention_bias"),
240
- sample_size=pixart_config.get("sample_size"),
241
- patch_size=pixart_config.get("patch_size"),
242
- activation_fn=pixart_config.get("activation_fn"),
243
- num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
244
- upcast_attention=pixart_config.get("upcast_attention"),
245
- norm_type=pixart_config.get("norm_type"),
246
- norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
247
- norm_eps=pixart_config.get("norm_eps"),
248
- caption_channels=pixart_config.get("caption_channels"),
249
- attention_type=pixart_config.get("attention_type")
250
- )
251
- # controlnet = init_controlnet(controlnet, causal_dit)
252
- del transformer
253
- transformer_lora_config = LoraConfig(
254
- r=lora_rank,
255
- lora_alpha=lora_rank,
256
- # use_dora=True,
257
- init_lora_weights="gaussian",
258
- target_modules=["to_k",
259
- "to_q",
260
- "to_v",
261
- "to_out.0",
262
- "proj_in",
263
- "proj_out",
264
- "ff.net.0.proj",
265
- "ff.net.2",
266
- "proj",
267
- "linear",
268
- "linear_1",
269
- "linear_2"],#ff.net.0.proj ff.net.2
270
- )
271
- causal_dit.add_adapter(transformer_lora_config)
272
 
273
-
274
- lora_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'transformer_lora_pos.bin'), map_location='cpu')
275
- causal_dit.load_state_dict(lora_state_dict, strict=False)
276
- controlnet_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'controlnet.bin'), map_location='cpu')
277
- controlnet.load_state_dict(controlnet_state_dict, strict=True)
278
-
279
- causal_dit.to('cuda', dtype=weight_dtype)
280
- controlnet.to('cuda', dtype=weight_dtype)
281
-
282
- pipeline = CobraPixArtAlphaPipeline.from_pretrained(
283
- pretrained_model_name_or_path,
284
- transformer=causal_dit,
285
- controlnet=controlnet,
286
- safety_checker=None,
287
- revision=None,
288
- variant=None,
289
- torch_dtype=weight_dtype,
290
- )
291
 
292
- pipeline = pipeline.to("cuda")
 
 
 
 
 
 
 
 
293
 
294
- print('loaded pipeline')
295
 
 
296
 
297
- load_ckpt()
298
 
299
  @spaces.GPU
300
  def change_ckpt(style):
 
178
  global MultiResNetModel
179
  global cur_style
180
 
181
+ cur_style = 'line + shadow'
182
+ weight_dtype = torch.float16
183
+
184
+ block_out_channels = [128, 128, 256, 512, 512]
185
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
186
+ MultiResNetModel.load_state_dict(torch.load(os.path.join(model_global_path, 'shadow_GSRP', 'MultiResNetModel.bin'), map_location='cpu'), strict=True)
187
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
188
 
 
 
 
 
189
 
190
+ # transformer
191
+ transform = transforms.Compose([
192
+ transforms.ToTensor(), # 将 PIL 图像转换为 Tensor
193
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
194
+ ])
195
+ # seed = 43
196
+ lora_rank = 128
197
+ pretrained_model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
198
+
199
+ transformer = PixArtTransformer2DModel.from_pretrained(
200
+ pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
201
+ )
202
+ pixart_config = get_pixart_config()
203
+ causal_dit = CausalSparseDiTModel(num_attention_heads=pixart_config.get("num_attention_heads"),
204
+ attention_head_dim=pixart_config.get("attention_head_dim"),
205
+ in_channels=pixart_config.get("in_channels"),
206
+ out_channels=pixart_config.get("out_channels"),
207
+ num_layers=pixart_config.get("num_layers"),
208
+ dropout=pixart_config.get("dropout"),
209
+ norm_num_groups=pixart_config.get("norm_num_groups"),
210
+ cross_attention_dim=pixart_config.get("cross_attention_dim"),
211
+ attention_bias=pixart_config.get("attention_bias"),
212
+ sample_size=pixart_config.get("sample_size"),
213
+ patch_size=pixart_config.get("patch_size"),
214
+ activation_fn=pixart_config.get("activation_fn"),
215
+ num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
216
+ upcast_attention=pixart_config.get("upcast_attention"),
217
+ norm_type=pixart_config.get("norm_type"),
218
+ norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
219
+ norm_eps=pixart_config.get("norm_eps"),
220
+ caption_channels=pixart_config.get("caption_channels"),
221
+ attention_type=pixart_config.get("attention_type"))
222
+
223
+ causal_dit = init_causal_dit(causal_dit, transformer)
224
+ print('loaded causal_dit')
225
+ controlnet = CausalSparseDiTControlModel(num_attention_heads=pixart_config.get("num_attention_heads"),
226
+ attention_head_dim=pixart_config.get("attention_head_dim"),
227
+ in_channels=pixart_config.get("in_channels"),
228
+ cond_chanels = 9,
229
+ out_channels=pixart_config.get("out_channels"),
230
+ num_layers=pixart_config.get("num_layers"),
231
+ dropout=pixart_config.get("dropout"),
232
+ norm_num_groups=pixart_config.get("norm_num_groups"),
233
+ cross_attention_dim=pixart_config.get("cross_attention_dim"),
234
+ attention_bias=pixart_config.get("attention_bias"),
235
+ sample_size=pixart_config.get("sample_size"),
236
+ patch_size=pixart_config.get("patch_size"),
237
+ activation_fn=pixart_config.get("activation_fn"),
238
+ num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
239
+ upcast_attention=pixart_config.get("upcast_attention"),
240
+ norm_type=pixart_config.get("norm_type"),
241
+ norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
242
+ norm_eps=pixart_config.get("norm_eps"),
243
+ caption_channels=pixart_config.get("caption_channels"),
244
+ attention_type=pixart_config.get("attention_type")
245
+ )
246
+ # controlnet = init_controlnet(controlnet, causal_dit)
247
+ del transformer
248
+ transformer_lora_config = LoraConfig(
249
+ r=lora_rank,
250
+ lora_alpha=lora_rank,
251
+ # use_dora=True,
252
+ init_lora_weights="gaussian",
253
+ target_modules=["to_k",
254
+ "to_q",
255
+ "to_v",
256
+ "to_out.0",
257
+ "proj_in",
258
+ "proj_out",
259
+ "ff.net.0.proj",
260
+ "ff.net.2",
261
+ "proj",
262
+ "linear",
263
+ "linear_1",
264
+ "linear_2"],#ff.net.0.proj ff.net.2
265
+ )
266
+ causal_dit.add_adapter(transformer_lora_config)
267
 
 
 
 
 
 
 
 
 
268
 
269
+ lora_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'transformer_lora_pos.bin'), map_location='cpu')
270
+ causal_dit.load_state_dict(lora_state_dict, strict=False)
271
+ controlnet_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'controlnet.bin'), map_location='cpu')
272
+ controlnet.load_state_dict(controlnet_state_dict, strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ causal_dit.to('cuda', dtype=weight_dtype)
275
+ controlnet.to('cuda', dtype=weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ pipeline = CobraPixArtAlphaPipeline.from_pretrained(
278
+ pretrained_model_name_or_path,
279
+ transformer=causal_dit,
280
+ controlnet=controlnet,
281
+ safety_checker=None,
282
+ revision=None,
283
+ variant=None,
284
+ torch_dtype=weight_dtype,
285
+ )
286
 
287
+ pipeline = pipeline.to("cuda")
288
 
289
+ print('loaded pipeline')
290
 
 
291
 
292
  @spaces.GPU
293
  def change_ckpt(style):