|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import traceback |
|
import time |
|
import torch |
|
import os |
|
import argparse |
|
import mcubes |
|
import trimesh |
|
import numpy as np |
|
from PIL import Image |
|
from glob import glob |
|
from omegaconf import OmegaConf |
|
from tqdm.auto import tqdm |
|
from accelerate.logging import get_logger |
|
|
|
from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image, prepare_gaga_motion_seqs |
|
|
|
|
|
from .base_inferrer import Inferrer |
|
from lam.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics |
|
from lam.utils.logging import configure_logger |
|
from lam.runners import REGISTRY_RUNNERS |
|
from lam.utils.video import images_to_video |
|
from lam.utils.hf_hub import wrap_model_hub |
|
from lam.models.modeling_lam import ModelLAM |
|
from safetensors.torch import load_file |
|
import moviepy.editor as mpy |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def parse_configs(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str) |
|
parser.add_argument('--infer', type=str) |
|
args, unknown = parser.parse_known_args() |
|
|
|
cfg = OmegaConf.create() |
|
cli_cfg = OmegaConf.from_cli(unknown) |
|
|
|
|
|
if os.environ.get('APP_INFER') is not None: |
|
args.infer = os.environ.get('APP_INFER') |
|
if os.environ.get('APP_MODEL_NAME') is not None: |
|
cli_cfg.model_name = os.environ.get('APP_MODEL_NAME') |
|
|
|
if args.config is not None: |
|
cfg = OmegaConf.load(args.config) |
|
cfg_train = OmegaConf.load(args.config) |
|
cfg.source_size = cfg_train.dataset.source_image_res |
|
cfg.render_size = cfg_train.dataset.render_image.high |
|
_relative_path = os.path.join(cfg_train.experiment.parent, cfg_train.experiment.child, os.path.basename(cli_cfg.model_name).split('_')[-1]) |
|
|
|
cfg.save_tmp_dump = os.path.join("exps", 'save_tmp', _relative_path) |
|
cfg.image_dump = os.path.join("exps", 'images', _relative_path) |
|
cfg.video_dump = os.path.join("exps", 'videos', _relative_path) |
|
cfg.mesh_dump = os.path.join("exps", 'meshes', _relative_path) |
|
|
|
if args.infer is not None: |
|
cfg_infer = OmegaConf.load(args.infer) |
|
cfg.merge_with(cfg_infer) |
|
cfg.setdefault("save_tmp_dump", os.path.join("exps", cli_cfg.model_name, 'save_tmp')) |
|
cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, 'images')) |
|
cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, 'videos')) |
|
cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, 'meshes')) |
|
|
|
cfg.motion_video_read_fps = 6 |
|
cfg.merge_with(cli_cfg) |
|
|
|
""" |
|
[required] |
|
model_name: str |
|
image_input: str |
|
export_video: bool |
|
export_mesh: bool |
|
|
|
[special] |
|
source_size: int |
|
render_size: int |
|
video_dump: str |
|
mesh_dump: str |
|
|
|
[default] |
|
render_views: int |
|
render_fps: int |
|
mesh_size: int |
|
mesh_thres: float |
|
frame_size: int |
|
logger: str |
|
""" |
|
|
|
cfg.setdefault('logger', 'INFO') |
|
|
|
|
|
assert cfg.model_name is not None, "model_name is required" |
|
if not os.environ.get('APP_ENABLED', None): |
|
assert cfg.image_input is not None, "image_input is required" |
|
assert cfg.export_video or cfg.export_mesh, \ |
|
"At least one of export_video or export_mesh should be True" |
|
cfg.app_enabled = False |
|
else: |
|
cfg.app_enabled = True |
|
|
|
return cfg |
|
|
|
|
|
def count_parameters_excluding_modules(model, exclude_names=[]): |
|
""" |
|
Counts the number of parameters in a PyTorch model, excluding specified modules by name. |
|
|
|
Parameters: |
|
- model (torch.nn.Module): The PyTorch model instance. |
|
- exclude_names (list of str): List of module names to exclude from the parameter count. |
|
|
|
Returns: |
|
- int: Total number of parameters in the model, excluding specified modules. |
|
""" |
|
total_size_bytes = 0 |
|
total_size_bits = 0 |
|
for name, module in model.named_modules(): |
|
|
|
|
|
if any(exclude_name in name for exclude_name in exclude_names): |
|
continue |
|
|
|
|
|
for param in module.parameters(): |
|
total_size_bytes += param.numel() |
|
if param.is_floating_point(): |
|
total_size_bits += param.numel() |
|
else: |
|
total_size_bits += param.numel() |
|
|
|
|
|
total_size_mb = total_size_bytes / (1024 ** 2) |
|
print("==="*16*3, f"\nTotal number of parameters: {total_size_mb}M", "\n"+"==="*16*3) |
|
print(f"model size: {total_size_bits} / bit | {total_size_bits / 1e6:.2f} / MB") |
|
|
|
return total_size_mb |
|
|
|
|
|
@REGISTRY_RUNNERS.register('infer.lam') |
|
class LAMInferrer(Inferrer): |
|
|
|
EXP_TYPE: str = 'lam' |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.cfg = parse_configs() |
|
""" |
|
configure_logger( |
|
stream_level=self.cfg.logger, |
|
log_level=self.cfg.logger, |
|
) |
|
""" |
|
|
|
self.model: LAMInferrer = self._build_model(self.cfg).to(self.device) |
|
|
|
def _build_model(self, cfg): |
|
""" |
|
from lam.models import model_dict |
|
hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE]) |
|
model = hf_model_cls.from_pretrained(cfg.model_name) |
|
""" |
|
from lam.models import ModelLAM |
|
model = ModelLAM(**cfg.model) |
|
|
|
|
|
|
|
resume = os.path.join(cfg.model_name, "model.safetensors") |
|
print("==="*16*3) |
|
print("loading pretrained weight from:", resume) |
|
if resume.endswith('safetensors'): |
|
ckpt = load_file(resume, device='cpu') |
|
else: |
|
ckpt = torch.load(resume, map_location='cpu') |
|
state_dict = model.state_dict() |
|
for k, v in ckpt.items(): |
|
if k in state_dict: |
|
if state_dict[k].shape == v.shape: |
|
state_dict[k].copy_(v) |
|
else: |
|
print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.") |
|
else: |
|
print(f"WARN] unexpected param {k}: {v.shape}") |
|
print("finish loading pretrained weight from:", resume) |
|
print("==="*16*3) |
|
return model |
|
|
|
def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')): |
|
|
|
canonical_camera_extrinsics = torch.tensor([[ |
|
[1, 0, 0, 0], |
|
[0, 0, -1, -dist_to_center], |
|
[0, 1, 0, 0], |
|
]], dtype=torch.float32, device=device) |
|
canonical_camera_intrinsics = create_intrinsics( |
|
f=0.75, |
|
c=0.5, |
|
device=device, |
|
).unsqueeze(0) |
|
source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) |
|
return source_camera.repeat(batch_size, 1) |
|
|
|
def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')): |
|
|
|
render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device) |
|
render_camera_intrinsics = create_intrinsics( |
|
f=0.75, |
|
c=0.5, |
|
device=device, |
|
).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1) |
|
render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics) |
|
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
def infer_planes(self, image: torch.Tensor, source_cam_dist: float): |
|
N = image.shape[0] |
|
source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device) |
|
planes = self.model.forward_planes(image, source_camera) |
|
assert N == planes.shape[0] |
|
return planes |
|
|
|
def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str): |
|
N = planes.shape[0] |
|
render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device) |
|
render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device) |
|
render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size |
|
render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 0. |
|
|
|
frames = [] |
|
for i in range(0, render_cameras.shape[1], frame_size): |
|
frames.append( |
|
self.model.synthesizer( |
|
planes=planes, |
|
cameras=render_cameras[:, i:i+frame_size], |
|
anchors=render_anchors[:, i:i+frame_size], |
|
resolutions=render_resolutions[:, i:i+frame_size], |
|
bg_colors=render_bg_colors[:, i:i+frame_size], |
|
region_size=render_size, |
|
) |
|
) |
|
|
|
frames = { |
|
k: torch.cat([r[k] for r in frames], dim=1) |
|
for k in frames[0].keys() |
|
} |
|
|
|
os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) |
|
for k, v in frames.items(): |
|
if k == 'images_rgb': |
|
images_to_video( |
|
images=v[0], |
|
output_path=dump_video_path, |
|
fps=render_fps, |
|
gradio_codec=self.cfg.app_enabled, |
|
) |
|
|
|
def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str): |
|
grid_out = self.model.synthesizer.forward_grid( |
|
planes=planes, |
|
grid_size=mesh_size, |
|
) |
|
|
|
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres) |
|
vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
|
|
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) |
|
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() |
|
vtx_colors = (vtx_colors * 255).astype(np.uint8) |
|
|
|
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) |
|
|
|
|
|
os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True) |
|
mesh.export(dump_mesh_path) |
|
|
|
def save_imgs_2_video(self, imgs, v_pth, fps): |
|
img_lst = [imgs[i] for i in range(imgs.shape[0])] |
|
|
|
clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] |
|
|
|
|
|
video = mpy.concatenate_videoclips(clips, method="compose") |
|
|
|
|
|
video.write_videofile(v_pth, fps=fps) |
|
|
|
def infer_single(self, image_path: str, |
|
motion_seqs_dir, |
|
motion_img_dir, |
|
motion_video_read_fps, |
|
export_video: bool, |
|
export_mesh: bool, |
|
dump_tmp_dir:str, |
|
dump_image_dir:str, |
|
dump_video_path: str, |
|
dump_mesh_path: str, |
|
gaga_track_type: str): |
|
source_size = self.cfg.source_size |
|
render_size = self.cfg.render_size |
|
|
|
render_fps = self.cfg.render_fps |
|
|
|
|
|
|
|
|
|
aspect_standard = 1.0/1.0 |
|
motion_img_need_mask = self.cfg.get("motion_img_need_mask", False) |
|
vis_motion = self.cfg.get("vis_motion", False) |
|
save_ply = self.cfg.get("save_ply", False) |
|
save_img = self.cfg.get("save_img", False) |
|
|
|
rendered_bg = 1. |
|
ref_bg = 1. |
|
mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png") |
|
if ref_bg < 1.: |
|
if "VFHQ_TEST" in image_path: |
|
mask_path = image_path.replace("/VFHQ_TEST/", "/mask/").replace("/images/", "/mask/").replace(".png", ".jpg") |
|
else: |
|
mask_path = image_path.replace("/vfhq_test_nooffset_export/", "/mask/").replace("/images/", "/mask/").replace(".png", ".jpg") |
|
if not os.path.exists(mask_path): |
|
print("Warning: Mask path not exists:", mask_path) |
|
mask_path = None |
|
else: |
|
print("load mask from:", mask_path) |
|
|
|
|
|
if "hdtf" in image_path: |
|
uid = image_path.split('/')[-3] |
|
split0 = uid.replace(uid.split('_')[-1], '0') |
|
print("==="*16*3, "\n"+image_path, uid, split0) |
|
image_path = image_path.replace(uid, split0) |
|
mask_path = mask_path.replace(uid, split0) |
|
print(image_path, "\n"+"==="*16*3) |
|
print(mask_path, "\n"+"==="*16*3) |
|
if hasattr(self.cfg.model, "use_albedo_input") and (self.cfg.model.get("use_albedo_input", False)): |
|
image_path = image_path.replace("/images/", "/images_hydelight/") |
|
image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=ref_bg, |
|
max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0], |
|
render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True) |
|
|
|
save_ref_img_path = os.path.join(dump_tmp_dir, "refer_" + os.path.basename(image_path)) |
|
vis_ref_img = (image[0].permute(1, 2 ,0).cpu().detach().numpy() * 255).astype(np.uint8) |
|
Image.fromarray(vis_ref_img).save(save_ref_img_path) |
|
|
|
test_sample=self.cfg.get("test_sample", True) |
|
|
|
if gaga_track_type == "": |
|
print("==="*16*3, "\nuse vhap tracked results!", "\n"+"==="*16*3) |
|
src = image_path.split('/')[-3] |
|
driven = motion_seqs_dir.split('/')[-2] |
|
src_driven = [src, driven] |
|
motion_seq = prepare_motion_seqs(motion_seqs_dir, motion_img_dir, save_root=dump_tmp_dir, fps=motion_video_read_fps, |
|
bg_color=rendered_bg, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0], |
|
render_image_res=render_size, multiply=16, |
|
need_mask=motion_img_need_mask, vis_motion=vis_motion, |
|
shape_param=shape_param, test_sample=test_sample, cross_id=self.cfg.get("cross_id", False), src_driven=src_driven) |
|
else: |
|
print("==="*16*3, "\nuse gaga tracked results:", gaga_track_type, "\n"+"==="*16*3) |
|
motion_seq = prepare_gaga_motion_seqs(motion_seqs_dir, motion_img_dir, save_root=dump_tmp_dir, fps=motion_video_read_fps, |
|
bg_color=rendered_bg, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0], |
|
render_image_res=render_size, multiply=16, |
|
need_mask=motion_img_need_mask, vis_motion=vis_motion, |
|
shape_param=shape_param, test_sample=test_sample, gaga_track_type=gaga_track_type) |
|
|
|
|
|
|
|
motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0) |
|
|
|
start_time = time.time() |
|
device="cuda" |
|
dtype=torch.float32 |
|
|
|
self.model.to(dtype) |
|
print("start to inference...................") |
|
with torch.no_grad(): |
|
|
|
res = self.model.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None, |
|
render_c2ws=motion_seq["render_c2ws"].to(device), |
|
render_intrs=motion_seq["render_intrs"].to(device), |
|
render_bg_colors=motion_seq["render_bg_colors"].to(device), |
|
flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()}) |
|
|
|
print(f"time elapsed: {time.time() - start_time}") |
|
rgb = res["comp_rgb"].detach().cpu().numpy() |
|
rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8) |
|
only_pred = rgb |
|
if vis_motion: |
|
|
|
import cv2 |
|
vis_ref_img = np.tile(cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :], (rgb.shape[0], 1, 1, 1)) |
|
blend_ratio = 0.7 |
|
blend_res = ((1 - blend_ratio) * rgb + blend_ratio * motion_seq["vis_motion_render"]).astype(np.uint8) |
|
|
|
rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2) |
|
|
|
os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) |
|
|
|
self.save_imgs_2_video(rgb, dump_video_path, render_fps) |
|
if save_img and dump_image_dir is not None: |
|
for i in range(rgb.shape[0]): |
|
save_file = os.path.join(dump_image_dir, f"{i:04d}.png") |
|
Image.fromarray(only_pred[i]).save(save_file) |
|
if save_ply and dump_mesh_path is not None: |
|
res["3dgs"][i][0][0].save_ply(os.path.join(dump_image_dir, f"{i:04d}.ply")) |
|
|
|
dump_cano_dir = "./exps/cano_gs/" |
|
if not os.path.exists(dump_cano_dir): |
|
os.system(f"mkdir -p {dump_cano_dir}") |
|
cano_ply_pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + ".ply") |
|
|
|
|
|
cano_ply_pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + "_gs_offset.ply") |
|
res['cano_gs_lst'][0].save_ply(cano_ply_pth, rgb2sh=False, offset2xyz=True, albedo2rgb=False) |
|
|
|
|
|
def save_color_points(points, colors, sv_pth, sv_fd="debug_vis/dataloader/"): |
|
points = points.squeeze().detach().cpu().numpy() |
|
colors = colors.squeeze().detach().cpu().numpy() |
|
sv_pth = os.path.join(sv_fd, sv_pth) |
|
if not os.path.exists(sv_fd): |
|
os.system(f"mkdir -p {sv_fd}") |
|
with open(sv_pth, 'w') as of: |
|
for point, color in zip(points, colors): |
|
print('v', point[0], point[1], point[2], color[0], color[1], color[2], file=of) |
|
|
|
|
|
save_color_points(res['cano_gs_lst'][0].xyz, res["cano_gs_lst"][0].shs[:, 0, :], "framework_img.obj", sv_fd=dump_cano_dir) |
|
|
|
|
|
import trimesh |
|
vtxs = res['cano_gs_lst'][0].xyz - res['cano_gs_lst'][0].offset |
|
vtxs = vtxs.detach().cpu().numpy() |
|
faces = self.model.renderer.flame_model.faces.detach().cpu().numpy() |
|
mesh = trimesh.Trimesh(vertices=vtxs, faces=faces) |
|
mesh.export(os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + '_shaped_mesh.obj')) |
|
|
|
|
|
import lam.models.rendering.utils.mesh_utils as mesh_utils |
|
vtxs = res['cano_gs_lst'][0].xyz.detach().cpu() |
|
faces = self.model.renderer.flame_model.faces.detach().cpu() |
|
colors = res['cano_gs_lst'][0].shs.squeeze(1).detach().cpu() |
|
pth = os.path.join(dump_cano_dir, os.path.basename(dump_image_dir) + '_textured_mesh.obj') |
|
print("Save textured mesh to:", pth) |
|
mesh_utils.save_obj(pth, vtxs, faces, textures=colors, texture_type="vertex") |
|
|
|
|
|
|
|
|
|
|
|
def infer(self): |
|
image_paths = [] |
|
|
|
if os.path.isfile(self.cfg.image_input): |
|
omit_prefix = os.path.dirname(self.cfg.image_input) |
|
image_paths = [self.cfg.image_input] |
|
else: |
|
|
|
|
|
image_paths = glob(os.path.join(self.cfg.image_input, "*.jpg")) |
|
omit_prefix = self.cfg.image_input |
|
|
|
""" |
|
# image_paths = glob("train_data/demo_export/DEMOVIDEO/*/images/00000_00.png") |
|
image_paths = glob("train_data/vfhq_test/VFHQ_TEST/Clip+G0DGRma_p48+P0+C0+F11208-11383/images/00000_00.png") |
|
image_paths = glob("train_data/SIDE_FACE/*/images/00000_00.png") |
|
image_paths = glob("train_data/vfhq_test/VFHQ_TEST/*/images/00000_00.png") |
|
|
|
import json |
|
# uids = json.load(open("./train_data/vfhq_vhap/selected_id.json", 'r'))["self_id"] |
|
# image_paths = [os.path.join("train_data/vfhq_test/VFHQ_TEST/", uid, "images/00000_00.png") for uid in uids] |
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
# image_paths = glob("train_data/nersemble_vhap/export/017_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png") |
|
# image_paths = glob("train_data/nersemble_vhap/export/374_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png") |
|
image_paths = glob("train_data/nersemble_vhap/export/375_SEN-01-cramp_small_danger_v16_DS4_whiteBg_staticOffset_maskBelowLine/images/00000_00.png") |
|
|
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
""" |
|
|
|
|
|
|
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
|
|
|
|
print(len(image_paths), image_paths) |
|
|
|
|
|
|
|
image_paths = ["train_data/vfhq_test/VFHQ_TEST/Clip+G0DGRma_p48+P0+C0+F11208-11383/images/00000_00.png"] |
|
|
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
|
|
uids = ['Clip+1qf8dZpLED0+P2+C1+F5731-5855', 'Clip+8vcxTHoDadk+P3+C0+F27918-28036', 'Clip+gsHu2fb3aj0+P0+C0+F17563-17742'] |
|
image_paths = ["train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png".replace("*", item) for item in uids] |
|
|
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
|
|
image_paths = glob("train_data/vfhq_test/vfhq_test_nooffset_export/*/images/00000_00.png") |
|
|
|
image_paths = glob("train_data/test_2w_cases/*/images/00000_00.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "hdtf" in image_paths[0]: |
|
image_paths = image_paths[self.cfg.get("rank", 0)::self.cfg.get("nodes", 1)] |
|
|
|
gaga_track_type = self.cfg.get("gaga_track_type", "") |
|
if gaga_track_type is None: |
|
gaga_track_type = "" |
|
print("==="*16*3, "\nUse gaga_track_type:", gaga_track_type, "\n"+"==="*16*3) |
|
|
|
if self.cfg.get("cross_id", False): |
|
import json |
|
cross_id_lst = json.load(open("train_data/Cross-identity-info.json", 'r')) |
|
src2driven = {item["src"]: item["driven"] for item in cross_id_lst} |
|
|
|
for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process): |
|
try: |
|
|
|
motion_seqs_dir = self.cfg.motion_seqs_dir |
|
if "VFHQ_TEST" in image_path or "vfhq_test_nooffset_export" in image_path or "hdtf" in image_path: |
|
motion_seqs_dir = os.path.join(*image_path.split('/')[:-2], "flame_param") |
|
|
|
if self.cfg.get("cross_id", False): |
|
src = motion_seqs_dir.split('/')[-2] |
|
driven = src2driven[src] |
|
motion_seqs_dir = motion_seqs_dir.replace(src, driven) |
|
|
|
print("motion_seqs_dir:", motion_seqs_dir) |
|
|
|
image_name = os.path.basename(image_path) |
|
uid = image_name.split('.')[0] |
|
subdir_path = os.path.dirname(image_path).replace(omit_prefix, '') |
|
subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path |
|
|
|
subdir_path = gaga_track_type |
|
if self.cfg.get("cross_id", False): |
|
subdir_path = "cross_id" |
|
print("==="*16*3, "\n"+ "subdir_path:", subdir_path, "\n"+"==="*16*3) |
|
uid = os.path.basename(os.path.dirname(os.path.dirname(image_path))) |
|
print("subdir_path and uid:", subdir_path, uid) |
|
dump_video_path = os.path.join( |
|
self.cfg.video_dump, |
|
subdir_path, |
|
f'{uid}.mp4', |
|
) |
|
dump_image_dir = os.path.join( |
|
self.cfg.image_dump, |
|
subdir_path, |
|
f'{uid}' |
|
) |
|
dump_tmp_dir = os.path.join( |
|
self.cfg.image_dump, |
|
subdir_path, |
|
"tmp_res" |
|
) |
|
dump_mesh_path = os.path.join( |
|
self.cfg.mesh_dump, |
|
subdir_path, |
|
|
|
) |
|
os.makedirs(dump_image_dir, exist_ok=True) |
|
os.makedirs(dump_tmp_dir, exist_ok=True) |
|
os.makedirs(dump_mesh_path, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
self.infer_single( |
|
image_path, |
|
motion_seqs_dir=motion_seqs_dir, |
|
motion_img_dir=self.cfg.motion_img_dir, |
|
motion_video_read_fps=self.cfg.motion_video_read_fps, |
|
export_video=self.cfg.export_video, |
|
export_mesh=self.cfg.export_mesh, |
|
dump_tmp_dir=dump_tmp_dir, |
|
dump_image_dir=dump_image_dir, |
|
dump_video_path=dump_video_path, |
|
dump_mesh_path=dump_mesh_path, |
|
gaga_track_type=gaga_track_type |
|
) |
|
except: |
|
traceback.print_exc() |
|
|