retwpay commited on
Commit
6cdbf05
·
verified ·
1 Parent(s): 4f5fe11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -9
app.py CHANGED
@@ -7,6 +7,7 @@ import random
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
@@ -29,22 +30,151 @@ pipe.unet.to(torch.float16)
29
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
33
  @spaces.GPU
34
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
35
- # Check and truncate prompt if too long (CLIP can only handle 77 tokens)
36
- if len(prompt.split()) > 60: # Rough estimate to avoid exceeding token limit
37
- print("Warning: Prompt may be too long and will be truncated by the model")
38
-
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator(device=device).manual_seed(seed)
43
 
44
  try:
 
 
 
 
 
 
 
 
45
  output_image = pipe(
46
- prompt=prompt,
47
- negative_prompt=negative_prompt,
48
  guidance_scale=guidance_scale,
49
  num_inference_steps=num_inference_steps,
50
  width=width,
@@ -75,8 +205,8 @@ with gr.Blocks(css=css) as demo:
75
  prompt = gr.Text(
76
  label="Prompt",
77
  show_label=False,
78
- max_lines=1,
79
- placeholder="Enter your prompt (keep it under 60 words for best results)",
80
  container=False,
81
  )
82
 
@@ -88,7 +218,7 @@ with gr.Blocks(css=css) as demo:
88
 
89
  negative_prompt = gr.Text(
90
  label="Negative prompt",
91
- max_lines=1,
92
  placeholder="Enter a negative prompt",
93
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
94
  )
 
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
+ import re
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1216
33
+
34
+ # Function to parse weighted prompts
35
+ def parse_prompt_attention(text):
36
+ """
37
+ Parses a prompt with attention weights
38
+ Examples:
39
+ "a (red:1.5) dress" -> weight "red" with 1.5
40
+ "a ((blue)) sky" -> weight "blue" with 2.0
41
+ """
42
+ re_attention = r'\((\()?([^:]+)(\))?(?::([\d\.]+))?\)'
43
+ res = []
44
+
45
+ for match in re.finditer(re_attention, text):
46
+ double_paren, content, _, weight = match.groups()
47
+ weight = float(weight) if weight is not None else 1.0
48
+ if double_paren:
49
+ weight = weight * 1.1 # Optional: make (()) slightly higher than ()
50
+
51
+ res.append((match.start(), match.end(), content, weight))
52
+
53
+ return res
54
+
55
+ # Function to process prompts with attention weights
56
+ def get_weighted_text_embeddings(
57
+ pipe,
58
+ prompt,
59
+ negative_prompt=None
60
+ ):
61
+ """
62
+ Processes prompts with attention weights and handles long prompts
63
+ by chunking, applying weights, and combining embeddings
64
+ """
65
+ max_length = pipe.tokenizer.model_max_length
66
+
67
+ # Process the input prompt with attention weights
68
+ parsed_attention = parse_prompt_attention(prompt)
69
+
70
+ # Handle long prompts by chunking them appropriately
71
+ if len(prompt.split()) > 60: # Rough estimate of potentially exceeding token limit
72
+ print(f"Long prompt detected. Will process in chunks.")
73
+
74
+ # Remove and store attention weights for processing
75
+ text_chunks = []
76
+ current_length = 0
77
+ current_chunk = ""
78
+
79
+ words = prompt.split()
80
+
81
+ for word in words:
82
+ if current_length + len(word.split()) + 1 > 60: # Start a new chunk
83
+ text_chunks.append(current_chunk)
84
+ current_chunk = word
85
+ current_length = len(word.split())
86
+ else:
87
+ if current_chunk:
88
+ current_chunk += " " + word
89
+ else:
90
+ current_chunk = word
91
+ current_length += len(word.split())
92
+
93
+ if current_chunk:
94
+ text_chunks.append(current_chunk)
95
+
96
+ print(f"Split into {len(text_chunks)} chunks: {text_chunks}")
97
+
98
+ # Process each chunk with the tokenizer and get embedding
99
+ prompt_embeds_list = []
100
+ pooled_prompt_embeds_list = []
101
+
102
+ for text_chunk in text_chunks:
103
+ text_input = pipe.tokenizer(
104
+ text_chunk,
105
+ padding="max_length",
106
+ max_length=pipe.tokenizer.model_max_length,
107
+ truncation=True,
108
+ return_tensors="pt",
109
+ )
110
+ text_input = text_input.to(device)
111
+
112
+ # Get text embeddings for both encoders
113
+ prompt_embeds = pipe.text_encoder(text_input.input_ids)[0]
114
+ pooled_prompt_embeds = pipe.text_encoder_2(text_input.input_ids)[0]
115
+
116
+ prompt_embeds_list.append(prompt_embeds)
117
+ pooled_prompt_embeds_list.append(pooled_prompt_embeds)
118
+
119
+ # Average the embeddings from all chunks (alternatively could use max pooling or other methods)
120
+ prompt_embeds = torch.stack(prompt_embeds_list).mean(dim=0)
121
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds_list).mean(dim=0)
122
+
123
+ else:
124
+ # For shorter prompts, just use the standard pipeline processing
125
+ text_input = pipe.tokenizer(
126
+ prompt,
127
+ padding="max_length",
128
+ max_length=pipe.tokenizer.model_max_length,
129
+ truncation=True,
130
+ return_tensors="pt",
131
+ )
132
+ text_input = text_input.to(device)
133
+
134
+ prompt_embeds = pipe.text_encoder(text_input.input_ids)[0]
135
+ pooled_prompt_embeds = pipe.text_encoder_2(text_input.input_ids)[0]
136
+
137
+ # Process negative prompt if provided
138
+ if negative_prompt is None:
139
+ negative_prompt = ""
140
+
141
+ uncond_input = pipe.tokenizer(
142
+ negative_prompt,
143
+ padding="max_length",
144
+ max_length=pipe.tokenizer.model_max_length,
145
+ truncation=True,
146
+ return_tensors="pt",
147
+ )
148
+ uncond_input = uncond_input.to(device)
149
+ negative_prompt_embeds = pipe.text_encoder(uncond_input.input_ids)[0]
150
+ negative_pooled_prompt_embeds = pipe.text_encoder_2(uncond_input.input_ids)[0]
151
+
152
+ # Combine positive and negative embeddings
153
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
154
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
155
 
156
+ return prompt_embeds, pooled_prompt_embeds
157
+
158
+ # Customized version of the generation function
159
  @spaces.GPU
160
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
163
 
164
  generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
  try:
167
+ # Get embeddings with special handling for long prompts
168
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings(
169
+ pipe,
170
+ prompt,
171
+ negative_prompt
172
+ )
173
+
174
+ # Use the custom embeddings to generate the image
175
  output_image = pipe(
176
+ prompt_embeds=prompt_embeds,
177
+ pooled_prompt_embeds=pooled_prompt_embeds,
178
  guidance_scale=guidance_scale,
179
  num_inference_steps=num_inference_steps,
180
  width=width,
 
205
  prompt = gr.Text(
206
  label="Prompt",
207
  show_label=False,
208
+ max_lines=3, # Increased to allow longer prompts to be visible
209
+ placeholder="Enter your prompt (supports long prompts now)",
210
  container=False,
211
  )
212
 
 
218
 
219
  negative_prompt = gr.Text(
220
  label="Negative prompt",
221
+ max_lines=2, # Also increased for consistency
222
  placeholder="Enter a negative prompt",
223
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
224
  )