File size: 15,951 Bytes
be320a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5043b
28bca13
 
 
be320a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446b3da
 
 
be320a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446b3da
be320a7
 
 
 
 
446b3da
 
 
 
be320a7
 
 
 
 
 
446b3da
 
 
 
 
be320a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446b3da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be320a7
 
28bca13
 
 
 
 
 
 
 
 
 
 
be320a7
 
 
 
 
28bca13
be320a7
 
 
446b3da
be320a7
 
 
 
 
446b3da
 
 
 
 
 
 
 
 
 
be320a7
 
446b3da
 
be320a7
446b3da
 
 
 
 
 
 
 
 
be320a7
 
9a3aec3
446b3da
31d7cbf
 
446b3da
 
be320a7
 
 
446b3da
3c1267b
446b3da
 
 
 
 
 
 
 
 
 
be320a7
446b3da
9a3aec3
be320a7
 
 
446b3da
 
 
 
 
 
be320a7
446b3da
ae54b7c
be320a7
 
 
 
446b3da
 
 
 
 
 
be320a7
446b3da
 
 
be320a7
 
446b3da
be320a7
 
446b3da
be320a7
 
 
 
 
 
 
446b3da
 
be320a7
 
 
 
 
 
446b3da
 
be320a7
 
 
446b3da
be320a7
 
 
 
 
 
 
446b3da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import gradio as gr
import torch
import time
from PIL import Image, UnidentifiedImageError
import os
import requests
import io
import uuid
from pathlib import Path
from huggingface_hub import hf_hub_download


# Assuming these are correctly imported from your processing script
from run_on_patches_online import (
    load_model,
    process_image_from_data,
    Generator, # Make sure Generator is imported if needed by load_model
    DEVICE,
    #CHECKPOINT_GEN,
    PATCH_KERNEL_SIZE,
    PATCH_STRIDE
)

# --- Global Variables & Model Loading ---
HF_REPO_ID = "b-aryan/WM-rem-epoch-42"
HF_FILENAME = "gen_epoch_42.pth.tar"
CHECKPOINT_GEN = HF_FILENAME
MODEL = None
MODEL_LOAD_ERROR = None
DOWNLOADED_CHECKPOINT_PATH = None # Store the path after download
image_paths = ["1.png", "2.png", "4.png", "5.png", "6.png", "7.png"]
images = [Image.open(path) for path in image_paths]
N = len(images)


print(f"Attempting to download/load model '{HF_FILENAME}' from repo '{HF_REPO_ID}' onto device '{DEVICE}'...")

try:
    # 1. Download the model checkpoint from Hugging Face Hub
    print(f"Downloading checkpoint '{HF_FILENAME}' from '{HF_REPO_ID}'...")
    DOWNLOADED_CHECKPOINT_PATH = hf_hub_download(
        repo_id=HF_REPO_ID,
        filename=HF_FILENAME
        # cache_dir can be specified if needed, otherwise uses default HF cache
    )
    print(f"Checkpoint downloaded successfully to: {DOWNLOADED_CHECKPOINT_PATH}")

    # 2. Load the model using the downloaded path
    if not os.path.exists(DOWNLOADED_CHECKPOINT_PATH):
         # This should ideally not happen if hf_hub_download succeeded
         raise FileNotFoundError(f"Downloaded checkpoint file not found at: {DOWNLOADED_CHECKPOINT_PATH}")

    MODEL = load_model(DOWNLOADED_CHECKPOINT_PATH, DEVICE)
    print("Model loaded successfully for Gradio app.")

except Exception as e:
    MODEL_LOAD_ERROR = f"Failed to download or load model '{HF_FILENAME}' from '{HF_REPO_ID}'. Error: {e}"
    print(f"Error: {MODEL_LOAD_ERROR}")
    import traceback
    traceback.print_exc()

TEMP_DIR = Path("temp")
TEMP_DIR.mkdir(exist_ok=True)

