import os from PIL import Image, UnidentifiedImageError import numpy as np import torch import torch.nn.functional as F import albumentations as A from albumentations.pytorch import ToTensorV2 from collections import OrderedDict from tqdm import tqdm import requests import io from generator_model import Generator # --- Constants --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CHECKPOINT_GEN = "gen_epoch_42.pth.tar" # Keep your checkpoint name PATCH_KERNEL_SIZE = 256 PATCH_STRIDE = 64 # DEFAULT_INPUT_DIR = "test/inputs" # No longer needed for Gradio URL input # DEFAULT_OUTPUT_DIR = "test/outputs" # Output handled by Gradio SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') # Still useful for local testing if needed test_transform = A.Compose([ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2() ]) def load_model(checkpoint_path: str, device: str) -> Generator: print(f"Loading model from: {checkpoint_path} onto device: {device}") model = Generator(in_channels=3, features=64).to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) new_state_dict = OrderedDict() has_module_prefix = any(k.startswith("module.") for k in checkpoint["state_dict"]) for k, v in checkpoint["state_dict"].items(): name = k.replace("module.", "") if has_module_prefix else k new_state_dict[name] = v model.load_state_dict(new_state_dict) model.eval() # Set model to evaluation mode print("Model loaded successfully.") return model def calculate_padding(img_h: int, img_w: int, kernel_size: int, stride: int) -> tuple[int, int]: pad_h = kernel_size - img_h if img_h < kernel_size else (stride - (img_h - kernel_size) % stride) % stride pad_w = kernel_size - img_w if img_w < kernel_size else (stride - (img_w - kernel_size) % stride) % stride return pad_h, pad_w def download_image(url: str, timeout: int = 15) -> Image.Image | None: """Downloads an image from a URL and returns it as a PIL Image object.""" print(f"Attempting to download image from: {url}") try: headers = {'User-Agent': 'Gradio-Image-Processor/1.0'} # Be a good net citizen response = requests.get(url, stream=True, timeout=timeout, headers=headers) response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) content_type = response.headers.get('Content-Type', '').lower() if not content_type.startswith('image/'): print(f"Error: URL content type ({content_type}) is not an image.") return None image_bytes = response.content pil_image = Image.open(io.BytesIO(image_bytes)) pil_image = pil_image.convert('RGB') # Ensure image is in RGB format print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).") return pil_image except requests.exceptions.Timeout: print(f"Error: Request timed out after {timeout} seconds.") return None except requests.exceptions.RequestException as e: print(f"Error downloading image: {e}") return None except UnidentifiedImageError: print("Error: Could not identify image file. The URL might not point to a valid image.") return None except Exception as e: print(f"An unexpected error occurred during download: {e}") return None def process_image_from_data( input_pil_image: Image.Image, model: Generator, device: str, kernel_size: int, stride: int, use_tqdm: bool = True # Optional: Control progress bar visibility ) -> Image.Image | None: """ Processes an input PIL image using the patch-based method and returns the output PIL image. Returns None if an error occurs during processing. """ print(f"\nProcessing image data...") try: image_np = np.array(input_pil_image) # Convert PIL Image to NumPy array H, W, _ = image_np.shape print(f" Input dimensions: {W}x{H}") # Apply transformations transformed = test_transform(image=image_np) input_tensor = transformed['image'].to(device) # Shape: (C, H, W) C = input_tensor.shape[0] # Calculate and apply padding pad_h, pad_w = calculate_padding(H, W, kernel_size, stride) print(f" Calculated padding (H, W): ({pad_h}, {pad_w})") padded_tensor = F.pad(input_tensor.unsqueeze(0), (0, pad_w, 0, pad_h), mode='reflect').squeeze(0) _, H_pad, W_pad = padded_tensor.shape print(f" Padded dimensions: {W_pad}x{H_pad}") # Extract patches patches = padded_tensor.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride) num_patches_h = patches.shape[1] num_patches_w = patches.shape[2] num_patches_total = num_patches_h * num_patches_w print(f" Extracted {num_patches_total} patches ({num_patches_h} H x {num_patches_w} W)") patches = patches.contiguous().view(C, -1, kernel_size, kernel_size) # Permute to (num_patches_total, C, kernel_size, kernel_size) patches = patches.permute(1, 0, 2, 3).contiguous() output_patches = [] # Set up tqdm iterator if enabled patch_iterator = tqdm(patches, total=num_patches_total, desc=" Inferring patches", unit="patch", leave=False, disable=not use_tqdm) # --- Inference Loop --- with torch.no_grad(): for patch in patch_iterator: # Add batch dimension, run model, remove batch dimension output_patch = model(patch.unsqueeze(0)).squeeze(0) # Move to CPU immediately to save GPU memory during inference loop output_patches.append(output_patch.cpu()) # Stack output patches back together # If GPU memory allows, move back for reconstruction, otherwise keep on CPU # Let's try moving back to device for faster reconstruction if possible try: output_patches = torch.stack(output_patches).to(device) print(f" Output patches moved to {device} for reconstruction.") except Exception as e: # Catch potential OOM on device print(f" Warning: Could not move all output patches to {device} ({e}). Reconstruction might be slower on CPU.") output_patches = torch.stack(output_patches) # Keep on CPU # --- Reconstruction --- # Generate 2D Hann window for blending window_1d = torch.hann_window(kernel_size, periodic=False, device=device) # periodic=False is common window_2d = torch.outer(window_1d, window_1d) window_2d = window_2d.unsqueeze(0).to(device) # Add channel dim and ensure on device # Initialize output tensor and weight tensor (for weighted averaging) output_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=output_patches.dtype) weight_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=window_2d.dtype) patch_idx = 0 reconstruct_iterator = tqdm(total=num_patches_total, desc=" Reconstructing", unit="patch", leave=False, disable=not use_tqdm) for i in range(num_patches_h): for j in range(num_patches_w): h_start = i * stride w_start = j * stride h_end = h_start + kernel_size w_end = w_start + kernel_size # Get current patch (ensure it's on the correct device) current_patch = output_patches[patch_idx].to(device) weighted_patch = current_patch * window_2d # Apply window # Add weighted patch to output tensor output_tensor[:, h_start:h_end, w_start:w_end] += weighted_patch # Accumulate weights weight_tensor[:, h_start:h_end, w_start:w_end] += window_2d patch_idx += 1 reconstruct_iterator.update(1) reconstruct_iterator.close() # Close the inner tqdm bar # Perform weighted averaging - clamp weights to avoid division by zero output_averaged = output_tensor / weight_tensor.clamp(min=1e-6) # Crop to original dimensions output_cropped = output_averaged[:, :H, :W] print(f" Final output dimensions: {output_cropped.shape[2]}x{output_cropped.shape[1]}") # --- Convert to Output Format --- # Permute C, H, W -> H, W, C ; Move to CPU ; Convert to NumPy output_numpy = output_cropped.permute(1, 2, 0).cpu().numpy() # Denormalize: Assuming input was normalized to [-1, 1] output_numpy = (output_numpy * 0.5 + 0.5) * 255.0 # Clip values to [0, 255] and convert to uint8 output_numpy = output_numpy.clip(0, 255).astype(np.uint8) # Convert NumPy array back to PIL Image output_image = Image.fromarray(output_numpy) print(" Image processing complete.") return output_image except Exception as e: print(f"Error during image processing: {e}") import traceback traceback.print_exc() # Print detailed traceback for debugging return None if __name__ == "__main__": print("--- Testing Phase 1 Refactoring ---") print(f"Using device: {DEVICE}") print(f"Using patch kernel size: {PATCH_KERNEL_SIZE}") print(f"Using patch stride: {PATCH_STRIDE}") print(f"Using model checkpoint: {CHECKPOINT_GEN}") # 1. Load the model (as it would be done globally in Gradio app) try: model = load_model(CHECKPOINT_GEN, DEVICE) except Exception as e: print(f"Failed to load model. Exiting test. Error: {e}") exit() # 2. Test URL download # Replace with a valid image URL for testing # test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg" test_url = "https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg" # A smaller known image input_pil = download_image(test_url) if input_pil: print(f"\nDownloaded image type: {type(input_pil)}, size: {input_pil.size}") # 3. Test processing the downloaded image output_pil = process_image_from_data( input_pil_image=input_pil, model=model, device=DEVICE, kernel_size=PATCH_KERNEL_SIZE, stride=PATCH_STRIDE, use_tqdm=True # Show progress bars during test ) if output_pil: print(f"\nProcessed image type: {type(output_pil)}, size: {output_pil.size}") # Save the output locally for verification during testing try: os.makedirs("test_outputs", exist_ok=True) output_filename = "test_output_" + os.path.basename(test_url).split('?')[0] # Basic filename extraction if not output_filename.lower().endswith(SUPPORTED_EXTENSIONS): output_filename += ".png" # Ensure it has an extension save_path = os.path.join("test_outputs", output_filename) output_pil.save(save_path) print(f"Saved test output to: {save_path}") except Exception as e: print(f"Error saving test output: {e}") else: print("\nImage processing failed.") else: print("\nImage download failed.") print("\n--- Phase 1 Testing Complete ---")