import gradio as gr import torch import time from PIL import Image, UnidentifiedImageError import os import requests import io import uuid from pathlib import Path from huggingface_hub import hf_hub_download # Assuming these are correctly imported from your processing script from run_on_patches_online import ( load_model, process_image_from_data, Generator, # Make sure Generator is imported if needed by load_model DEVICE, #CHECKPOINT_GEN, PATCH_KERNEL_SIZE, PATCH_STRIDE ) # --- Global Variables & Model Loading --- HF_REPO_ID = "b-aryan/WM-rem-epoch-42" HF_FILENAME = "gen_epoch_42.pth.tar" CHECKPOINT_GEN = HF_FILENAME MODEL = None MODEL_LOAD_ERROR = None DOWNLOADED_CHECKPOINT_PATH = None # Store the path after download image_paths = ["1.png", "2.png", "4.png", "5.png", "6.png", "7.png"] images = [Image.open(path) for path in image_paths] N = len(images) print(f"Attempting to download/load model '{HF_FILENAME}' from repo '{HF_REPO_ID}' onto device '{DEVICE}'...") try: # 1. Download the model checkpoint from Hugging Face Hub print(f"Downloading checkpoint '{HF_FILENAME}' from '{HF_REPO_ID}'...") DOWNLOADED_CHECKPOINT_PATH = hf_hub_download( repo_id=HF_REPO_ID, filename=HF_FILENAME # cache_dir can be specified if needed, otherwise uses default HF cache ) print(f"Checkpoint downloaded successfully to: {DOWNLOADED_CHECKPOINT_PATH}") # 2. Load the model using the downloaded path if not os.path.exists(DOWNLOADED_CHECKPOINT_PATH): # This should ideally not happen if hf_hub_download succeeded raise FileNotFoundError(f"Downloaded checkpoint file not found at: {DOWNLOADED_CHECKPOINT_PATH}") MODEL = load_model(DOWNLOADED_CHECKPOINT_PATH, DEVICE) print("Model loaded successfully for Gradio app.") except Exception as e: MODEL_LOAD_ERROR = f"Failed to download or load model '{HF_FILENAME}' from '{HF_REPO_ID}'. Error: {e}" print(f"Error: {MODEL_LOAD_ERROR}") import traceback traceback.print_exc() TEMP_DIR = Path("temp") TEMP_DIR.mkdir(exist_ok=True) # --- Helper Function: Download Image (Simplified from run_on_patches_online) --- def download_image_for_gradio(url: str, timeout: int = 20) -> Image.Image | None: """Downloads an image from a URL for Gradio, returns PIL Image or raises gr.Error.""" print(f"Attempting to download image from: {url}") if not url or not url.startswith(('http://', 'https://')): # Don't raise error here, let the caller decide if URL is optional print("Invalid or empty URL provided.") return None # Indicate failure without raising error immediately try: headers = {'User-Agent': 'Gradio-Image-Processor/1.1'} response = requests.get(url, stream=True, timeout=timeout, headers=headers) response.raise_for_status() content_type = response.headers.get('Content-Type', '').lower() if not content_type.startswith('image/'): raise gr.Error(f"URL content type ({content_type}) is not recognized as an image.") # Limit download size (e.g., 20 MB) to prevent abuse content_length = response.headers.get('Content-Length') if content_length and int(content_length) > 20 * 1024 * 1024: raise gr.Error(f"Image file size ({int(content_length)/1024/1024:.1f} MB) exceeds the 20 MB limit.") image_bytes = response.content pil_image = Image.open(io.BytesIO(image_bytes)) pil_image = pil_image.convert('RGB') print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).") # Optional: Add image dimension limits if needed # max_dim = 2048 # if pil_image.width > max_dim or pil_image.height > max_dim: # raise gr.Error(f"Image dimensions ({pil_image.width}x{pil_image.height}) exceed the maximum allowed ({max_dim}x{max_dim}).") return pil_image except requests.exceptions.Timeout: raise gr.Error(f"Request timed out after {timeout} seconds trying to download the image.") except requests.exceptions.RequestException as e: raise gr.Error(f"Error downloading image: {e}") except UnidentifiedImageError: raise gr.Error("Could not identify image file. The URL might not point to a valid image.") except Exception as e: print(f"An unexpected error occurred during download: {e}") # Log for server admin raise gr.Error(f"An unexpected error occurred during image download.") # --- Processing Function (Handles the ML part with progress) --- def run_processing(input_pil_image: Image.Image, progress=gr.Progress(track_tqdm=True)): """Processes the input PIL image (from upload or download) and returns the result.""" if MODEL is None: # Include the more specific error message if loading failed error_msg = f"Model is not loaded. Cannot process image. Load error: {MODEL_LOAD_ERROR}" if MODEL_LOAD_ERROR else "Model is not loaded. Cannot process image." raise gr.Error(error_msg) if input_pil_image is None: # This case might happen if the input component feeding this function is cleared # before processing starts, or if the previous step failed silently. print("run_processing called with None input.") return None, None # Return None for both outputs start_time = time.time() print("Starting image processing...") progress(0, desc="Preparing for processing...") try: # Ensure image is RGB before processing if input_pil_image.mode != 'RGB': print(f"Converting input image from {input_pil_image.mode} to RGB.") input_pil_image = input_pil_image.convert('RGB') output_pil_image = process_image_from_data( input_pil_image=input_pil_image, model=MODEL, device=DEVICE, kernel_size=PATCH_KERNEL_SIZE, stride=PATCH_STRIDE, use_tqdm=True ) if output_pil_image is None: raise gr.Error("Image processing failed internally. Check server logs.") # Save the processed image with a unique filename filename = f"processed_{uuid.uuid4().hex}.jpg" save_path = TEMP_DIR / filename output_pil_image.save(save_path) print(f"Saved processed image to: {save_path}") except Exception as e: print(f"Exception during processing: {e}") import traceback traceback.print_exc() raise gr.Error(f"An error occurred during model processing: {e}") end_time = time.time() processing_time = end_time - start_time print(f"Processing time: {processing_time:.2f} seconds") return output_pil_image, str(save_path) # --- NEW Wrapper Function for Button Click (Handles URL or Upload) --- def handle_input(url, uploaded_image): """ Determines the input source (URL or Upload), prepares the image, and returns it for the 'original_image' display. Clears the other input method and the processed image output. Raises gr.Error if input is invalid or download fails. """ print("Handle input triggered.") # Debug print input_image = None final_url = url # Keep original url unless cleared final_uploaded = None # Clear upload by default if uploaded_image is not None: print("Processing uploaded image.") input_image = uploaded_image final_url = "" # Clear URL field if upload is used elif url: print("Processing URL.") try: input_image = download_image_for_gradio(url) if input_image is None: # download_image_for_gradio returns None for non-http/https URLs raise gr.Error("Invalid URL provided. Please enter a valid HTTP or HTTPS image URL.") # Keep the url in the box if download succeeds except gr.Error as e: # If download fails, raise the Gradio error again print(f"Download failed: {e}") # Debug print raise e except Exception as e: # Catch other unexpected errors during the wrapper logic print(f"Unexpected error in handle_input (download branch): {e}") # Debug print import traceback traceback.print_exc() # Raise a Gradio error so the user sees something meaningful raise gr.Error(f"An unexpected error occurred preparing the image from URL: {e}") else: # Neither URL nor upload provided raise gr.Error("Please provide an image URL or upload an image file.") # At this point, input_image should be a valid PIL image or an error was raised. # Return: # 1. The input PIL image for 'original_image' # 2. None to clear 'processed_image' # 3. The final URL value (empty if upload was used, original if URL was used) # 4. None to clear the file upload input return input_image, None, final_url, final_uploaded def show_image(index): index = index % N return images[index], index def next_img(index): return show_image(index + 1) def prev_img(index): return show_image(index - 1) # --- Gradio Blocks Interface Definition --- print("Setting up Gradio Blocks interface...") with gr.Blocks() as demo: saved_image_path = gr.State() index_state = gr.State(0) gr.Markdown(f""" # Watermark Remover **Disclaimer: This project was made to showcase my Deep Learning skills with no intention to cause harm to any business or infringe on any IP and it will be decommissioned as soon as I get a decent job offer. If you still have any issue, please reach out to me at [bhu.aryan.28@gmail.com](mailto:bhu.aryan.28@gmail.com).** """) with gr.Row(): with gr.Column(scale=1): gr.Markdown(f""" This is a demo of a DL model which takes in an image with a watermark (either via URL or direct upload) and gives out the image with its watermark removed. The image is broken into overlapping patches of {PATCH_KERNEL_SIZE}x{PATCH_KERNEL_SIZE} pixels with a stride of {PATCH_STRIDE} before feeding them to the model, the model infers on each patch separately and then they are stitched together to form the whole image again to give the final output. """) if MODEL_LOAD_ERROR: gr.Markdown(f"**Model Loading Error:** {MODEL_LOAD_ERROR}") # Input Method 1: URL image_url = gr.Textbox( label="Image URL", placeholder="Paste image URL (e.g., https://www.shutterstock.com/shutterstock/photos/.../image.jpg)", info="Enter the direct link to a publicly accessible image file." ) # Input Method 2: File Upload file_upload_input = gr.Image( label="Or Upload Image File", type="pil", # Keep as PIL for direct use sources=["upload"], # Specify only upload source height=130 ) submit_button = gr.Button("Process Image", variant="primary") gr.Examples( label="Click on the link to upload some test examples", examples=[ ["https://www.shutterstock.com/shutterstock/photos/2429728793/display_1500/stock-photo-monkey-funny-open-mouth-and-showing-teeth-crazy-nature-angry-short-hair-brown-grand-bassin-2429728793.jpg", None], # Need None for the upload input ["https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg", None] ], inputs=[image_url, file_upload_input] # Link examples to both inputs ) with gr.Column(scale=2): gr.Markdown(f"""### Some Samples""") img_display = gr.Image(type="pil", label="Watermarked Image vs Processed Image") with gr.Row(): prev_btn = gr.Button("<") next_btn = gr.Button(">") prev_btn.click(fn=prev_img, inputs=index_state, outputs=[img_display, index_state]) next_btn.click(fn=next_img, inputs=index_state, outputs=[img_display, index_state]) # Load the first image demo.load(fn=show_image, inputs=index_state, outputs=[img_display, index_state]) with gr.Row(): # Make original_image interactive=False as it's now purely a display original_image = gr.Image(label="Your Uploaded Image", type="pil", interactive=False) processed_image = gr.Image(label="Processed Output Image", type="pil", interactive=False) download_button = gr.DownloadButton("Download Processed Image", visible=True, variant="secondary") gr.Markdown("---") # Separator gr.Markdown(f""" **A bit about the model:** In this project, I have trained a GAN network, with the Generator being inspired from Pix2Pix and Pix2PixHD architectures and the Discriminator is very similar to PatchGAN in Pix2Pix. For the loss, I have also added Perceptual Loss using VGG like in Pix2PixHD and SRGAN papers apart from the L1 and BCE loss. """) gr.Markdown( f"""If you liked this project, you can find my CV [here](https://drive.google.com/file/d/1uGpjvQWmhkNN6nVQw3_9276EYroKVe_U/view?usp=sharing) or reach me out at [bhu.aryan.28@gmail.com](mailto:bhu.aryan.28@gmail.com).""") # --- Event Handling Logic --- # 1. When Button is clicked: # - Input: URL from textbox AND image from file upload # - Action: Call 'handle_input' wrapper function to decide which input to use, # download if needed, and handle errors. # - Output: Update 'original_image' with the chosen/downloaded image, # clear 'processed_image', clear the URL field if upload was used, # clear the file upload component. submit_button.click( fn=handle_input, inputs=[image_url, file_upload_input], outputs=[original_image, processed_image, image_url, file_upload_input] # Target outputs ) # 2. When 'original_image' component *changes* (i.e., after successful input handling): # - Input: The PIL image data from 'original_image' # - Action: Call 'run_processing' (this function has the progress bar) # - Output: Update the 'processed_image' component and the saved_image_path state. original_image.change( fn=run_processing, inputs=original_image, outputs=[processed_image, saved_image_path] # concurrency_limit=1 # Optional: Prevent multiple simultaneous processing runs if needed ) # 3. When 'processed_image' changes (after processing finishes or is cleared): # - Action: Update the visibility/style of the download button. processed_image.change( fn=lambda img: gr.update(variant="primary" if img is not None else "secondary"), inputs=processed_image, outputs=download_button ) # 4. When Download Button is clicked: # - Action: Provide the file path stored in saved_image_path state to the download component. download_button.click( fn=lambda path: path if path else None, inputs=saved_image_path, outputs=download_button # The output of the click event for DownloadButton is the file path itself ) # --- Launch the Application --- if __name__ == "__main__": print("Launching Gradio Blocks interface...") # Set queue=True for better handling under load, especially with long-running processing demo.queue() # Use queue for better user experience with processing time demo.launch(share=False, server_name="0.0.0.0", show_api=False)