Spaces:
Build error
Build error
import gradio as gr | |
import os | |
import cv2 | |
import shutil | |
import sys | |
from subprocess import call | |
import torch | |
import numpy as np | |
from skimage import color | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import torch | |
os.system("pip install dlib") | |
os.system('bash setup.sh') | |
def lab2rgb(L, AB): | |
"""Convert an Lab tensor image to a RGB numpy output | |
Parameters: | |
L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) | |
AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) | |
Returns: | |
rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) | |
""" | |
AB2 = AB * 110.0 | |
L2 = (L + 1.0) * 50.0 | |
Lab = torch.cat([L2, AB2], dim=1) | |
Lab = Lab[0].data.cpu().float().numpy() | |
Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) | |
rgb = color.lab2rgb(Lab) * 255 | |
return rgb | |
def get_transform(params=None, grayscale=False, method=Image.BICUBIC): | |
#params | |
preprocess = 'resize_and_crop' | |
load_size = 256 | |
crop_size = 256 | |
transform_list = [] | |
if grayscale: | |
transform_list.append(transforms.Grayscale(1)) | |
if 'resize' in preprocess: | |
osize = [load_size, load_size] | |
transform_list.append(transforms.Resize(osize, method)) | |
if 'crop' in preprocess: | |
if params is None: | |
transform_list.append(transforms.RandomCrop(crop_size)) | |
return transforms.Compose(transform_list) | |
def inferColorization(img,model_name): | |
print(model_name) | |
if model_name == "Pix2Pix_resnet9b": | |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b') | |
elif model_name == "Pix2Pix_unet256": | |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256') | |
elif model_name == "Deoldify": | |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization') | |
transform_list = [ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
] | |
transform = transforms.Compose(transform_list) | |
#a = transforms.ToTensor()(a) | |
img = img.convert('L') | |
img = transform(img) | |
img = torch.unsqueeze(img, 0) | |
result = model(img) | |
result = result[0].detach() | |
result = (result +1)/2.0 | |
#img = transforms.Grayscale(3)(img) | |
#img = transforms.ToTensor()(img) | |
#img = torch.unsqueeze(img, 0) | |
#result = model(img) | |
#result = torch.clip(result, min=0, max=1) | |
image_pil = transforms.ToPILImage()(result) | |
return image_pil | |
transform_seq = get_transform() | |
im = transform_seq(img) | |
im = np.array(img) | |
lab = color.rgb2lab(im).astype(np.float32) | |
lab_t = transforms.ToTensor()(lab) | |
A = lab_t[[0], ...] / 50.0 - 1.0 | |
B = lab_t[[1, 2], ...] / 110.0 | |
#data = {'A': A, 'B': B, 'A_paths': "", 'B_paths': ""} | |
L = torch.unsqueeze(A, 0) | |
#print(L.shape) | |
ab = model(L) | |
Lab = lab2rgb(L, ab).astype(np.uint8) | |
image_pil = Image.fromarray(Lab) | |
#image_pil.save('test.png') | |
#print(Lab.shape) | |
return image_pil | |
def colorizaition(image,model_name): | |
image = Image.fromarray(image) | |
result = inferColorization(image,model_name) | |
return result | |
def run_cmd(command): | |
try: | |
call(command, shell=True) | |
except KeyboardInterrupt: | |
print("Process interrupted") | |
sys.exit(1) | |
def run(image): | |
if os.path.isdir("Temp"): | |
shutil.rmtree("Temp") | |
os.makedirs("Temp") | |
os.makedirs("Temp/input") | |
print(type(image)) | |
cv2.imwrite("Temp/input/input_img.png", image) | |
command = ("python run.py --input_folder " | |
+ "Temp/input" | |
+ " --output_folder " | |
+ "Temp" | |
+ " --GPU " | |
+ "-1" | |
+ " --with_scratch") | |
run_cmd(command) | |
result_restoration = Image.open("Temp/final_output/input_img.png") | |
shutil.rmtree("Temp") | |
result_colorization = inferColorization(result_restoration,"Deoldify") | |
return result_colorization | |
examples = [['example/1.jpeg'],['example/2.jpg'],['example/3.jpg'],['example/4.jpg']] | |
iface = gr.Interface(fn=run, inputs="image", outputs="image",examples=examples).launch(debug=True,share=False) |