Spaces:
Running
Running
File size: 11,565 Bytes
be320a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
import os
from PIL import Image, UnidentifiedImageError
import numpy as np
import torch
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import OrderedDict
from tqdm import tqdm
import requests
import io
from generator_model import Generator
# --- Constants ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_GEN = "gen_epoch_42.pth.tar" # Keep your checkpoint name
PATCH_KERNEL_SIZE = 256
PATCH_STRIDE = 64
# DEFAULT_INPUT_DIR = "test/inputs" # No longer needed for Gradio URL input
# DEFAULT_OUTPUT_DIR = "test/outputs" # Output handled by Gradio
SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') # Still useful for local testing if needed
test_transform = A.Compose([
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
ToTensorV2()
])
def load_model(checkpoint_path: str, device: str) -> Generator:
print(f"Loading model from: {checkpoint_path} onto device: {device}")
model = Generator(in_channels=3, features=64).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
new_state_dict = OrderedDict()
has_module_prefix = any(k.startswith("module.") for k in checkpoint["state_dict"])
for k, v in checkpoint["state_dict"].items():
name = k.replace("module.", "") if has_module_prefix else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval() # Set model to evaluation mode
print("Model loaded successfully.")
return model
def calculate_padding(img_h: int, img_w: int, kernel_size: int, stride: int) -> tuple[int, int]:
pad_h = kernel_size - img_h if img_h < kernel_size else (stride - (img_h - kernel_size) % stride) % stride
pad_w = kernel_size - img_w if img_w < kernel_size else (stride - (img_w - kernel_size) % stride) % stride
return pad_h, pad_w
def download_image(url: str, timeout: int = 15) -> Image.Image | None:
"""Downloads an image from a URL and returns it as a PIL Image object."""
print(f"Attempting to download image from: {url}")
try:
headers = {'User-Agent': 'Gradio-Image-Processor/1.0'} # Be a good net citizen
response = requests.get(url, stream=True, timeout=timeout, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
content_type = response.headers.get('Content-Type', '').lower()
if not content_type.startswith('image/'):
print(f"Error: URL content type ({content_type}) is not an image.")
return None
image_bytes = response.content
pil_image = Image.open(io.BytesIO(image_bytes))
pil_image = pil_image.convert('RGB') # Ensure image is in RGB format
print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
return pil_image
except requests.exceptions.Timeout:
print(f"Error: Request timed out after {timeout} seconds.")
return None
except requests.exceptions.RequestException as e:
print(f"Error downloading image: {e}")
return None
except UnidentifiedImageError:
print("Error: Could not identify image file. The URL might not point to a valid image.")
return None
except Exception as e:
print(f"An unexpected error occurred during download: {e}")
return None
def process_image_from_data(
input_pil_image: Image.Image,
model: Generator,
device: str,
kernel_size: int,
stride: int,
use_tqdm: bool = True # Optional: Control progress bar visibility
) -> Image.Image | None:
"""
Processes an input PIL image using the patch-based method and returns the output PIL image.
Returns None if an error occurs during processing.
"""
print(f"\nProcessing image data...")
try:
image_np = np.array(input_pil_image) # Convert PIL Image to NumPy array
H, W, _ = image_np.shape
print(f" Input dimensions: {W}x{H}")
# Apply transformations
transformed = test_transform(image=image_np)
input_tensor = transformed['image'].to(device) # Shape: (C, H, W)
C = input_tensor.shape[0]
# Calculate and apply padding
pad_h, pad_w = calculate_padding(H, W, kernel_size, stride)
print(f" Calculated padding (H, W): ({pad_h}, {pad_w})")
padded_tensor = F.pad(input_tensor.unsqueeze(0), (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
_, H_pad, W_pad = padded_tensor.shape
print(f" Padded dimensions: {W_pad}x{H_pad}")
# Extract patches
patches = padded_tensor.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
num_patches_h = patches.shape[1]
num_patches_w = patches.shape[2]
num_patches_total = num_patches_h * num_patches_w
print(f" Extracted {num_patches_total} patches ({num_patches_h} H x {num_patches_w} W)")
patches = patches.contiguous().view(C, -1, kernel_size, kernel_size)
# Permute to (num_patches_total, C, kernel_size, kernel_size)
patches = patches.permute(1, 0, 2, 3).contiguous()
output_patches = []
# Set up tqdm iterator if enabled
patch_iterator = tqdm(patches, total=num_patches_total, desc=" Inferring patches", unit="patch", leave=False, disable=not use_tqdm)
# --- Inference Loop ---
with torch.no_grad():
for patch in patch_iterator:
# Add batch dimension, run model, remove batch dimension
output_patch = model(patch.unsqueeze(0)).squeeze(0)
# Move to CPU immediately to save GPU memory during inference loop
output_patches.append(output_patch.cpu())
# Stack output patches back together
# If GPU memory allows, move back for reconstruction, otherwise keep on CPU
# Let's try moving back to device for faster reconstruction if possible
try:
output_patches = torch.stack(output_patches).to(device)
print(f" Output patches moved to {device} for reconstruction.")
except Exception as e: # Catch potential OOM on device
print(f" Warning: Could not move all output patches to {device} ({e}). Reconstruction might be slower on CPU.")
output_patches = torch.stack(output_patches) # Keep on CPU
# --- Reconstruction ---
# Generate 2D Hann window for blending
window_1d = torch.hann_window(kernel_size, periodic=False, device=device) # periodic=False is common
window_2d = torch.outer(window_1d, window_1d)
window_2d = window_2d.unsqueeze(0).to(device) # Add channel dim and ensure on device
# Initialize output tensor and weight tensor (for weighted averaging)
output_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=output_patches.dtype)
weight_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=window_2d.dtype)
patch_idx = 0
reconstruct_iterator = tqdm(total=num_patches_total, desc=" Reconstructing", unit="patch", leave=False, disable=not use_tqdm)
for i in range(num_patches_h):
for j in range(num_patches_w):
h_start = i * stride
w_start = j * stride
h_end = h_start + kernel_size
w_end = w_start + kernel_size
# Get current patch (ensure it's on the correct device)
current_patch = output_patches[patch_idx].to(device)
weighted_patch = current_patch * window_2d # Apply window
# Add weighted patch to output tensor
output_tensor[:, h_start:h_end, w_start:w_end] += weighted_patch
# Accumulate weights
weight_tensor[:, h_start:h_end, w_start:w_end] += window_2d
patch_idx += 1
reconstruct_iterator.update(1)
reconstruct_iterator.close() # Close the inner tqdm bar
# Perform weighted averaging - clamp weights to avoid division by zero
output_averaged = output_tensor / weight_tensor.clamp(min=1e-6)
# Crop to original dimensions
output_cropped = output_averaged[:, :H, :W]
print(f" Final output dimensions: {output_cropped.shape[2]}x{output_cropped.shape[1]}")
# --- Convert to Output Format ---
# Permute C, H, W -> H, W, C ; Move to CPU ; Convert to NumPy
output_numpy = output_cropped.permute(1, 2, 0).cpu().numpy()
# Denormalize: Assuming input was normalized to [-1, 1]
output_numpy = (output_numpy * 0.5 + 0.5) * 255.0
# Clip values to [0, 255] and convert to uint8
output_numpy = output_numpy.clip(0, 255).astype(np.uint8)
# Convert NumPy array back to PIL Image
output_image = Image.fromarray(output_numpy)
print(" Image processing complete.")
return output_image
except Exception as e:
print(f"Error during image processing: {e}")
import traceback
traceback.print_exc() # Print detailed traceback for debugging
return None
if __name__ == "__main__":
print("--- Testing Phase 1 Refactoring ---")
print(f"Using device: {DEVICE}")
print(f"Using patch kernel size: {PATCH_KERNEL_SIZE}")
print(f"Using patch stride: {PATCH_STRIDE}")
print(f"Using model checkpoint: {CHECKPOINT_GEN}")
# 1. Load the model (as it would be done globally in Gradio app)
try:
model = load_model(CHECKPOINT_GEN, DEVICE)
except Exception as e:
print(f"Failed to load model. Exiting test. Error: {e}")
exit()
# 2. Test URL download
# Replace with a valid image URL for testing
# test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
test_url = "https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg" # A smaller known image
input_pil = download_image(test_url)
if input_pil:
print(f"\nDownloaded image type: {type(input_pil)}, size: {input_pil.size}")
# 3. Test processing the downloaded image
output_pil = process_image_from_data(
input_pil_image=input_pil,
model=model,
device=DEVICE,
kernel_size=PATCH_KERNEL_SIZE,
stride=PATCH_STRIDE,
use_tqdm=True # Show progress bars during test
)
if output_pil:
print(f"\nProcessed image type: {type(output_pil)}, size: {output_pil.size}")
# Save the output locally for verification during testing
try:
os.makedirs("test_outputs", exist_ok=True)
output_filename = "test_output_" + os.path.basename(test_url).split('?')[0] # Basic filename extraction
if not output_filename.lower().endswith(SUPPORTED_EXTENSIONS):
output_filename += ".png" # Ensure it has an extension
save_path = os.path.join("test_outputs", output_filename)
output_pil.save(save_path)
print(f"Saved test output to: {save_path}")
except Exception as e:
print(f"Error saving test output: {e}")
else:
print("\nImage processing failed.")
else:
print("\nImage download failed.")
print("\n--- Phase 1 Testing Complete ---") |