abiabidali's picture
Update app.py
873696c verified
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
import numpy as np
import tempfile
import time
import os
from transformers import pipeline
import csv
import zipfile
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load RealESRGAN model with specified scale
def load_model(scale):
model = RealESRGAN(device, scale=scale)
weights_path = f'weights/RealESRGAN_x{scale}.pth'
try:
model.load_weights(weights_path, download=True)
print(f"Weights for scale {scale} loaded successfully.")
except Exception as e:
print(f"Error loading weights for scale {scale}: {e}")
model.load_weights(weights_path, download=False)
return model
# Load models for different scales
model2 = load_model(2)
model4 = load_model(4)
model8 = load_model(8)
# Hugging Face image description pipeline
description_generator = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
# Enhance image based on selected scale
def enhance_image(image, scale):
try:
image_np = np.array(image.convert('RGB'))
if scale == '2x':
result = model2.predict(image_np)
elif scale == '4x':
result = model4.predict(image_np)
else:
result = model8.predict(image_np)
return Image.fromarray(np.uint8(result))
except Exception as e:
print(f"Error enhancing image: {e}")
return image
# Generate image description
def generate_description(image):
try:
description = description_generator(image)[0]['generated_text']
return description
except Exception as e:
print(f"Error generating description: {e}")
return "Description unavailable."
# Adjust DPI
def muda_dpi(input_image, dpi):
dpi_tuple = (dpi, dpi)
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(temp_file, format='JPEG', dpi=dpi_tuple)
temp_file.close()
return Image.open(temp_file.name)
# Resize an image
def resize_image(input_image, width, height):
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
resized_image = image.resize((width, height))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
resized_image.save(temp_file, format='JPEG')
temp_file.close()
return Image.open(temp_file.name)
# Process images and generate a ZIP file with images and CSV
def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
processed_images = []
file_paths = []
descriptions = []
# Temporary CSV file path
csv_file_path = os.path.join(tempfile.gettempdir(), "image_descriptions.csv")
with open(csv_file_path, mode="w", newline="") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["Filename", "Title", "Keywords"])
for image_file in image_files:
input_image = np.array(Image.open(image_file).convert('RGB'))
original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
if enhance:
original_image = enhance_image(original_image, scale)
if adjust_dpi:
original_image = muda_dpi(np.array(original_image), dpi)
if resize:
original_image = resize_image(np.array(original_image), width, height)
# Generate description
description = generate_description(original_image)
title = description # Using description as the title
keywords = ", ".join(set(description.split()))[:45] # Limit to 45 unique words
# Clean the filename
base_name = os.path.basename(image_file.name)
file_name, _ = os.path.splitext(base_name)
file_name = ''.join(e for e in file_name if e.isalnum() or e in (' ', '_', '-')).strip().replace(' ', '_')
# Final image path
output_path = os.path.join(tempfile.gettempdir(), f"{file_name}.jpg")
original_image.save(output_path, format='JPEG')
# Write to CSV
writer.writerow([file_name, title, keywords])
# Collect image paths and descriptions
processed_images.append(original_image)
file_paths.append(output_path)
descriptions.append(description)
# Create a ZIP file with all images and CSV
zip_file_path = os.path.join(tempfile.gettempdir(), "processed_images.zip")
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
for file_path in file_paths:
zipf.write(file_path, arcname=os.path.basename(file_path))
zipf.write(csv_file_path, arcname="image_descriptions.csv")
return processed_images, zip_file_path, descriptions
# Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
gr.Files(label="Upload Image Files"),
gr.Checkbox(label="Enhance Images (ESRGAN)"),
gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'),
gr.Checkbox(label="Adjust DPI"),
gr.Number(label="DPI", value=300),
gr.Checkbox(label="Resize"),
gr.Number(label="Width", value=512),
gr.Number(label="Height", value=512)
],
outputs=[
gr.Gallery(label="Final Images"),
gr.File(label="Download ZIP of Images and Descriptions"),
gr.Textbox(label="Image Descriptions", lines=5)
],
title="Multi-Image Enhancer with Hugging Face Descriptions",
description="Upload multiple images, enhance, adjust DPI, resize, generate descriptions, and download the results and a ZIP archive."
)
iface.launch(debug=True, share=True)