dezzman commited on
Commit
e61c05b
·
verified ·
1 Parent(s): 124feae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -8,15 +8,15 @@ import torch
8
  from typing import Optional
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model_repo_id_default = "CompVis/stable-diffusion-v1-4"
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
15
  else:
16
  torch_dtype = torch.float32
17
 
18
- pipe = DiffusionPipeline.from_pretrained(model_repo_id_default, torch_dtype=torch_dtype)
19
- pipe = pipe.to(device)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
@@ -36,19 +36,24 @@ def infer(
36
  ):
37
  generator = torch.Generator().manual_seed(seed)
38
 
 
 
 
 
 
 
 
 
 
 
39
  if model_id != model_repo_id_default:
40
  pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
41
  pipe = pipe.to(device)
 
 
 
42
 
43
- image = pipe(
44
- prompt=prompt,
45
- negative_prompt=negative_prompt,
46
- guidance_scale=guidance_scale,
47
- num_inference_steps=num_inference_steps,
48
- width=width,
49
- height=height,
50
- generator=generator,
51
- ).images[0]
52
 
53
  return image, pipe.name_or_path
54
 
 
8
  from typing import Optional
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model_id_default = "CompVis/stable-diffusion-v1-4"
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
15
  else:
16
  torch_dtype = torch.float32
17
 
18
+ pipe_default = DiffusionPipeline.from_pretrained(model_id_default, torch_dtype=torch_dtype)
19
+ pipe_default = pipe_default.to(device)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
 
36
  ):
37
  generator = torch.Generator().manual_seed(seed)
38
 
39
+ params = {
40
+ 'prompt': prompt,
41
+ 'negative_prompt': negative_prompt,
42
+ 'guidance_scale': guidance_scale,
43
+ 'num_inference_steps': num_inference_steps,
44
+ 'width': width,
45
+ 'height': height,
46
+ 'generator': generator,
47
+ }
48
+
49
  if model_id != model_repo_id_default:
50
  pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
51
  pipe = pipe.to(device)
52
+ image = pipe(**params).images[0]
53
+ else:
54
+ image = pipe_default(**params).images[0]
55
 
56
+
 
 
 
 
 
 
 
 
57
 
58
  return image, pipe.name_or_path
59