retwpay commited on
Commit
87959d7
·
verified ·
1 Parent(s): 6cdbf05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -139
app.py CHANGED
@@ -7,7 +7,6 @@ import random
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
- import re
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
@@ -30,151 +29,22 @@ pipe.unet.to(torch.float16)
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1216
33
-
34
- # Function to parse weighted prompts
35
- def parse_prompt_attention(text):
36
- """
37
- Parses a prompt with attention weights
38
- Examples:
39
- "a (red:1.5) dress" -> weight "red" with 1.5
40
- "a ((blue)) sky" -> weight "blue" with 2.0
41
- """
42
- re_attention = r'\((\()?([^:]+)(\))?(?::([\d\.]+))?\)'
43
- res = []
44
-
45
- for match in re.finditer(re_attention, text):
46
- double_paren, content, _, weight = match.groups()
47
- weight = float(weight) if weight is not None else 1.0
48
- if double_paren:
49
- weight = weight * 1.1 # Optional: make (()) slightly higher than ()
50
-
51
- res.append((match.start(), match.end(), content, weight))
52
-
53
- return res
54
-
55
- # Function to process prompts with attention weights
56
- def get_weighted_text_embeddings(
57
- pipe,
58
- prompt,
59
- negative_prompt=None
60
- ):
61
- """
62
- Processes prompts with attention weights and handles long prompts
63
- by chunking, applying weights, and combining embeddings
64
- """
65
- max_length = pipe.tokenizer.model_max_length
66
-
67
- # Process the input prompt with attention weights
68
- parsed_attention = parse_prompt_attention(prompt)
69
-
70
- # Handle long prompts by chunking them appropriately
71
- if len(prompt.split()) > 60: # Rough estimate of potentially exceeding token limit
72
- print(f"Long prompt detected. Will process in chunks.")
73
-
74
- # Remove and store attention weights for processing
75
- text_chunks = []
76
- current_length = 0
77
- current_chunk = ""
78
-
79
- words = prompt.split()
80
-
81
- for word in words:
82
- if current_length + len(word.split()) + 1 > 60: # Start a new chunk
83
- text_chunks.append(current_chunk)
84
- current_chunk = word
85
- current_length = len(word.split())
86
- else:
87
- if current_chunk:
88
- current_chunk += " " + word
89
- else:
90
- current_chunk = word
91
- current_length += len(word.split())
92
-
93
- if current_chunk:
94
- text_chunks.append(current_chunk)
95
-
96
- print(f"Split into {len(text_chunks)} chunks: {text_chunks}")
97
-
98
- # Process each chunk with the tokenizer and get embedding
99
- prompt_embeds_list = []
100
- pooled_prompt_embeds_list = []
101
-
102
- for text_chunk in text_chunks:
103
- text_input = pipe.tokenizer(
104
- text_chunk,
105
- padding="max_length",
106
- max_length=pipe.tokenizer.model_max_length,
107
- truncation=True,
108
- return_tensors="pt",
109
- )
110
- text_input = text_input.to(device)
111
-
112
- # Get text embeddings for both encoders
113
- prompt_embeds = pipe.text_encoder(text_input.input_ids)[0]
114
- pooled_prompt_embeds = pipe.text_encoder_2(text_input.input_ids)[0]
115
-
116
- prompt_embeds_list.append(prompt_embeds)
117
- pooled_prompt_embeds_list.append(pooled_prompt_embeds)
118
-
119
- # Average the embeddings from all chunks (alternatively could use max pooling or other methods)
120
- prompt_embeds = torch.stack(prompt_embeds_list).mean(dim=0)
121
- pooled_prompt_embeds = torch.stack(pooled_prompt_embeds_list).mean(dim=0)
122
-
123
- else:
124
- # For shorter prompts, just use the standard pipeline processing
125
- text_input = pipe.tokenizer(
126
- prompt,
127
- padding="max_length",
128
- max_length=pipe.tokenizer.model_max_length,
129
- truncation=True,
130
- return_tensors="pt",
131
- )
132
- text_input = text_input.to(device)
133
-
134
- prompt_embeds = pipe.text_encoder(text_input.input_ids)[0]
135
- pooled_prompt_embeds = pipe.text_encoder_2(text_input.input_ids)[0]
136
-
137
- # Process negative prompt if provided
138
- if negative_prompt is None:
139
- negative_prompt = ""
140
-
141
- uncond_input = pipe.tokenizer(
142
- negative_prompt,
143
- padding="max_length",
144
- max_length=pipe.tokenizer.model_max_length,
145
- truncation=True,
146
- return_tensors="pt",
147
- )
148
- uncond_input = uncond_input.to(device)
149
- negative_prompt_embeds = pipe.text_encoder(uncond_input.input_ids)[0]
150
- negative_pooled_prompt_embeds = pipe.text_encoder_2(uncond_input.input_ids)[0]
151
-
152
- # Combine positive and negative embeddings
153
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
154
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
155
 
156
- return prompt_embeds, pooled_prompt_embeds
157
-
158
- # Customized version of the generation function
159
  @spaces.GPU
160
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
163
 
164
  generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
  try:
167
- # Get embeddings with special handling for long prompts
168
- prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings(
169
- pipe,
170
- prompt,
171
- negative_prompt
172
- )
173
-
174
- # Use the custom embeddings to generate the image
175
  output_image = pipe(
176
- prompt_embeds=prompt_embeds,
177
- pooled_prompt_embeds=pooled_prompt_embeds,
178
  guidance_scale=guidance_scale,
179
  num_inference_steps=num_inference_steps,
180
  width=width,
@@ -205,8 +75,8 @@ with gr.Blocks(css=css) as demo:
205
  prompt = gr.Text(
206
  label="Prompt",
207
  show_label=False,
208
- max_lines=3, # Increased to allow longer prompts to be visible
209
- placeholder="Enter your prompt (supports long prompts now)",
210
  container=False,
211
  )
212
 
@@ -218,7 +88,7 @@ with gr.Blocks(css=css) as demo:
218
 
219
  negative_prompt = gr.Text(
220
  label="Negative prompt",
221
- max_lines=2, # Also increased for consistency
222
  placeholder="Enter a negative prompt",
223
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
224
  )
 
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
 
29
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
33
  @spaces.GPU
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
35
+ # Check and truncate prompt if too long (CLIP can only handle 77 tokens)
36
+ if len(prompt.split()) > 60: # Rough estimate to avoid exceeding token limit
37
+ print("Warning: Prompt may be too long and will be truncated by the model")
38
+
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator(device=device).manual_seed(seed)
43
 
44
  try:
 
 
 
 
 
 
 
 
45
  output_image = pipe(
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt,
48
  guidance_scale=guidance_scale,
49
  num_inference_steps=num_inference_steps,
50
  width=width,
 
75
  prompt = gr.Text(
76
  label="Prompt",
77
  show_label=False,
78
+ max_lines=1,
79
+ placeholder="Enter your prompt (keep it under 60 words for best results)",
80
  container=False,
81
  )
82
 
 
88
 
89
  negative_prompt = gr.Text(
90
  label="Negative prompt",
91
+ max_lines=1,
92
  placeholder="Enter a negative prompt",
93
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
94
  )