|
|
|
import json |
|
import os |
|
os.system("pip install basicsr==1.3.4.8") |
|
import sys |
|
import os |
|
import io |
|
import argparse |
|
import uuid |
|
import base64 |
|
import logging |
|
import time |
|
import copy |
|
import cv2 |
|
import insightface |
|
import numpy as np |
|
from typing import List, Union |
|
from PIL import Image |
|
from restoration import * |
|
from flask import Flask, request, jsonify, make_response |
|
from waitress import serve |
|
import os |
|
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" |
|
LOG_LEVEL = logging.DEBUG |
|
TMP_PATH = '/tmp/inswapper' |
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
log_path = '' |
|
|
|
|
|
if sys.platform == 'linux': |
|
log_path = '/var/log/' |
|
|
|
logging.basicConfig( |
|
filename=f'{log_path}inswapper.log', |
|
format='%(asctime)s : %(levelname)s : %(message)s', |
|
level=LOG_LEVEL |
|
) |
|
|
|
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) |
|
|
|
|
|
def process_request(request_obj): |
|
try: |
|
logging.debug('Swapping face') |
|
face_swap_timer = Timer() |
|
result_image = face_swap(request_obj['source_image'], request_obj['target_image']) |
|
face_swap_time = face_swap_timer.get_elapsed_time() |
|
logging.info(f'Time taken to swap face: {face_swap_time} seconds') |
|
|
|
response = { |
|
'status': 'ok', |
|
'image': result_image |
|
} |
|
except Exception as e: |
|
logging.error(e) |
|
response = { |
|
'status': 'error', |
|
'msg': 'Face swap failed', |
|
'detail': str(e) |
|
} |
|
|
|
return response |
|
|
|
|
|
class Timer: |
|
def __init__(self): |
|
self.start = time.time() |
|
|
|
def restart(self): |
|
self.start = time.time() |
|
|
|
def get_elapsed_time(self): |
|
end = time.time() |
|
return round(end - self.start, 1) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Inswapper REST API' |
|
) |
|
|
|
parser.add_argument( |
|
'-p', '--port', |
|
help='Port to listen on', |
|
type=int, |
|
default=8090 |
|
) |
|
|
|
parser.add_argument( |
|
'-H', '--host', |
|
help='Host to bind to', |
|
default='0.0.0.0' |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def determine_file_extension(image_data): |
|
try: |
|
if image_data.startswith('/9j/'): |
|
image_extension = '.jpg' |
|
elif image_data.startswith('iVBORw0Kg'): |
|
image_extension = '.png' |
|
else: |
|
|
|
image_extension = '.png' |
|
except Exception as e: |
|
image_extension = '.png' |
|
|
|
return image_extension |
|
|
|
|
|
def write_base64_to_disk(file_b64: str, file_path: str): |
|
with open(file_path, 'wb') as file: |
|
file.write(base64.b64decode(file_b64)) |
|
|
|
|
|
def get_face_swap_model(model_path: str): |
|
model = insightface.model_zoo.get_model(model_path) |
|
return model |
|
|
|
|
|
def get_face_analyser(model_path: str, |
|
det_size=(320, 320)): |
|
face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints") |
|
face_analyser.prepare(ctx_id=0, det_size=det_size) |
|
return face_analyser |
|
|
|
|
|
def get_one_face(face_analyser, |
|
frame:np.ndarray): |
|
face = face_analyser.get(frame) |
|
try: |
|
return min(face, key=lambda x: x.bbox[0]) |
|
except ValueError: |
|
return None |
|
|
|
|
|
def get_many_faces(face_analyser, |
|
frame:np.ndarray): |
|
""" |
|
get faces from left to right by order |
|
""" |
|
try: |
|
face = face_analyser.get(frame) |
|
return sorted(face, key=lambda x: x.bbox[0]) |
|
except IndexError: |
|
return None |
|
|
|
|
|
def swap_face(face_swapper, |
|
source_faces, |
|
target_faces, |
|
source_index, |
|
target_index, |
|
temp_frame): |
|
""" |
|
paste source_face on target image |
|
""" |
|
source_face = source_faces[source_index] |
|
target_face = target_faces[target_index] |
|
|
|
return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) |
|
|
|
|
|
def process(source_img: Union[Image.Image, List], |
|
target_img: Image.Image, |
|
source_indexes: str, |
|
target_indexes: str, |
|
model: str): |
|
|
|
|
|
face_analyser = get_face_analyser(model) |
|
|
|
|
|
model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) |
|
face_swapper = get_face_swap_model(model_path) |
|
|
|
|
|
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
target_faces = get_many_faces(face_analyser, target_img) |
|
num_target_faces = len(target_faces) |
|
num_source_images = len(source_img) |
|
|
|
if target_faces is not None: |
|
temp_frame = copy.deepcopy(target_img) |
|
if isinstance(source_img, list) and num_source_images == num_target_faces: |
|
logging.debug('Replacing the faces in the target image from left to right by order') |
|
for i in range(num_target_faces): |
|
source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) |
|
source_index = i |
|
target_index = i |
|
|
|
if source_faces is None: |
|
raise Exception('No source faces found!') |
|
|
|
temp_frame = swap_face( |
|
face_swapper, |
|
source_faces, |
|
target_faces, |
|
source_index, |
|
target_index, |
|
temp_frame |
|
) |
|
elif num_source_images == 1: |
|
|
|
source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR)) |
|
num_source_faces = len(source_faces) |
|
logging.debug(f'Source faces: {num_source_faces}') |
|
logging.debug(f'Target faces: {num_target_faces}') |
|
|
|
if source_faces is None: |
|
raise Exception('No source faces found!') |
|
|
|
if target_indexes == "-1": |
|
if num_source_faces == 1: |
|
logging.debug('Replacing all faces in target image with the same face from the source image') |
|
num_iterations = num_target_faces |
|
elif num_source_faces < num_target_faces: |
|
logging.debug('There are less faces in the source image than the target image, replacing as many as we can') |
|
num_iterations = num_source_faces |
|
elif num_target_faces < num_source_faces: |
|
logging.debug('There are less faces in the target image than the source image, replacing as many as we can') |
|
num_iterations = num_target_faces |
|
else: |
|
logging.debug('Replacing all faces in the target image with the faces from the source image') |
|
num_iterations = num_target_faces |
|
|
|
for i in range(num_iterations): |
|
source_index = 0 if num_source_faces == 1 else i |
|
target_index = i |
|
|
|
temp_frame = swap_face( |
|
face_swapper, |
|
source_faces, |
|
target_faces, |
|
source_index, |
|
target_index, |
|
temp_frame |
|
) |
|
elif source_indexes == '-1' and target_indexes == '-1': |
|
logging.debug('Replacing specific face(s) in the target image with the face from the source image') |
|
target_indexes = target_indexes.split(',') |
|
source_index = 0 |
|
|
|
for target_index in target_indexes: |
|
target_index = int(target_index) |
|
|
|
temp_frame = swap_face( |
|
face_swapper, |
|
source_faces, |
|
target_faces, |
|
source_index, |
|
target_index, |
|
temp_frame |
|
) |
|
else: |
|
logging.debug('Replacing specific face(s) in the target image with specific face(s) from the source image') |
|
|
|
if source_indexes == "-1": |
|
source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces))) |
|
|
|
if target_indexes == "-1": |
|
target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces))) |
|
|
|
source_indexes = source_indexes.split(',') |
|
target_indexes = target_indexes.split(',') |
|
num_source_faces_to_swap = len(source_indexes) |
|
num_target_faces_to_swap = len(target_indexes) |
|
|
|
if num_source_faces_to_swap > num_source_faces: |
|
raise Exception('Number of source indexes is greater than the number of faces in the source image') |
|
|
|
if num_target_faces_to_swap > num_target_faces: |
|
raise Exception('Number of target indexes is greater than the number of faces in the target image') |
|
|
|
if num_source_faces_to_swap > num_target_faces_to_swap: |
|
num_iterations = num_source_faces_to_swap |
|
else: |
|
num_iterations = num_target_faces_to_swap |
|
|
|
if num_source_faces_to_swap == num_target_faces_to_swap: |
|
for index in range(num_iterations): |
|
source_index = int(source_indexes[index]) |
|
target_index = int(target_indexes[index]) |
|
|
|
if source_index > num_source_faces-1: |
|
raise ValueError(f'Source index {source_index} is higher than the number of faces in the source image') |
|
|
|
if target_index > num_target_faces-1: |
|
raise ValueError(f'Target index {target_index} is higher than the number of faces in the target image') |
|
|
|
temp_frame = swap_face( |
|
face_swapper, |
|
source_faces, |
|
target_faces, |
|
source_index, |
|
target_index, |
|
temp_frame |
|
) |
|
else: |
|
logging.error('Unsupported face configuration') |
|
raise Exception('Unsupported face configuration') |
|
result = temp_frame |
|
else: |
|
logging.error('No target faces found') |
|
raise Exception('No target faces found!') |
|
|
|
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) |
|
return result_image |
|
|
|
|
|
def face_swap(src_img_path, |
|
target_img_path, |
|
source_indexes, |
|
target_indexes, |
|
background_enhance, |
|
face_restore, |
|
face_upsample, |
|
upscale, |
|
codeformer_fidelity, |
|
output_format): |
|
|
|
source_img_paths = src_img_path.split(';') |
|
source_img = [Image.open(img_path) for img_path in source_img_paths] |
|
target_img = Image.open(target_img_path) |
|
|
|
|
|
model = os.path.join(script_dir, 'checkpoints/inswapper_128.onnx') |
|
logging.debug(f'Face swap model: {model}') |
|
|
|
try: |
|
logging.debug('Performing face swap') |
|
result_image = process( |
|
source_img, |
|
target_img, |
|
source_indexes, |
|
target_indexes, |
|
model |
|
) |
|
logging.debug('Face swap complete') |
|
except Exception as e: |
|
raise |
|
|
|
|
|
check_ckpts() |
|
|
|
if face_restore: |
|
|
|
logging.debug('Setting upsampler to RealESRGAN_x2plus') |
|
upsampler = set_realesrgan() |
|
|
|
if torch.cuda.is_available(): |
|
torch_device = 'cuda' |
|
else: |
|
torch_device = 'cpu' |
|
|
|
logging.debug(f'Torch device: {torch_device.upper()}') |
|
device = torch.device(torch_device) |
|
|
|
codeformer_net = ARCH_REGISTRY.get('CodeFormer')( |
|
dim_embd=512, |
|
codebook_size=1024, |
|
n_head=8, |
|
n_layers=9, |
|
connect_list=['32', '64', '128', '256'], |
|
).to(device) |
|
|
|
ckpt_path = os.path.join(script_dir, 'CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth') |
|
logging.debug(f'Loading CodeFormer model: {ckpt_path}') |
|
checkpoint = torch.load(ckpt_path)['params_ema'] |
|
codeformer_net.load_state_dict(checkpoint) |
|
codeformer_net.eval() |
|
result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) |
|
logging.debug('Performing face restoration using CodeFormer') |
|
|
|
try: |
|
result_image = face_restoration( |
|
result_image, |
|
background_enhance, |
|
face_upsample, |
|
upscale, |
|
codeformer_fidelity, |
|
upsampler, |
|
codeformer_net, |
|
device |
|
) |
|
except Exception as e: |
|
raise |
|
|
|
logging.debug('CodeFormer face restoration completed successfully') |
|
result_image = Image.fromarray(result_image) |
|
|
|
output_buffer = io.BytesIO() |
|
result_image.save(output_buffer, format=output_format) |
|
image_data = output_buffer.getvalue() |
|
|
|
return base64.b64encode(image_data).decode('utf-8') |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.errorhandler(400) |
|
def not_found(error): |
|
return make_response(jsonify( |
|
{ |
|
'status': 'error', |
|
'msg': f'Bad Request', |
|
'detail': str(error) |
|
} |
|
), 400) |
|
|
|
|
|
@app.errorhandler(404) |
|
def not_found(error): |
|
return make_response(jsonify( |
|
{ |
|
'status': 'error', |
|
'msg': f'{request.url} not found', |
|
'detail': str(error) |
|
} |
|
), 404) |
|
|
|
|
|
@app.errorhandler(500) |
|
def internal_server_error(error): |
|
return make_response(jsonify( |
|
{ |
|
'status': 'error', |
|
'msg': 'Internal Server Error', |
|
'detail': str(error) |
|
} |
|
), 500) |
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
def ping(): |
|
return make_response(jsonify( |
|
{ |
|
'status': 'ok' |
|
} |
|
), 200) |
|
|
|
|
|
@app.route('/faceswap', methods=['POST']) |
|
def face_swap_api(): |
|
total_timer = Timer() |
|
logging.debug('Received face swap API request') |
|
payload = request.get_json() |
|
|
|
if not os.path.exists(TMP_PATH): |
|
logging.debug(f'Creating temporary directory: {TMP_PATH}') |
|
os.makedirs(TMP_PATH) |
|
|
|
unique_id = uuid.uuid4() |
|
source_image_data = payload['source_image'] |
|
target_image_data = payload['target_image'] |
|
|
|
|
|
source_image = base64.b64decode(source_image_data) |
|
source_file_extension = determine_file_extension(source_image_data) |
|
source_image_path = f'{TMP_PATH}/source_{unique_id}{source_file_extension}' |
|
|
|
|
|
with open(source_image_path, 'wb') as source_file: |
|
source_file.write(source_image) |
|
|
|
|
|
target_image = base64.b64decode(target_image_data) |
|
target_file_extension = determine_file_extension(target_image_data) |
|
target_image_path = f'{TMP_PATH}/target_{unique_id}{target_file_extension}' |
|
|
|
|
|
with open(target_image_path, 'wb') as target_file: |
|
target_file.write(target_image) |
|
|
|
|
|
if 'source_indexes' not in payload: |
|
payload['source_indexes'] = '-1' |
|
|
|
if 'target_indexes' not in payload: |
|
payload['target_indexes'] = '-1' |
|
|
|
if 'background_enhance' not in payload: |
|
payload['background_enhance'] = True |
|
|
|
if 'face_restore' not in payload: |
|
payload['face_restore'] = True |
|
|
|
if 'face_upsample' not in payload: |
|
payload['face_upsample'] = True |
|
|
|
if 'upscale' not in payload: |
|
payload['upscale'] = 1 |
|
|
|
if 'codeformer_fidelity' not in payload: |
|
payload['codeformer_fidelity'] = 0.5 |
|
|
|
if 'output_format' not in payload: |
|
payload['output_format'] = 'JPEG' |
|
|
|
try: |
|
logging.debug(f'Source indexes: {payload["source_indexes"]}') |
|
logging.debug(f'Target indexes: {payload["target_indexes"]}') |
|
logging.debug(f'Background enhance: {payload["background_enhance"]}') |
|
logging.debug(f'Face Restoration: {payload["face_restore"]}') |
|
logging.debug(f'Face Upsampling: {payload["face_upsample"]}') |
|
logging.debug(f'Upscale: {payload["upscale"]}') |
|
logging.debug(f'Codeformer Fidelity: {payload["codeformer_fidelity"]}') |
|
logging.debug(f'Output Format: {payload["output_format"]}') |
|
|
|
result_image = face_swap( |
|
source_image_path, |
|
target_image_path, |
|
payload['source_indexes'], |
|
payload['target_indexes'], |
|
payload['background_enhance'], |
|
payload['face_restore'], |
|
payload['face_upsample'], |
|
payload['upscale'], |
|
payload['codeformer_fidelity'], |
|
payload['output_format'] |
|
) |
|
|
|
status_code = 200 |
|
|
|
response = { |
|
'status': 'ok', |
|
'image': result_image |
|
} |
|
except Exception as e: |
|
logging.error(e) |
|
|
|
response = { |
|
'status': 'error', |
|
'msg': 'Face swap failed', |
|
'detail': str(e) |
|
} |
|
|
|
status_code = 500 |
|
|
|
|
|
os.remove(source_image_path) |
|
os.remove(target_image_path) |
|
|
|
total_time = total_timer.get_elapsed_time() |
|
logging.info(f'Total time taken for face swap API call {total_time} seconds') |
|
|
|
return make_response(jsonify(response), status_code) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = get_args() |
|
|
|
serve( |
|
app, |
|
host=args.host, |
|
port=args.port |
|
) |
|
|