# --- Helper Function: Download Image (Simplified from run_on_patches_online) ---
def download_image_for_gradio(url: str, timeout: int = 20) -> Image.Image | None:
    """Downloads an image from a URL for Gradio, returns PIL Image or raises gr.Error."""
    print(f"Attempting to download image from: {url}")
    if not url or not url.startswith(('http://', 'https://')):
        # Don't raise error here, let the caller decide if URL is optional
        print("Invalid or empty URL provided.")
        return None # Indicate failure without raising error immediately

    try:
        headers = {'User-Agent': 'Gradio-Image-Processor/1.1'}
        response = requests.get(url, stream=True, timeout=timeout, headers=headers)
        response.raise_for_status()

        content_type = response.headers.get('Content-Type', '').lower()
        if not content_type.startswith('image/'):
            raise gr.Error(f"URL content type ({content_type}) is not recognized as an image.")

        # Limit download size (e.g., 20 MB) to prevent abuse
        content_length = response.headers.get('Content-Length')
        if content_length and int(content_length) > 20 * 1024 * 1024:
             raise gr.Error(f"Image file size ({int(content_length)/1024/1024:.1f} MB) exceeds the 20 MB limit.")

        image_bytes = response.content
        pil_image = Image.open(io.BytesIO(image_bytes))
        pil_image = pil_image.convert('RGB')
        print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")

        # Optional: Add image dimension limits if needed
        # max_dim = 2048
        # if pil_image.width > max_dim or pil_image.height > max_dim:
        #     raise gr.Error(f"Image dimensions ({pil_image.width}x{pil_image.height}) exceed the maximum allowed ({max_dim}x{max_dim}).")

        return pil_image

    except requests.exceptions.Timeout:
        raise gr.Error(f"Request timed out after {timeout} seconds trying to download the image.")
    except requests.exceptions.RequestException as e:
        raise gr.Error(f"Error downloading image: {e}")
    except UnidentifiedImageError:
        raise gr.Error("Could not identify image file. The URL might not point to a valid image.")
    except Exception as e:
        print(f"An unexpected error occurred during download: {e}") # Log for server admin
        raise gr.Error(f"An unexpected error occurred during image download.")


