WM-Remover / run_on_patches_online.py
b-aryan's picture
First Commit
be320a7 verified
import os
from PIL import Image, UnidentifiedImageError
import numpy as np
import torch
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import OrderedDict
from tqdm import tqdm
import requests
import io
from generator_model import Generator
# --- Constants ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_GEN = "gen_epoch_42.pth.tar" # Keep your checkpoint name
PATCH_KERNEL_SIZE = 256
PATCH_STRIDE = 64
# DEFAULT_INPUT_DIR = "test/inputs" # No longer needed for Gradio URL input
# DEFAULT_OUTPUT_DIR = "test/outputs" # Output handled by Gradio
SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') # Still useful for local testing if needed
test_transform = A.Compose([
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
ToTensorV2()
])
def load_model(checkpoint_path: str, device: str) -> Generator:
print(f"Loading model from: {checkpoint_path} onto device: {device}")
model = Generator(in_channels=3, features=64).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
new_state_dict = OrderedDict()
has_module_prefix = any(k.startswith("module.") for k in checkpoint["state_dict"])
for k, v in checkpoint["state_dict"].items():
name = k.replace("module.", "") if has_module_prefix else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval() # Set model to evaluation mode
print("Model loaded successfully.")
return model
def calculate_padding(img_h: int, img_w: int, kernel_size: int, stride: int) -> tuple[int, int]:
pad_h = kernel_size - img_h if img_h < kernel_size else (stride - (img_h - kernel_size) % stride) % stride
pad_w = kernel_size - img_w if img_w < kernel_size else (stride - (img_w - kernel_size) % stride) % stride
return pad_h, pad_w
def download_image(url: str, timeout: int = 15) -> Image.Image | None:
"""Downloads an image from a URL and returns it as a PIL Image object."""
print(f"Attempting to download image from: {url}")
try:
headers = {'User-Agent': 'Gradio-Image-Processor/1.0'} # Be a good net citizen
response = requests.get(url, stream=True, timeout=timeout, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
content_type = response.headers.get('Content-Type', '').lower()
if not content_type.startswith('image/'):
print(f"Error: URL content type ({content_type}) is not an image.")
return None
image_bytes = response.content
pil_image = Image.open(io.BytesIO(image_bytes))
pil_image = pil_image.convert('RGB') # Ensure image is in RGB format
print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
return pil_image
except requests.exceptions.Timeout:
print(f"Error: Request timed out after {timeout} seconds.")
return None
except requests.exceptions.RequestException as e:
print(f"Error downloading image: {e}")
return None
except UnidentifiedImageError:
print("Error: Could not identify image file. The URL might not point to a valid image.")
return None
except Exception as e:
print(f"An unexpected error occurred during download: {e}")
return None
def process_image_from_data(
input_pil_image: Image.Image,
model: Generator,
device: str,
kernel_size: int,
stride: int,
use_tqdm: bool = True # Optional: Control progress bar visibility
) -> Image.Image | None:
"""
Processes an input PIL image using the patch-based method and returns the output PIL image.
Returns None if an error occurs during processing.
"""
print(f"\nProcessing image data...")
try:
image_np = np.array(input_pil_image) # Convert PIL Image to NumPy array
H, W, _ = image_np.shape
print(f" Input dimensions: {W}x{H}")
# Apply transformations
transformed = test_transform(image=image_np)
input_tensor = transformed['image'].to(device) # Shape: (C, H, W)
C = input_tensor.shape[0]
# Calculate and apply padding
pad_h, pad_w = calculate_padding(H, W, kernel_size, stride)
print(f" Calculated padding (H, W): ({pad_h}, {pad_w})")
padded_tensor = F.pad(input_tensor.unsqueeze(0), (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
_, H_pad, W_pad = padded_tensor.shape
print(f" Padded dimensions: {W_pad}x{H_pad}")
# Extract patches
patches = padded_tensor.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
num_patches_h = patches.shape[1]
num_patches_w = patches.shape[2]
num_patches_total = num_patches_h * num_patches_w
print(f" Extracted {num_patches_total} patches ({num_patches_h} H x {num_patches_w} W)")
patches = patches.contiguous().view(C, -1, kernel_size, kernel_size)
# Permute to (num_patches_total, C, kernel_size, kernel_size)
patches = patches.permute(1, 0, 2, 3).contiguous()
output_patches = []
# Set up tqdm iterator if enabled
patch_iterator = tqdm(patches, total=num_patches_total, desc=" Inferring patches", unit="patch", leave=False, disable=not use_tqdm)
# --- Inference Loop ---
with torch.no_grad():
for patch in patch_iterator:
# Add batch dimension, run model, remove batch dimension
output_patch = model(patch.unsqueeze(0)).squeeze(0)
# Move to CPU immediately to save GPU memory during inference loop
output_patches.append(output_patch.cpu())
# Stack output patches back together
# If GPU memory allows, move back for reconstruction, otherwise keep on CPU
# Let's try moving back to device for faster reconstruction if possible
try:
output_patches = torch.stack(output_patches).to(device)
print(f" Output patches moved to {device} for reconstruction.")
except Exception as e: # Catch potential OOM on device
print(f" Warning: Could not move all output patches to {device} ({e}). Reconstruction might be slower on CPU.")
output_patches = torch.stack(output_patches) # Keep on CPU
# --- Reconstruction ---
# Generate 2D Hann window for blending
window_1d = torch.hann_window(kernel_size, periodic=False, device=device) # periodic=False is common
window_2d = torch.outer(window_1d, window_1d)
window_2d = window_2d.unsqueeze(0).to(device) # Add channel dim and ensure on device
# Initialize output tensor and weight tensor (for weighted averaging)
output_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=output_patches.dtype)
weight_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=window_2d.dtype)
patch_idx = 0
reconstruct_iterator = tqdm(total=num_patches_total, desc=" Reconstructing", unit="patch", leave=False, disable=not use_tqdm)
for i in range(num_patches_h):
for j in range(num_patches_w):
h_start = i * stride
w_start = j * stride
h_end = h_start + kernel_size
w_end = w_start + kernel_size
# Get current patch (ensure it's on the correct device)
current_patch = output_patches[patch_idx].to(device)
weighted_patch = current_patch * window_2d # Apply window
# Add weighted patch to output tensor
output_tensor[:, h_start:h_end, w_start:w_end] += weighted_patch
# Accumulate weights
weight_tensor[:, h_start:h_end, w_start:w_end] += window_2d
patch_idx += 1
reconstruct_iterator.update(1)
reconstruct_iterator.close() # Close the inner tqdm bar
# Perform weighted averaging - clamp weights to avoid division by zero
output_averaged = output_tensor / weight_tensor.clamp(min=1e-6)
# Crop to original dimensions
output_cropped = output_averaged[:, :H, :W]
print(f" Final output dimensions: {output_cropped.shape[2]}x{output_cropped.shape[1]}")
# --- Convert to Output Format ---
# Permute C, H, W -> H, W, C ; Move to CPU ; Convert to NumPy
output_numpy = output_cropped.permute(1, 2, 0).cpu().numpy()
# Denormalize: Assuming input was normalized to [-1, 1]
output_numpy = (output_numpy * 0.5 + 0.5) * 255.0
# Clip values to [0, 255] and convert to uint8
output_numpy = output_numpy.clip(0, 255).astype(np.uint8)
# Convert NumPy array back to PIL Image
output_image = Image.fromarray(output_numpy)
print(" Image processing complete.")
return output_image
except Exception as e:
print(f"Error during image processing: {e}")
import traceback
traceback.print_exc() # Print detailed traceback for debugging
return None
if __name__ == "__main__":
print("--- Testing Phase 1 Refactoring ---")
print(f"Using device: {DEVICE}")
print(f"Using patch kernel size: {PATCH_KERNEL_SIZE}")
print(f"Using patch stride: {PATCH_STRIDE}")
print(f"Using model checkpoint: {CHECKPOINT_GEN}")
# 1. Load the model (as it would be done globally in Gradio app)
try:
model = load_model(CHECKPOINT_GEN, DEVICE)
except Exception as e:
print(f"Failed to load model. Exiting test. Error: {e}")
exit()
# 2. Test URL download
# Replace with a valid image URL for testing
# test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
test_url = "https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg" # A smaller known image
input_pil = download_image(test_url)
if input_pil:
print(f"\nDownloaded image type: {type(input_pil)}, size: {input_pil.size}")
# 3. Test processing the downloaded image
output_pil = process_image_from_data(
input_pil_image=input_pil,
model=model,
device=DEVICE,
kernel_size=PATCH_KERNEL_SIZE,
stride=PATCH_STRIDE,
use_tqdm=True # Show progress bars during test
)
if output_pil:
print(f"\nProcessed image type: {type(output_pil)}, size: {output_pil.size}")
# Save the output locally for verification during testing
try:
os.makedirs("test_outputs", exist_ok=True)
output_filename = "test_output_" + os.path.basename(test_url).split('?')[0] # Basic filename extraction
if not output_filename.lower().endswith(SUPPORTED_EXTENSIONS):
output_filename += ".png" # Ensure it has an extension
save_path = os.path.join("test_outputs", output_filename)
output_pil.save(save_path)
print(f"Saved test output to: {save_path}")
except Exception as e:
print(f"Error saving test output: {e}")
else:
print("\nImage processing failed.")
else:
print("\nImage download failed.")
print("\n--- Phase 1 Testing Complete ---")