Spaces:
Paused
Paused
File size: 6,766 Bytes
cae212d 164e0dd cae212d 164e0dd cae212d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import streamlit as st
import cv2
import os
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
# Function to load the model
def load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id):
if model_name == 'RealESRGAN_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
# Determine model paths
if model_path is not None:
model_path = model_path
else:
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
for url in file_url:
# Model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(os.getcwd(), 'weights'), progress=True, file_name=model_name + '.pth')
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
model_path = [model_path, model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Use DNI to control the denoise strength
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=not fp32,
gpu_id=gpu_id)
return upsampler
# Function to download model weights if not present
def ensure_model_weights(model_name):
weights_dir = 'weights'
model_file = f"{model_name}.pth"
model_path = os.path.join(weights_dir, model_file)
if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
if not os.path.isfile(model_path):
if model_name == 'RealESRGAN_x4plus':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
elif model_name == 'RealESRGAN_x4plus_anime_6B':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
elif model_name == 'RealESRGAN_x2plus':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
model_path = load_file_from_url(
url=file_url, model_dir=weights_dir, progress=True, file_name=model_file)
return model_path
# Streamlit app
st.title("Real-ESRGAN Image Enhancement")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# User selects model name, denoise strength, and other parameters
model_name = st.selectbox("Model Name", ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus'])
denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
outscale = st.slider("Output Scale", 1, 4, 2) # Reduce output scale to 2
tile = st.slider("Tile Size", 0, 512, 256) # Add tile size slider
tile_pad = 10
pre_pad = 0
face_enhance = st.checkbox("Face Enhance")
fp32 = st.checkbox("Use FP32 Precision")
gpu_id = None # or set to 0, 1, etc. if you have multiple GPUs
if uploaded_file is not None:
col1, col2 = st.columns(2)
with col1:
st.write("### Original Image")
st.image(uploaded_file, use_column_width=True)
run_button = st.button("Run")
# Save uploaded image to disk
input_image_path = os.path.join("temp", "input_image.png")
os.makedirs("temp", exist_ok=True)
with open(input_image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
if not run_button:
st.warning("Click the 'Run' button to start the enhancement process.")
if run_button:
# Ensure model weights are downloaded
model_path = ensure_model_weights(model_name)
# Load the model
upsampler = load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id)
# Load the image
img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_UNCHANGED)
if img is None:
st.error("Error loading image. Please try again.")
else:
img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None
try:
if face_enhance:
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
except RuntimeError as error:
st.error(f"Error: {error}")
st.error('If you encounter CUDA out of memory, try to set a smaller tile size.')
else:
# Save and display the output image
output_image_path = os.path.join("temp", "output_image.png")
cv2.imwrite(output_image_path, output)
with col2:
st.write("### Enhanced Image")
st.image(output_image_path, use_column_width=True)
if 'output_image_path' in locals():
st.download_button("Download Enhanced Image", data=open(output_image_path, "rb").read(), file_name="output_image.png", mime="image/png") |