DerekLiu35 commited on
Commit
575f433
·
1 Parent(s): e70b261

remove clear gpu memory

Browse files
Files changed (1) hide show
  1. app.py +4 -46
app.py CHANGED
@@ -22,31 +22,6 @@ DEFAULT_MAX_SEQUENCE_LENGTH = 512
22
  GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
23
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
24
 
25
- def clear_gpu_memory(*args):
26
- allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
27
- reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
28
- print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB")
29
-
30
- deleted_types = []
31
- for arg in args:
32
- if arg is not None:
33
- deleted_types.append(str(type(arg)))
34
- del arg
35
-
36
- if deleted_types:
37
- print(f"Deleted objects of types: {', '.join(deleted_types)}")
38
- else:
39
- print("No objects passed to clear_gpu_memory.")
40
-
41
- gc.collect()
42
- if DEVICE == "cuda":
43
- torch.cuda.empty_cache()
44
-
45
- allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
46
- reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
47
- print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB")
48
- print("-" * 20)
49
-
50
  CACHED_PIPES = {}
51
  def load_bf16_pipeline():
52
  """Loads the original FLUX.1-dev pipeline in BF16 precision."""
@@ -120,7 +95,7 @@ def load_bnb_4bit_pipeline():
120
 
121
  @spaces.GPU(duration=240)
122
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
123
- """Loads original and selected quantized model, generates one image each, clears memory, shuffles results."""
124
  if not prompt:
125
  return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
126
 
@@ -161,12 +136,6 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
161
  print(f"\n--- Loading {label} Model ---")
162
  load_start_time = time.time()
163
  try:
164
- # Ensure previous pipe is cleared *before* loading the next
165
- # if current_pipe:
166
- # print(f"--- Clearing memory before loading {label} Model ---")
167
- # clear_gpu_memory(current_pipe)
168
- # current_pipe = None
169
-
170
  current_pipe = load_func()
171
  load_end_time = time.time()
172
  print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
@@ -184,22 +153,11 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
184
 
185
  except Exception as e:
186
  print(f"Error during {label} model processing: {e}")
187
- # Attempt cleanup
188
- if current_pipe:
189
- print(f"--- Clearing memory after error with {label} Model ---")
190
- clear_gpu_memory(current_pipe)
191
- current_pipe = None
192
  # Return error state to Gradio - update all outputs
193
  return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
194
 
195
  # No finally block needed here, cleanup happens before next load or after loop
196
 
197
- # Final cleanup after the loop finishes successfully
198
- # if current_pipe:
199
- # print(f"--- Clearing memory after last model ({label}) ---")
200
- # clear_gpu_memory(current_pipe)
201
- # current_pipe = None
202
-
203
  if len(results) != len(model_configs):
204
  print("Generation did not complete for all models.")
205
  # Update all outputs
@@ -275,7 +233,7 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
275
  generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
276
 
277
  output_gallery = gr.Gallery(
278
- label="Generated Images (Original vs. Quantized)",
279
  columns=2,
280
  height=512,
281
  object_fit="contain",
@@ -324,5 +282,5 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
324
 
325
  if __name__ == "__main__":
326
  # queue()
327
- # demo.queue().launch() # Set share=True to create public link if needed
328
- demo.launch()
 
22
  GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
23
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  CACHED_PIPES = {}
26
  def load_bf16_pipeline():
27
  """Loads the original FLUX.1-dev pipeline in BF16 precision."""
 
95
 
96
  @spaces.GPU(duration=240)
97
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
98
+ """Loads original and selected quantized model, generates one image each, shuffles results."""
99
  if not prompt:
100
  return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
101
 
 
136
  print(f"\n--- Loading {label} Model ---")
137
  load_start_time = time.time()
138
  try:
 
 
 
 
 
 
139
  current_pipe = load_func()
140
  load_end_time = time.time()
141
  print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
 
153
 
154
  except Exception as e:
155
  print(f"Error during {label} model processing: {e}")
 
 
 
 
 
156
  # Return error state to Gradio - update all outputs
157
  return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
158
 
159
  # No finally block needed here, cleanup happens before next load or after loop
160
 
 
 
 
 
 
 
161
  if len(results) != len(model_configs):
162
  print("Generation did not complete for all models.")
163
  # Update all outputs
 
233
  generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
234
 
235
  output_gallery = gr.Gallery(
236
+ label="Generated Images",
237
  columns=2,
238
  height=512,
239
  object_fit="contain",
 
282
 
283
  if __name__ == "__main__":
284
  # queue()
285
+ # demo.queue().launch()
286
+ demo.launch(share=True)