# --- Processing Function (Handles the ML part with progress) ---
def run_processing(input_pil_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
    """Processes the input PIL image (from upload or download) and returns the result."""
    if MODEL is None:
        # Include the more specific error message if loading failed
        error_msg = f"Model is not loaded. Cannot process image. Load error: {MODEL_LOAD_ERROR}" if MODEL_LOAD_ERROR else "Model is not loaded. Cannot process image."
        raise gr.Error(error_msg)
    if input_pil_image is None:
        # This case might happen if the input component feeding this function is cleared
        # before processing starts, or if the previous step failed silently.
        print("run_processing called with None input.")
        return None, None # Return None for both outputs

    start_time = time.time()
    print("Starting image processing...")
    progress(0, desc="Preparing for processing...")

    try:
        # Ensure image is RGB before processing
        if input_pil_image.mode != 'RGB':
            print(f"Converting input image from {input_pil_image.mode} to RGB.")
            input_pil_image = input_pil_image.convert('RGB')

        output_pil_image = process_image_from_data(
            input_pil_image=input_pil_image,
            model=MODEL,
            device=DEVICE,
            kernel_size=PATCH_KERNEL_SIZE,
            stride=PATCH_STRIDE,
            use_tqdm=True
        )
        if output_pil_image is None:
            raise gr.Error("Image processing failed internally. Check server logs.")

        # Save the processed image with a unique filename
        filename = f"processed_{uuid.uuid4().hex}.jpg"
        save_path = TEMP_DIR / filename
        output_pil_image.save(save_path)
        print(f"Saved processed image to: {save_path}")

    except Exception as e:
        print(f"Exception during processing: {e}")
        import traceback
        traceback.print_exc()
        raise gr.Error(f"An error occurred during model processing: {e}")

    end_time = time.time()
    processing_time = end_time - start_time
    print(f"Processing time: {processing_time:.2f} seconds")

    return output_pil_image, str(save_path)

# --- NEW Wrapper Function for Button Click (Handles URL or Upload) ---
def handle_input(url, uploaded_image):
    """
    Determines the input source (URL or Upload), prepares the image,
    and returns it for the 'original_image' display.
    Clears the other input method and the processed image output.
    Raises gr.Error if input is invalid or download fails.
    """
    print("Handle input triggered.") # Debug print
    input_image = None
    final_url = url # Keep original url unless cleared
    final_uploaded = None # Clear upload by default

    if uploaded_image is not None:
        print("Processing uploaded image.")
        input_image = uploaded_image
        final_url = "" # Clear URL field if upload is used

    elif url:
        print("Processing URL.")
        try:
            input_image = download_image_for_gradio(url)
            if input_image is None:
                # download_image_for_gradio returns None for non-http/https URLs
                 raise gr.Error("Invalid URL provided. Please enter a valid HTTP or HTTPS image URL.")
            # Keep the url in the box if download succeeds
        except gr.Error as e:
            # If download fails, raise the Gradio error again
            print(f"Download failed: {e}") # Debug print
            raise e
        except Exception as e:
            # Catch other unexpected errors during the wrapper logic
            print(f"Unexpected error in handle_input (download branch): {e}") # Debug print
            import traceback
            traceback.print_exc()
            # Raise a Gradio error so the user sees something meaningful
            raise gr.Error(f"An unexpected error occurred preparing the image from URL: {e}")
    else:
        # Neither URL nor upload provided
        raise gr.Error("Please provide an image URL or upload an image file.")

    # At this point, input_image should be a valid PIL image or an error was raised.
    # Return:
    # 1. The input PIL image for 'original_image'
    # 2. None to clear 'processed_image'
    # 3. The final URL value (empty if upload was used, original if URL was used)
    # 4. None to clear the file upload input
    return input_image, None, final_url, final_uploaded


def show_image(index):
    index = index % N
    return images[index], index

def next_img(index):
    return show_image(index + 1)

def prev_img(index):
    return show_image(index - 1)


# --- Gradio Blocks Interface Definition ---
print("Setting up Gradio Blocks interface...")

with gr.Blocks() as demo:
    saved_image_path = gr.State()
    index_state = gr.State(0)
    gr.Markdown(f"""
        # Watermark Remover

        **Disclaimer: This project was made to showcase my Deep Learning skills with no intention to cause harm to any business or infringe on any IP and it will be decommissioned as soon as I get a decent job offer.
        If you still have any issue, please reach out to me at [[email protected]](mailto:[email protected]).**
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"""
                This is a demo of a DL model which takes in an image with a watermark (either via URL or direct upload) and gives out the image with its watermark removed.
                The image is broken into overlapping patches of {PATCH_KERNEL_SIZE}x{PATCH_KERNEL_SIZE} pixels with a stride of {PATCH_STRIDE} before feeding them to the model,
                the model infers on each patch separately and then they are stitched together to form the whole image again to give the final output.
            """)

            if MODEL_LOAD_ERROR:
                gr.Markdown(f"<font color='red'>**Model Loading Error:** {MODEL_LOAD_ERROR}</font>")

            # Input Method 1: URL
            image_url = gr.Textbox(
                label="Image URL",
                placeholder="Paste image URL (e.g., https://www.shutterstock.com/shutterstock/photos/.../image.jpg)",
                info="Enter the direct link to a publicly accessible image file."
            )

            # Input Method 2: File Upload
            file_upload_input = gr.Image(
                label="Or Upload Image File",
                type="pil", # Keep as PIL for direct use
                sources=["upload"], # Specify only upload source
                height=130
            )

            submit_button = gr.Button("Process Image", variant="primary")
            gr.Examples(
                label="Click on the link to upload some test examples",
                examples=[
                    ["https://www.shutterstock.com/shutterstock/photos/2429728793/display_1500/stock-photo-monkey-funny-open-mouth-and-showing-teeth-crazy-nature-angry-short-hair-brown-grand-bassin-2429728793.jpg", None], # Need None for the upload input
                    ["https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg", None]
                ],
                inputs=[image_url, file_upload_input] # Link examples to both inputs
            )

        with gr.Column(scale=2):
            gr.Markdown(f"""### Some Samples""")
            img_display = gr.Image(type="pil", label="Watermarked Image vs Processed Image")
            with gr.Row():
                prev_btn = gr.Button("<")
                next_btn = gr.Button(">")

            prev_btn.click(fn=prev_img, inputs=index_state, outputs=[img_display, index_state])
            next_btn.click(fn=next_img, inputs=index_state, outputs=[img_display, index_state])

            # Load the first image
            demo.load(fn=show_image, inputs=index_state, outputs=[img_display, index_state])

            with gr.Row():
                # Make original_image interactive=False as it's now purely a display
                original_image = gr.Image(label="Your Uploaded Image", type="pil", interactive=False)
                processed_image = gr.Image(label="Processed Output Image", type="pil", interactive=False)
            download_button = gr.DownloadButton("Download Processed Image", visible=True, variant="secondary")

    gr.Markdown("---")  # Separator
    gr.Markdown(f"""
                    **A bit about the model:** In this project, I have trained a GAN network,
                    with the Generator being inspired from Pix2Pix and Pix2PixHD architectures and the Discriminator is very similar to PatchGAN in Pix2Pix.
                    For the loss, I have also added Perceptual Loss using VGG like in Pix2PixHD and SRGAN papers apart from the L1 and BCE loss.
                """)

    gr.Markdown(
        f"""If you liked this project, you can find my CV [here](https://drive.google.com/file/d/1uGpjvQWmhkNN6nVQw3_9276EYroKVe_U/view?usp=sharing) or reach me out at [[email protected]](mailto:[email protected]).""")

    # --- Event Handling Logic ---

    # 1. When Button is clicked:
    #    - Input: URL from textbox AND image from file upload
    #    - Action: Call 'handle_input' wrapper function to decide which input to use,
    #              download if needed, and handle errors.
    #    - Output: Update 'original_image' with the chosen/downloaded image,
    #              clear 'processed_image', clear the URL field if upload was used,
    #              clear the file upload component.
    submit_button.click(
        fn=handle_input,
        inputs=[image_url, file_upload_input],
        outputs=[original_image, processed_image, image_url, file_upload_input] # Target outputs
    )

    # 2. When 'original_image' component *changes* (i.e., after successful input handling):
    #    - Input: The PIL image data from 'original_image'
    #    - Action: Call 'run_processing' (this function has the progress bar)
    #    - Output: Update the 'processed_image' component and the saved_image_path state.
    original_image.change(
        fn=run_processing,
        inputs=original_image,
        outputs=[processed_image, saved_image_path]
        # concurrency_limit=1 # Optional: Prevent multiple simultaneous processing runs if needed
    )

    # 3. When 'processed_image' changes (after processing finishes or is cleared):
    #    - Action: Update the visibility/style of the download button.
    processed_image.change(
        fn=lambda img: gr.update(variant="primary" if img is not None else "secondary"),
        inputs=processed_image,
        outputs=download_button
    )

    # 4. When Download Button is clicked:
    #    - Action: Provide the file path stored in saved_image_path state to the download component.
    download_button.click(
        fn=lambda path: path if path else None,
        inputs=saved_image_path,
        outputs=download_button # The output of the click event for DownloadButton is the file path itself
    )


# --- Launch the Application ---
if __name__ == "__main__":
    print("Launching Gradio Blocks interface...")
    # Set queue=True for better handling under load, especially with long-running processing
    demo.queue() # Use queue for better user experience with processing time
    demo.launch(share=False, server_name="0.0.0.0", show_api=False)