Spaces:
Running
Running
Commit
·
575f433
1
Parent(s):
e70b261
remove clear gpu memory
Browse files
app.py
CHANGED
@@ -22,31 +22,6 @@ DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
|
22 |
GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
|
23 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
24 |
|
25 |
-
def clear_gpu_memory(*args):
|
26 |
-
allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
|
27 |
-
reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
|
28 |
-
print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB")
|
29 |
-
|
30 |
-
deleted_types = []
|
31 |
-
for arg in args:
|
32 |
-
if arg is not None:
|
33 |
-
deleted_types.append(str(type(arg)))
|
34 |
-
del arg
|
35 |
-
|
36 |
-
if deleted_types:
|
37 |
-
print(f"Deleted objects of types: {', '.join(deleted_types)}")
|
38 |
-
else:
|
39 |
-
print("No objects passed to clear_gpu_memory.")
|
40 |
-
|
41 |
-
gc.collect()
|
42 |
-
if DEVICE == "cuda":
|
43 |
-
torch.cuda.empty_cache()
|
44 |
-
|
45 |
-
allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
|
46 |
-
reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
|
47 |
-
print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB")
|
48 |
-
print("-" * 20)
|
49 |
-
|
50 |
CACHED_PIPES = {}
|
51 |
def load_bf16_pipeline():
|
52 |
"""Loads the original FLUX.1-dev pipeline in BF16 precision."""
|
@@ -120,7 +95,7 @@ def load_bnb_4bit_pipeline():
|
|
120 |
|
121 |
@spaces.GPU(duration=240)
|
122 |
def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
|
123 |
-
"""Loads original and selected quantized model, generates one image each,
|
124 |
if not prompt:
|
125 |
return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
|
126 |
|
@@ -161,12 +136,6 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
|
|
161 |
print(f"\n--- Loading {label} Model ---")
|
162 |
load_start_time = time.time()
|
163 |
try:
|
164 |
-
# Ensure previous pipe is cleared *before* loading the next
|
165 |
-
# if current_pipe:
|
166 |
-
# print(f"--- Clearing memory before loading {label} Model ---")
|
167 |
-
# clear_gpu_memory(current_pipe)
|
168 |
-
# current_pipe = None
|
169 |
-
|
170 |
current_pipe = load_func()
|
171 |
load_end_time = time.time()
|
172 |
print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
|
@@ -184,22 +153,11 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
|
|
184 |
|
185 |
except Exception as e:
|
186 |
print(f"Error during {label} model processing: {e}")
|
187 |
-
# Attempt cleanup
|
188 |
-
if current_pipe:
|
189 |
-
print(f"--- Clearing memory after error with {label} Model ---")
|
190 |
-
clear_gpu_memory(current_pipe)
|
191 |
-
current_pipe = None
|
192 |
# Return error state to Gradio - update all outputs
|
193 |
return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
|
194 |
|
195 |
# No finally block needed here, cleanup happens before next load or after loop
|
196 |
|
197 |
-
# Final cleanup after the loop finishes successfully
|
198 |
-
# if current_pipe:
|
199 |
-
# print(f"--- Clearing memory after last model ({label}) ---")
|
200 |
-
# clear_gpu_memory(current_pipe)
|
201 |
-
# current_pipe = None
|
202 |
-
|
203 |
if len(results) != len(model_configs):
|
204 |
print("Generation did not complete for all models.")
|
205 |
# Update all outputs
|
@@ -275,7 +233,7 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
|
|
275 |
generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
|
276 |
|
277 |
output_gallery = gr.Gallery(
|
278 |
-
label="Generated Images
|
279 |
columns=2,
|
280 |
height=512,
|
281 |
object_fit="contain",
|
@@ -324,5 +282,5 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
|
|
324 |
|
325 |
if __name__ == "__main__":
|
326 |
# queue()
|
327 |
-
# demo.queue().launch()
|
328 |
-
demo.launch()
|
|
|
22 |
GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
|
23 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
CACHED_PIPES = {}
|
26 |
def load_bf16_pipeline():
|
27 |
"""Loads the original FLUX.1-dev pipeline in BF16 precision."""
|
|
|
95 |
|
96 |
@spaces.GPU(duration=240)
|
97 |
def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
|
98 |
+
"""Loads original and selected quantized model, generates one image each, shuffles results."""
|
99 |
if not prompt:
|
100 |
return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
|
101 |
|
|
|
136 |
print(f"\n--- Loading {label} Model ---")
|
137 |
load_start_time = time.time()
|
138 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
current_pipe = load_func()
|
140 |
load_end_time = time.time()
|
141 |
print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
|
|
|
153 |
|
154 |
except Exception as e:
|
155 |
print(f"Error during {label} model processing: {e}")
|
|
|
|
|
|
|
|
|
|
|
156 |
# Return error state to Gradio - update all outputs
|
157 |
return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
|
158 |
|
159 |
# No finally block needed here, cleanup happens before next load or after loop
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
if len(results) != len(model_configs):
|
162 |
print("Generation did not complete for all models.")
|
163 |
# Update all outputs
|
|
|
233 |
generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
|
234 |
|
235 |
output_gallery = gr.Gallery(
|
236 |
+
label="Generated Images",
|
237 |
columns=2,
|
238 |
height=512,
|
239 |
object_fit="contain",
|
|
|
282 |
|
283 |
if __name__ == "__main__":
|
284 |
# queue()
|
285 |
+
# demo.queue().launch()
|
286 |
+
demo.launch(share=True)
|