File size: 5,268 Bytes
99745bb
a3f7eca
6f68b9e
99745bb
 
 
 
 
 
6f68b9e
9af3c99
 
a125be2
 
 
 
99745bb
e8c9f0d
34c6239
99745bb
 
a125be2
 
99745bb
a125be2
 
 
bf54c2a
9af3c99
 
 
 
 
 
 
 
 
e8c9f0d
9af3c99
a125be2
 
 
 
99745bb
a125be2
 
 
 
 
 
 
 
 
 
 
 
9af3c99
 
 
 
 
 
 
 
 
 
 
 
 
e8c9f0d
709fad3
 
 
 
 
e8c9f0d
 
 
 
 
 
 
 
9af3c99
a125be2
 
 
 
 
 
 
 
 
 
 
 
e8c9f0d
 
2b8d1f6
9b027fd
2b8d1f6
9b027fd
aad4f65
2b8d1f6
2fd1616
aad4f65
3c6dec0
9b027fd
 
 
 
e8c9f0d
 
9b027fd
e8c9f0d
 
 
 
 
 
 
 
 
 
 
 
2b8d1f6
 
 
 
aad4f65
bf54c2a
 
de96860
 
6f68b9e
aad4f65
 
9af3c99
2b8d1f6
 
 
 
 
 
 
e8c9f0d
bf54c2a
 
 
431d917
bf54c2a
 
 
e8c9f0d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
from install import install

if "HF_DEMO" in os.environ:
    # Global variable to track if install() has been run; only for deploying on HF space
    INSTALLED = False
    if not INSTALLED:
        install()
        INSTALLED = True

import gradio as gr
import tempfile
import hashlib
import io
import pickle
import sys
from test import process_face
from PIL import Image

INPUT_CACHE_DIR = "./cache"
os.makedirs(INPUT_CACHE_DIR, exist_ok=True)

def get_image_hash(img):
    """Generate a hash of the image content."""
    img_bytes = io.BytesIO()
    img.save(img_bytes, format='PNG')
    return hashlib.md5(img_bytes.getvalue()).hexdigest()

def enhance_face_gradio(input_image, ref_image):
    """
    Wrapper function for process_face that works with Gradio.
    
    Args:
        input_image: Input image from Gradio
        ref_image: Reference face image from Gradio
        
    Returns:
        PIL Image: Enhanced image
    """
    # Generate hashes for both images
    input_hash = get_image_hash(input_image)
    ref_hash = get_image_hash(ref_image)
    combined_hash = f"{input_hash}_{ref_hash}"
    cache_path = os.path.join(INPUT_CACHE_DIR, f"{combined_hash}.pkl")
    
    # Check if result exists in cache
    if os.path.exists(cache_path):
        try:
            with open(cache_path, 'rb') as f:
                result_img = pickle.load(f)
                print(f"Returning cached result for images with hash {combined_hash}")
                return result_img
        except (pickle.PickleError, IOError) as e:
            print(f"Error loading from cache: {e}")
            # Continue to processing if cache load fails
    
    # Create temporary files for input, reference, and output
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as input_file, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False) as ref_file, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False) as output_file:
        
        input_path = input_file.name
        ref_path = ref_file.name
        output_path = output_file.name
    
    # Save uploaded images to temporary files
    input_image.save(input_path)
    ref_image.save(ref_path)
    
    try:
        process_face(
            input_path=input_path,
            ref_path=ref_path,
            output_path=output_path
        )
    except Exception as e:
        # Handle the error, log it, and return an error message
        print(f"Error processing face: {e}")
        return "An error occurred while processing the face. Please try again."
    finally:
        # Clean up temporary input and reference files
        os.unlink(input_path)
        os.unlink(ref_path)
    
    # Load the output image
    result_img = Image.open(output_path)
    
    # Cache the result
    try:
        with open(cache_path, 'wb') as f:
            pickle.dump(result_img, f)
            print(f"Cached result for images with hash {combined_hash}")
    except (pickle.PickleError, IOError) as e:
        print(f"Error caching result: {e}")
    
    return result_img

def create_gradio_interface():
    with gr.Blocks(title="Face Enhancement") as demo:
        gr.Markdown("""
        # Face Enhance
        ### Instructions
        1. Upload the target image you want to enhance
        2. Upload a high-quality face image
        3. Click 'Enhance Face'

        Processing takes around 30 seconds.
        """, elem_id="instructions")

        gr.Markdown("---")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Target Image", type="pil")
                ref_image = gr.Image(label="Reference Face", type="pil")
                enhance_button = gr.Button("Enhance Face")
            
            with gr.Column():
                output_image = gr.Image(label="Enhanced Result")
        
        enhance_button.click(
            fn=enhance_face_gradio,
            inputs=[input_image, ref_image],
            outputs=output_image,
            queue=True  # Enable queue for sequential processing
        )
        gr.Markdown("""
        ## Examples
        Click on an example to load the images into the interface.
        """)
        example_inps = [
            ["examples/dany_gpt_1.png", "examples/dany_face.jpg"],
            ["examples/dany_gpt_2.png", "examples/dany_face.jpg"],
            ["examples/tim_gpt_1.png", "examples/tim_face.jpg"],
            ["examples/tim_gpt_2.png", "examples/tim_face.jpg"],
            ["examples/elon_gpt.png", "examples/elon_face.png"],
        ]
        gr.Examples(examples=example_inps, inputs=[input_image, ref_image], outputs=output_image)

        gr.Markdown("""
        ## Notes
        Check out the code [here](https://github.com/RishiDesai/FaceEnhance) and see my [blog post](https://rishidesai.github.io/posts/face-enhancement-techniques/) for more information.
        
        Due to the constraints of this demo, face cropping and upscaling are not applied to the reference image.
        """)

    # Launch the Gradio app with queue
    demo.queue(max_size=99)
    
    try:
        demo.launch()
    except OSError as e:
        print(f"Error starting server: {e}")
        sys.exit(1)


if __name__ == "__main__":
    create_gradio_interface()