Spaces:
Running
on
L40S
Running
on
L40S
import os | |
from install import install | |
if "HF_DEMO" in os.environ: | |
# Global variable to track if install() has been run; only for deploying on HF space | |
INSTALLED = False | |
if not INSTALLED: | |
install() | |
INSTALLED = True | |
import gradio as gr | |
import tempfile | |
import hashlib | |
import io | |
import pickle | |
import sys | |
from test import process_face | |
from PIL import Image | |
INPUT_CACHE_DIR = "./cache" | |
os.makedirs(INPUT_CACHE_DIR, exist_ok=True) | |
def get_image_hash(img): | |
"""Generate a hash of the image content.""" | |
img_bytes = io.BytesIO() | |
img.save(img_bytes, format='PNG') | |
return hashlib.md5(img_bytes.getvalue()).hexdigest() | |
def enhance_face_gradio(input_image, ref_image): | |
""" | |
Wrapper function for process_face that works with Gradio. | |
Args: | |
input_image: Input image from Gradio | |
ref_image: Reference face image from Gradio | |
Returns: | |
PIL Image: Enhanced image | |
""" | |
# Generate hashes for both images | |
input_hash = get_image_hash(input_image) | |
ref_hash = get_image_hash(ref_image) | |
combined_hash = f"{input_hash}_{ref_hash}" | |
cache_path = os.path.join(INPUT_CACHE_DIR, f"{combined_hash}.pkl") | |
# Check if result exists in cache | |
if os.path.exists(cache_path): | |
try: | |
with open(cache_path, 'rb') as f: | |
result_img = pickle.load(f) | |
print(f"Returning cached result for images with hash {combined_hash}") | |
return result_img | |
except (pickle.PickleError, IOError) as e: | |
print(f"Error loading from cache: {e}") | |
# Continue to processing if cache load fails | |
# Create temporary files for input, reference, and output | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as input_file, \ | |
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as ref_file, \ | |
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as output_file: | |
input_path = input_file.name | |
ref_path = ref_file.name | |
output_path = output_file.name | |
# Save uploaded images to temporary files | |
input_image.save(input_path) | |
ref_image.save(ref_path) | |
try: | |
process_face( | |
input_path=input_path, | |
ref_path=ref_path, | |
output_path=output_path | |
) | |
except Exception as e: | |
# Handle the error, log it, and return an error message | |
print(f"Error processing face: {e}") | |
return "An error occurred while processing the face. Please try again." | |
finally: | |
# Clean up temporary input and reference files | |
os.unlink(input_path) | |
os.unlink(ref_path) | |
# Load the output image | |
result_img = Image.open(output_path) | |
# Cache the result | |
try: | |
with open(cache_path, 'wb') as f: | |
pickle.dump(result_img, f) | |
print(f"Cached result for images with hash {combined_hash}") | |
except (pickle.PickleError, IOError) as e: | |
print(f"Error caching result: {e}") | |
return result_img | |
def create_gradio_interface(): | |
with gr.Blocks(title="Face Enhancement") as demo: | |
gr.Markdown(""" | |
# Face Enhance | |
### Instructions | |
1. Upload the target image you want to enhance | |
2. Upload a high-quality face image | |
3. Click 'Enhance Face' | |
Processing takes around 30 seconds. | |
""", elem_id="instructions") | |
gr.Markdown("---") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Target Image", type="pil") | |
ref_image = gr.Image(label="Reference Face", type="pil") | |
enhance_button = gr.Button("Enhance Face") | |
with gr.Column(): | |
output_image = gr.Image(label="Enhanced Result") | |
enhance_button.click( | |
fn=enhance_face_gradio, | |
inputs=[input_image, ref_image], | |
outputs=output_image, | |
queue=True # Enable queue for sequential processing | |
) | |
gr.Markdown(""" | |
## Examples | |
Click on an example to load the images into the interface. | |
""") | |
example_inps = [ | |
["examples/dany_gpt_1.png", "examples/dany_face.jpg"], | |
["examples/dany_gpt_2.png", "examples/dany_face.jpg"], | |
["examples/tim_gpt_1.png", "examples/tim_face.jpg"], | |
["examples/tim_gpt_2.png", "examples/tim_face.jpg"], | |
["examples/elon_gpt.png", "examples/elon_face.png"], | |
] | |
gr.Examples(examples=example_inps, inputs=[input_image, ref_image], outputs=output_image) | |
gr.Markdown(""" | |
## Notes | |
Check out the code [here](https://github.com/RishiDesai/FaceEnhance) and see my [blog post](https://rishidesai.github.io/posts/face-enhancement-techniques/) for more information. | |
Due to the constraints of this demo, face cropping and upscaling are not applied to the reference image. | |
""") | |
# Launch the Gradio app with queue | |
demo.queue(max_size=99) | |
try: | |
demo.launch() | |
except OSError as e: | |
print(f"Error starting server: {e}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
create_gradio_interface() |