|
import os |
|
from dataclasses import dataclass, field |
|
from collections import defaultdict |
|
try: |
|
from diff_gaussian_rasterization_wda import GaussianRasterizationSettings, GaussianRasterizer |
|
except: |
|
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer |
|
from plyfile import PlyData, PlyElement |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import math |
|
import copy |
|
from diffusers.utils import is_torch_version |
|
from lam.models.rendering.flame_model.flame import FlameHeadSubdivided |
|
from lam.models.transformer import TransformerDecoder |
|
from pytorch3d.transforms import matrix_to_quaternion |
|
from lam.models.rendering.utils.typing import * |
|
from lam.models.rendering.utils.utils import trunc_exp, MLP |
|
from lam.models.rendering.gaussian_model import GaussianModel |
|
from einops import rearrange, repeat |
|
from pytorch3d.ops.points_normals import estimate_pointcloud_normals |
|
os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
from pytorch3d.structures import Meshes, Pointclouds |
|
from pytorch3d.renderer import ( |
|
AmbientLights, |
|
PerspectiveCameras, |
|
SoftSilhouetteShader, |
|
SoftPhongShader, |
|
RasterizationSettings, |
|
MeshRenderer, |
|
MeshRendererWithFragments, |
|
MeshRasterizer, |
|
TexturesVertex, |
|
) |
|
from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend |
|
import lam.models.rendering.utils.mesh_utils as mesh_utils |
|
from lam.models.rendering.utils.point_utils import depth_to_normal |
|
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes |
|
|
|
inverse_sigmoid = lambda x: np.log(x / (1 - x)) |
|
|
|
|
|
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): |
|
Rt = np.zeros((4, 4)) |
|
Rt[:3, :3] = R.transpose() |
|
Rt[:3, 3] = t |
|
Rt[3, 3] = 1.0 |
|
|
|
C2W = np.linalg.inv(Rt) |
|
cam_center = C2W[:3, 3] |
|
cam_center = (cam_center + translate) * scale |
|
C2W[:3, 3] = cam_center |
|
Rt = np.linalg.inv(C2W) |
|
return np.float32(Rt) |
|
|
|
def getProjectionMatrix(znear, zfar, fovX, fovY): |
|
tanHalfFovY = math.tan((fovY / 2)) |
|
tanHalfFovX = math.tan((fovX / 2)) |
|
|
|
top = tanHalfFovY * znear |
|
bottom = -top |
|
right = tanHalfFovX * znear |
|
left = -right |
|
|
|
P = torch.zeros(4, 4) |
|
|
|
z_sign = 1.0 |
|
|
|
P[0, 0] = 2.0 * znear / (right - left) |
|
P[1, 1] = 2.0 * znear / (top - bottom) |
|
P[0, 2] = (right + left) / (right - left) |
|
P[1, 2] = (top + bottom) / (top - bottom) |
|
P[3, 2] = z_sign |
|
P[2, 2] = z_sign * zfar / (zfar - znear) |
|
P[2, 3] = -(zfar * znear) / (zfar - znear) |
|
return P |
|
|
|
def intrinsic_to_fov(intrinsic, w, h): |
|
fx, fy = intrinsic[0, 0], intrinsic[1, 1] |
|
fov_x = 2 * torch.arctan2(w, 2 * fx) |
|
fov_y = 2 * torch.arctan2(h, 2 * fy) |
|
return fov_x, fov_y |
|
|
|
|
|
class Camera: |
|
def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None: |
|
self.FoVx = FoVx |
|
self.FoVy = FoVy |
|
self.height = int(height) |
|
self.width = int(width) |
|
self.world_view_transform = w2c.transpose(0, 1) |
|
self.intrinsic = intrinsic |
|
|
|
self.zfar = 100.0 |
|
self.znear = 0.01 |
|
|
|
self.trans = trans |
|
self.scale = scale |
|
|
|
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device) |
|
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) |
|
self.camera_center = self.world_view_transform.inverse()[3, :3] |
|
|
|
@staticmethod |
|
def from_c2w(c2w, intrinsic, height, width): |
|
w2c = torch.inverse(c2w) |
|
FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device)) |
|
return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width) |
|
|
|
|
|
class GSLayer(nn.Module): |
|
def __init__(self, in_channels, use_rgb, |
|
clip_scaling=0.2, |
|
init_scaling=-5.0, |
|
scale_sphere=False, |
|
init_density=0.1, |
|
sh_degree=None, |
|
xyz_offset=True, |
|
restrict_offset=True, |
|
xyz_offset_max_step=None, |
|
fix_opacity=False, |
|
fix_rotation=False, |
|
use_fine_feat=False, |
|
pred_res=False, |
|
): |
|
super().__init__() |
|
self.clip_scaling = clip_scaling |
|
self.use_rgb = use_rgb |
|
self.restrict_offset = restrict_offset |
|
self.xyz_offset = xyz_offset |
|
self.xyz_offset_max_step = xyz_offset_max_step |
|
self.fix_opacity = fix_opacity |
|
self.fix_rotation = fix_rotation |
|
self.use_fine_feat = use_fine_feat |
|
self.scale_sphere = scale_sphere |
|
self.pred_res = pred_res |
|
|
|
self.attr_dict ={ |
|
"shs": (sh_degree + 1) ** 2 * 3, |
|
"scaling": 3 if not scale_sphere else 1, |
|
"xyz": 3, |
|
"opacity": None, |
|
"rotation": None |
|
} |
|
if not self.fix_opacity: |
|
self.attr_dict["opacity"] = 1 |
|
if not self.fix_rotation: |
|
self.attr_dict["rotation"] = 4 |
|
|
|
self.out_layers = nn.ModuleDict() |
|
for key, out_ch in self.attr_dict.items(): |
|
if out_ch is None: |
|
layer = nn.Identity() |
|
else: |
|
if key == "shs" and use_rgb: |
|
out_ch = 3 |
|
if key == "shs": |
|
shs_out_ch = out_ch |
|
if pred_res: |
|
layer = nn.Linear(in_channels+out_ch, out_ch) |
|
else: |
|
layer = nn.Linear(in_channels, out_ch) |
|
|
|
if not (key == "shs" and use_rgb): |
|
if key == "opacity" and self.fix_opacity: |
|
pass |
|
elif key == "rotation" and self.fix_rotation: |
|
pass |
|
else: |
|
nn.init.constant_(layer.weight, 0) |
|
nn.init.constant_(layer.bias, 0) |
|
if key == "scaling": |
|
nn.init.constant_(layer.bias, init_scaling) |
|
elif key == "rotation": |
|
if not self.fix_rotation: |
|
nn.init.constant_(layer.bias, 0) |
|
nn.init.constant_(layer.bias[0], 1.0) |
|
elif key == "opacity": |
|
if not self.fix_opacity: |
|
nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) |
|
self.out_layers[key] = layer |
|
|
|
if self.use_fine_feat: |
|
fine_shs_layer = nn.Linear(in_channels, shs_out_ch) |
|
nn.init.constant_(fine_shs_layer.weight, 0) |
|
nn.init.constant_(fine_shs_layer.bias, 0) |
|
self.out_layers["fine_shs"] = fine_shs_layer |
|
|
|
def forward(self, x, pts, x_fine=None, gs_raw_attr=None, ret_raw=False, vtx_sym_idxs=None): |
|
assert len(x.shape) == 2 |
|
ret = {} |
|
if ret_raw: |
|
raw_attr = {} |
|
ori_x = x |
|
for k in self.attr_dict: |
|
|
|
if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity", "rotation"]: |
|
|
|
|
|
x = ori_x[vtx_sym_idxs.to(x.device), :] |
|
else: |
|
x = ori_x |
|
layer =self.out_layers[k] |
|
if self.pred_res and (not self.fix_opacity or k != "opacity") and (not self.fix_rotation or k != "rotation"): |
|
v = layer(torch.cat([gs_raw_attr[k], x], dim=-1)) |
|
v = gs_raw_attr[k] + v |
|
else: |
|
v = layer(x) |
|
if ret_raw: |
|
raw_attr[k] = v |
|
if k == "rotation": |
|
if self.fix_rotation: |
|
v = matrix_to_quaternion(torch.eye(3).type_as(x)[None,: , :].repeat(x.shape[0], 1, 1)) |
|
else: |
|
|
|
v = torch.nn.functional.normalize(v) |
|
elif k == "scaling": |
|
v = trunc_exp(v) |
|
if self.scale_sphere: |
|
assert v.shape[-1] == 1 |
|
v = torch.cat([v, v, v], dim=-1) |
|
if self.clip_scaling is not None: |
|
v = torch.clamp(v, min=0, max=self.clip_scaling) |
|
elif k == "opacity": |
|
if self.fix_opacity: |
|
v = torch.ones_like(x)[..., 0:1] |
|
else: |
|
v = torch.sigmoid(v) |
|
elif k == "shs": |
|
if self.use_rgb: |
|
v[..., :3] = torch.sigmoid(v[..., :3]) |
|
if self.use_fine_feat: |
|
v_fine = self.out_layers["fine_shs"](x_fine) |
|
v_fine = torch.tanh(v_fine) |
|
v = v + v_fine |
|
else: |
|
if self.use_fine_feat: |
|
v_fine = self.out_layers["fine_shs"](x_fine) |
|
v = v + v_fine |
|
v = torch.reshape(v, (v.shape[0], -1, 3)) |
|
elif k == "xyz": |
|
|
|
if self.restrict_offset: |
|
max_step = self.xyz_offset_max_step |
|
v = (torch.sigmoid(v) - 0.5) * max_step |
|
if self.xyz_offset: |
|
pass |
|
else: |
|
assert NotImplementedError |
|
ret["offset"] = v |
|
v = pts + v |
|
ret[k] = v |
|
|
|
if ret_raw: |
|
return GaussianModel(**ret), raw_attr |
|
else: |
|
return GaussianModel(**ret) |
|
|
|
|
|
class PointEmbed(nn.Module): |
|
def __init__(self, hidden_dim=48, dim=128): |
|
super().__init__() |
|
|
|
assert hidden_dim % 6 == 0 |
|
|
|
self.embedding_dim = hidden_dim |
|
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
|
e = torch.stack([ |
|
torch.cat([e, torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), e, |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6), e]), |
|
]) |
|
self.register_buffer('basis', e) |
|
|
|
self.mlp = nn.Linear(self.embedding_dim+3, dim) |
|
self.norm = nn.LayerNorm(dim) |
|
|
|
@staticmethod |
|
def embed(input, basis): |
|
projections = torch.einsum( |
|
'bnd,de->bne', input, basis) |
|
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
|
return embeddings |
|
|
|
def forward(self, input): |
|
|
|
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) |
|
embed = self.norm(embed) |
|
return embed |
|
|
|
|
|
class CrossAttnBlock(nn.Module): |
|
""" |
|
Transformer block that takes in a cross-attention condition. |
|
Designed for SparseLRM architecture. |
|
""" |
|
|
|
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float=None, |
|
attn_drop: float = 0., attn_bias: bool = False, |
|
mlp_ratio: float = 4., mlp_drop: float = 0., feedforward=False): |
|
super().__init__() |
|
|
|
|
|
|
|
self.norm_q = nn.Identity() |
|
self.norm_k = nn.Identity() |
|
|
|
self.cross_attn = nn.MultiheadAttention( |
|
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, |
|
dropout=attn_drop, bias=attn_bias, batch_first=True) |
|
|
|
self.mlp = None |
|
if feedforward: |
|
self.norm2 = nn.LayerNorm(inner_dim, eps=eps) |
|
self.self_attn = nn.MultiheadAttention( |
|
embed_dim=inner_dim, num_heads=num_heads, |
|
dropout=attn_drop, bias=attn_bias, batch_first=True) |
|
self.norm3 = nn.LayerNorm(inner_dim, eps=eps) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), |
|
nn.GELU(), |
|
nn.Dropout(mlp_drop), |
|
nn.Linear(int(inner_dim * mlp_ratio), inner_dim), |
|
nn.Dropout(mlp_drop), |
|
) |
|
|
|
def forward(self, x, cond): |
|
|
|
|
|
x = self.cross_attn(self.norm_q(x), self.norm_k(cond), cond, need_weights=False)[0] |
|
if self.mlp is not None: |
|
before_sa = self.norm2(x) |
|
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] |
|
x = x + self.mlp(self.norm3(x)) |
|
return x |
|
|
|
|
|
class DecoderCrossAttn(nn.Module): |
|
def __init__(self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None): |
|
super().__init__() |
|
self.query_dim = query_dim |
|
self.context_dim = context_dim |
|
|
|
self.cross_attn = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, |
|
num_heads=num_heads, feedforward=mlp, |
|
eps=1e-5) |
|
self.decode_with_extra_info = decode_with_extra_info |
|
if decode_with_extra_info is not None: |
|
if decode_with_extra_info["type"] == "dinov2p14_feat": |
|
context_dim = decode_with_extra_info["cond_dim"] |
|
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, |
|
num_heads=num_heads, feedforward=False, eps=1e-5) |
|
elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat": |
|
from lam.models.encoders.dinov2_wrapper import Dinov2Wrapper |
|
self.encoder = Dinov2Wrapper(model_name='dinov2_vits14_reg', freeze=False, encoder_feat_dim=384) |
|
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=384, |
|
num_heads=num_heads, feedforward=False, |
|
eps=1e-5) |
|
elif decode_with_extra_info["type"] == "decoder_resnet18_feat": |
|
from lam.models.encoders.xunet_wrapper import XnetWrapper |
|
self.encoder = XnetWrapper(model_name='resnet18', freeze=False, encoder_feat_dim=64) |
|
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=64, |
|
num_heads=num_heads, feedforward=False, |
|
eps=1e-5) |
|
|
|
def resize_image(self, image, multiply): |
|
B, _, H, W = image.shape |
|
new_h, new_w = math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply |
|
image = F.interpolate(image, (new_h, new_w), align_corners=True, mode="bilinear") |
|
return image |
|
|
|
def forward(self, pcl_query, pcl_latent, extra_info=None): |
|
out = self.cross_attn(pcl_query, pcl_latent) |
|
if self.decode_with_extra_info is not None: |
|
out_dict = {} |
|
out_dict["coarse"] = out |
|
if self.decode_with_extra_info["type"] == "dinov2p14_feat": |
|
out = self.cross_attn_color(out, extra_info["image_feats"]) |
|
out_dict["fine"] = out |
|
return out_dict |
|
elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat": |
|
img_feat = self.encoder(extra_info["image"]) |
|
out = self.cross_attn_color(out, img_feat) |
|
out_dict["fine"] = out |
|
return out_dict |
|
elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat": |
|
image = extra_info["image"] |
|
image = self.resize_image(image, multiply=32) |
|
img_feat = self.encoder(image) |
|
out = self.cross_attn_color(out, img_feat) |
|
out_dict["fine"] = out |
|
return out_dict |
|
return out |
|
|
|
|
|
class GS3DRenderer(nn.Module): |
|
def __init__(self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, |
|
use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, |
|
expr_param_dim, shape_param_dim, |
|
clip_scaling=0.2, |
|
scale_sphere=False, |
|
skip_decoder=False, |
|
fix_opacity=False, |
|
fix_rotation=False, |
|
decode_with_extra_info=None, |
|
gradient_checkpointing=False, |
|
add_teeth=True, |
|
teeth_bs_flag=False, |
|
oral_mesh_flag=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
print(f"#########scale sphere:{scale_sphere}, add_teeth:{add_teeth}") |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.skip_decoder = skip_decoder |
|
self.smpl_type = smpl_type |
|
assert self.smpl_type == "flame" |
|
self.sym_rend2 = True |
|
self.teeth_bs_flag = teeth_bs_flag |
|
self.oral_mesh_flag = oral_mesh_flag |
|
self.render_rgb = kwargs.get("render_rgb", True) |
|
print("==="*16*3, "\n Render rgb:", self.render_rgb, "\n"+"==="*16*3) |
|
|
|
self.scaling_modifier = 1.0 |
|
self.sh_degree = sh_degree |
|
if use_rgb: |
|
self.sh_degree = 0 |
|
|
|
use_rgb = use_rgb |
|
|
|
self.flame_model = FlameHeadSubdivided( |
|
300, |
|
100, |
|
add_teeth=add_teeth, |
|
add_shoulder=False, |
|
flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl', |
|
flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy", |
|
flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj", |
|
flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl", |
|
subdivide_num=subdivide_num, |
|
teeth_bs_flag=teeth_bs_flag, |
|
oral_mesh_flag=oral_mesh_flag |
|
) |
|
|
|
if not self.skip_decoder: |
|
self.pcl_embed = PointEmbed(dim=query_dim) |
|
|
|
self.mlp_network_config = mlp_network_config |
|
if self.mlp_network_config is not None: |
|
self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config) |
|
|
|
init_scaling = -5.0 |
|
self.gs_net = GSLayer(in_channels=query_dim, |
|
use_rgb=use_rgb, |
|
sh_degree=self.sh_degree, |
|
clip_scaling=clip_scaling, |
|
scale_sphere=scale_sphere, |
|
init_scaling=init_scaling, |
|
init_density=0.1, |
|
xyz_offset=True, |
|
restrict_offset=True, |
|
xyz_offset_max_step=xyz_offset_max_step, |
|
fix_opacity=fix_opacity, |
|
fix_rotation=fix_rotation, |
|
use_fine_feat=True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False, |
|
) |
|
|
|
def forward_single_view(self, |
|
gs: GaussianModel, |
|
viewpoint_camera: Camera, |
|
background_color: Optional[Float[Tensor, "3"]], |
|
): |
|
|
|
screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0 |
|
try: |
|
screenspace_points.retain_grad() |
|
except: |
|
pass |
|
|
|
bg_color = background_color |
|
|
|
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) |
|
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) |
|
|
|
GSRSettings = GaussianRasterizationSettings |
|
GSR = GaussianRasterizer |
|
|
|
raster_settings = GSRSettings( |
|
image_height=int(viewpoint_camera.height), |
|
image_width=int(viewpoint_camera.width), |
|
tanfovx=tanfovx, |
|
tanfovy=tanfovy, |
|
bg=bg_color, |
|
scale_modifier=self.scaling_modifier, |
|
viewmatrix=viewpoint_camera.world_view_transform, |
|
projmatrix=viewpoint_camera.full_proj_transform.float(), |
|
sh_degree=self.sh_degree, |
|
campos=viewpoint_camera.camera_center, |
|
prefiltered=False, |
|
debug=False |
|
) |
|
|
|
rasterizer = GSR(raster_settings=raster_settings) |
|
|
|
means3D = gs.xyz |
|
means2D = screenspace_points |
|
opacity = gs.opacity |
|
|
|
|
|
|
|
scales = None |
|
rotations = None |
|
cov3D_precomp = None |
|
scales = gs.scaling |
|
rotations = gs.rotation |
|
|
|
|
|
|
|
shs = None |
|
colors_precomp = None |
|
if self.gs_net.use_rgb: |
|
colors_precomp = gs.shs.squeeze(1) |
|
else: |
|
shs = gs.shs |
|
|
|
|
|
|
|
with torch.autocast(device_type=self.device.type, dtype=torch.float32): |
|
raster_ret = rasterizer( |
|
means3D = means3D.float(), |
|
means2D = means2D.float(), |
|
shs = shs.float() if not self.gs_net.use_rgb else None, |
|
colors_precomp = colors_precomp.float() if colors_precomp is not None else None, |
|
opacities = opacity.float(), |
|
scales = scales.float(), |
|
rotations = rotations.float(), |
|
cov3D_precomp = cov3D_precomp |
|
) |
|
rendered_image, radii, rendered_depth, rendered_alpha = raster_ret |
|
|
|
ret = { |
|
"comp_rgb": rendered_image.permute(1, 2, 0), |
|
"comp_rgb_bg": bg_color, |
|
'comp_mask': rendered_alpha.permute(1, 2, 0), |
|
'comp_depth': rendered_depth.permute(1, 2, 0), |
|
} |
|
|
|
return ret |
|
|
|
def animate_gs_model(self, gs_attr: GaussianModel, query_points, flame_data, debug=False): |
|
""" |
|
query_points: [N, 3] |
|
""" |
|
device = gs_attr.xyz.device |
|
if debug: |
|
N = gs_attr.xyz.shape[0] |
|
gs_attr.xyz = torch.ones_like(gs_attr.xyz) * 0.0 |
|
|
|
rotation = matrix_to_quaternion(torch.eye(3).float()[None, :, :].repeat(N, 1, 1)).to(device) |
|
opacity = torch.ones((N, 1)).float().to(device) |
|
|
|
gs_attr.opacity = opacity |
|
gs_attr.rotation = rotation |
|
|
|
|
|
|
|
with torch.autocast(device_type=device.type, dtype=torch.float32): |
|
|
|
mean_3d = gs_attr.xyz |
|
|
|
num_view = flame_data["expr"].shape[0] |
|
mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) |
|
query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1) |
|
|
|
if self.teeth_bs_flag: |
|
expr = torch.cat([flame_data['expr'], flame_data['teeth_bs']], dim=-1) |
|
else: |
|
expr = flame_data["expr"] |
|
ret = self.flame_model.animation_forward(v_cano=mean_3d, |
|
shape=flame_data["betas"].repeat(num_view, 1), |
|
expr=expr, |
|
rotation=flame_data["rotation"], |
|
neck=flame_data["neck_pose"], |
|
jaw=flame_data["jaw_pose"], |
|
eyes=flame_data["eyes_pose"], |
|
translation=flame_data["translation"], |
|
zero_centered_at_root_node=False, |
|
return_landmarks=False, |
|
return_verts_cano=False, |
|
|
|
static_offset=None, |
|
) |
|
mean_3d = ret["animated"] |
|
|
|
gs_attr_list = [] |
|
for i in range(num_view): |
|
gs_attr_copy = GaussianModel(xyz=mean_3d[i], |
|
opacity=gs_attr.opacity, |
|
rotation=gs_attr.rotation, |
|
scaling=gs_attr.scaling, |
|
shs=gs_attr.shs, |
|
albedo=gs_attr.albedo, |
|
lights=gs_attr.lights, |
|
offset=gs_attr.offset) |
|
gs_attr_list.append(gs_attr_copy) |
|
|
|
return gs_attr_list |
|
|
|
|
|
def forward_gs_attr(self, x, query_points, flame_data, debug=False, x_fine=None, vtx_sym_idxs=None): |
|
""" |
|
x: [N, C] Float[Tensor, "Np Cp"], |
|
query_points: [N, 3] Float[Tensor, "Np 3"] |
|
""" |
|
device = x.device |
|
if self.mlp_network_config is not None: |
|
x = self.mlp_net(x) |
|
if x_fine is not None: |
|
x_fine = self.mlp_net(x_fine) |
|
gs_attr: GaussianModel = self.gs_net(x, query_points, x_fine, vtx_sym_idxs=vtx_sym_idxs) |
|
return gs_attr |
|
|
|
|
|
def get_query_points(self, flame_data, device): |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device.type, dtype=torch.float32): |
|
|
|
|
|
positions = self.flame_model.get_cano_verts(shape_params=flame_data["betas"]) |
|
|
|
|
|
return positions, flame_data |
|
|
|
def query_latent_feat(self, |
|
positions: Float[Tensor, "*B N1 3"], |
|
flame_data, |
|
latent_feat: Float[Tensor, "*B N2 C"], |
|
extra_info): |
|
device = latent_feat.device |
|
if self.skip_decoder: |
|
gs_feats = latent_feat |
|
assert positions is not None |
|
else: |
|
assert positions is None |
|
if positions is None: |
|
positions, flame_data = self.get_query_points(flame_data, device) |
|
|
|
with torch.autocast(device_type=device.type, dtype=torch.float32): |
|
pcl_embed = self.pcl_embed(positions) |
|
gs_feats = pcl_embed |
|
|
|
return gs_feats, positions, flame_data |
|
|
|
def forward_single_batch( |
|
self, |
|
gs_list: list[GaussianModel], |
|
c2ws: Float[Tensor, "Nv 4 4"], |
|
intrinsics: Float[Tensor, "Nv 4 4"], |
|
height: int, |
|
width: int, |
|
background_color: Optional[Float[Tensor, "Nv 3"]], |
|
debug: bool=False, |
|
): |
|
out_list = [] |
|
self.device = gs_list[0].xyz.device |
|
for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)): |
|
out_list.append(self.forward_single_view( |
|
gs_list[v_idx], |
|
Camera.from_c2w(c2w, intrinsic, height, width), |
|
background_color[v_idx], |
|
)) |
|
|
|
out = defaultdict(list) |
|
for out_ in out_list: |
|
for k, v in out_.items(): |
|
out[k].append(v) |
|
out = {k: torch.stack(v, dim=0) for k, v in out.items()} |
|
out["3dgs"] = gs_list |
|
|
|
return out |
|
|
|
def get_sing_batch_smpl_data(self, smpl_data, bidx): |
|
smpl_data_single_batch = {} |
|
for k, v in smpl_data.items(): |
|
smpl_data_single_batch[k] = v[bidx] |
|
if k == "betas" or (k == "joint_offset") or (k == "face_offset"): |
|
smpl_data_single_batch[k] = v[bidx:bidx+1] |
|
return smpl_data_single_batch |
|
|
|
def get_single_view_smpl_data(self, smpl_data, vidx): |
|
smpl_data_single_view = {} |
|
for k, v in smpl_data.items(): |
|
assert v.shape[0] == 1 |
|
if k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose"): |
|
smpl_data_single_view[k] = v |
|
else: |
|
smpl_data_single_view[k] = v[:, vidx: vidx + 1] |
|
return smpl_data_single_view |
|
|
|
def forward_gs(self, |
|
gs_hidden_features: Float[Tensor, "B Np Cp"], |
|
query_points: Float[Tensor, "B Np_q 3"], |
|
flame_data, |
|
additional_features: Optional[dict] = None, |
|
debug: bool = False, |
|
**kwargs): |
|
|
|
batch_size = gs_hidden_features.shape[0] |
|
|
|
query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, |
|
additional_features) |
|
|
|
gs_model_list = [] |
|
all_query_points = [] |
|
for b in range(batch_size): |
|
all_query_points.append(query_points[b:b+1, :]) |
|
if isinstance(query_gs_features, dict): |
|
ret_gs = self.forward_gs_attr(query_gs_features["coarse"][b], query_points[b], None, debug, |
|
x_fine=query_gs_features["fine"][b], vtx_sym_idxs=None) |
|
else: |
|
ret_gs = self.forward_gs_attr(query_gs_features[b], query_points[b], None, debug, vtx_sym_idxs=None) |
|
|
|
ret_gs.update_albedo(ret_gs.shs.clone()) |
|
|
|
gs_model_list.append(ret_gs) |
|
|
|
query_points = torch.cat(all_query_points, dim=0) |
|
return gs_model_list, query_points, flame_data, query_gs_features |
|
|
|
def forward_res_refine_gs(self, |
|
gs_hidden_features: Float[Tensor, "B Np Cp"], |
|
query_points: Float[Tensor, "B Np_q 3"], |
|
flame_data, |
|
additional_features: Optional[dict] = None, |
|
debug: bool = False, |
|
gs_raw_attr_list: list = None, |
|
**kwargs): |
|
|
|
batch_size = gs_hidden_features.shape[0] |
|
|
|
query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, |
|
additional_features) |
|
|
|
gs_model_list = [] |
|
for b in range(batch_size): |
|
gs_model = self.gs_refine_net(query_gs_features[b], query_points[b], x_fine=None, gs_raw_attr=gs_raw_attr_list[b]) |
|
gs_model_list.append(gs_model) |
|
return gs_model_list, query_points, flame_data, query_gs_features |
|
|
|
def forward_animate_gs(self, gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, |
|
background_color, debug=False): |
|
batch_size = len(gs_model_list) |
|
out_list = [] |
|
|
|
for b in range(batch_size): |
|
gs_model = gs_model_list[b] |
|
query_pt = query_points[b] |
|
animatable_gs_model_list: list[GaussianModel] = self.animate_gs_model(gs_model, |
|
query_pt, |
|
self.get_sing_batch_smpl_data(flame_data, b), |
|
debug=debug) |
|
assert len(animatable_gs_model_list) == c2w.shape[1] |
|
out_list.append(self.forward_single_batch( |
|
animatable_gs_model_list, |
|
c2w[b], |
|
intrinsic[b], |
|
height, width, |
|
background_color[b] if background_color is not None else None, |
|
debug=debug)) |
|
|
|
out = defaultdict(list) |
|
for out_ in out_list: |
|
for k, v in out_.items(): |
|
out[k].append(v) |
|
for k, v in out.items(): |
|
if isinstance(v[0], torch.Tensor): |
|
out[k] = torch.stack(v, dim=0) |
|
else: |
|
out[k] = v |
|
|
|
render_keys = ["comp_rgb", "comp_mask", "comp_depth"] |
|
for key in render_keys: |
|
out[key] = rearrange(out[key], "b v h w c -> b v c h w") |
|
|
|
return out |
|
|
|
def project_single_view_feats(self, img_vtx_ids, feats, nv, inter_feat=True): |
|
b, h, w, k = img_vtx_ids.shape |
|
c, ih, iw = feats.shape |
|
vtx_ids = img_vtx_ids |
|
if h != ih or w != iw: |
|
if inter_feat: |
|
feats = torch.nn.functional.interpolate( |
|
rearrange(feats, "(b c) h w -> b c h w", b=1).float(), (h, w) |
|
).squeeze(0) |
|
vtx_ids = rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).long().squeeze(1) |
|
else: |
|
vtx_ids = torch.nn.functional.interpolate( |
|
rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).float(), (ih, iw), mode="nearest" |
|
).long().squeeze(1) |
|
else: |
|
vtx_ids = rearrange(vtx_ids, "b h w k -> (b k) h w", b=1).long() |
|
vis_mask = vtx_ids > 0 |
|
vtx_ids = vtx_ids[vis_mask] |
|
vtx_ids = repeat(vtx_ids, "n -> n c", c=c) |
|
|
|
feats = repeat(feats, "c h w -> k h w c", k=k).to(vtx_ids.device) |
|
feats = feats[vis_mask, :] |
|
|
|
sums = torch.zeros((nv, c), dtype=feats.dtype, device=feats.device) |
|
counts = torch.zeros((nv), dtype=torch.int64, device=feats.device) |
|
|
|
sums.scatter_add_(0, vtx_ids, feats) |
|
one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(feats.device) |
|
counts.scatter_add_(0, vtx_ids[:, 0], one_hot) |
|
clamp_counts = counts.clamp(min=1) |
|
mean_feats = sums / clamp_counts.view(-1, 1) |
|
return mean_feats |
|
|
|
def forward(self, |
|
gs_hidden_features: Float[Tensor, "B Np Cp"], |
|
query_points: Float[Tensor, "B Np 3"], |
|
flame_data, |
|
c2w: Float[Tensor, "B Nv 4 4"], |
|
intrinsic: Float[Tensor, "B Nv 4 4"], |
|
height, |
|
width, |
|
additional_features: Optional[Float[Tensor, "B C H W"]] = None, |
|
background_color: Optional[Float[Tensor, "B Nv 3"]] = None, |
|
debug: bool = False, |
|
**kwargs): |
|
|
|
|
|
gs_model_list, query_points, flame_data, query_gs_features = self.forward_gs(gs_hidden_features, query_points, flame_data=flame_data, |
|
additional_features=additional_features, debug=debug) |
|
|
|
out = self.forward_animate_gs(gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug) |
|
|
|
return out |
|
|
|
|
|
def test_head(): |
|
import cv2 |
|
|
|
human_model_path = "./pretrained_models/human_model_files" |
|
device = "cuda" |
|
|
|
from accelerate.utils import set_seed |
|
set_seed(1234) |
|
|
|
from lam.datasets.video_head import VideoHeadDataset |
|
root_dir = "./train_data/vfhq_vhap/export" |
|
meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json" |
|
|
|
|
|
dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7, |
|
render_image_res_low=512, render_image_res_high=512, |
|
render_region_size=(512, 512), source_image_res=512, |
|
enlarge_ratio=[0.8, 1.2], |
|
debug=False) |
|
|
|
data = dataset[0] |
|
|
|
def get_flame_params(data): |
|
flame_params = {} |
|
flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ |
|
'rotation', 'neck_pose', 'eyes_pose', 'translation'] |
|
for k, v in data.items(): |
|
if k in flame_keys: |
|
|
|
flame_params[k] = data[k] |
|
return flame_params |
|
|
|
flame_data = get_flame_params(data) |
|
|
|
flame_data_tmp = {} |
|
for k, v in flame_data.items(): |
|
flame_data_tmp[k] = v.unsqueeze(0).to(device) |
|
print(k, v.shape) |
|
flame_data = flame_data_tmp |
|
|
|
c2ws = data["c2ws"].unsqueeze(0).to(device) |
|
intrs = data["intrs"].unsqueeze(0).to(device) |
|
render_images = data["render_image"].numpy() |
|
render_h = data["render_full_resolutions"][0, 0] |
|
render_w= data["render_full_resolutions"][0, 1] |
|
render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device) |
|
print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs) |
|
|
|
gs_render = GS3DRenderer(human_model_path=human_model_path, subdivide_num=2, smpl_type="flame", |
|
feat_dim=64, query_dim=64, use_rgb=True, sh_degree=3, mlp_network_config=None, |
|
xyz_offset_max_step=0.0001, expr_param_dim=10, shape_param_dim=10, |
|
fix_opacity=True, fix_rotation=True, clip_scaling=0.001, add_teeth=False) |
|
gs_render.to(device) |
|
|
|
out = gs_render.forward(gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device), |
|
query_points=None, |
|
flame_data=flame_data, |
|
c2w=c2ws, |
|
intrinsic=intrs, |
|
height=render_h, |
|
width=render_w, |
|
background_color=render_bg_colors, |
|
debug=False) |
|
|
|
os.makedirs("./debug_vis/gs_render", exist_ok=True) |
|
for k, v in out.items(): |
|
if k == "comp_rgb_bg": |
|
print("comp_rgb_bg", v) |
|
continue |
|
for b_idx in range(len(v)): |
|
if k == "3dgs": |
|
for v_idx in range(len(v[b_idx])): |
|
v[b_idx][v_idx].save_ply(f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply") |
|
continue |
|
for v_idx in range(v.shape[1]): |
|
save_path = os.path.join("./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg") |
|
if "normal" in k: |
|
img = ((v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2. * 255).astype(np.uint8) |
|
else: |
|
img = (v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) |
|
print(v[b_idx, v_idx].shape, img.shape, save_path) |
|
if "mask" in k: |
|
render_img = render_images[v_idx].transpose(1, 2, 0) * 255 |
|
blend_img = (render_images[v_idx].transpose(1, 2, 0) * 255 * 0.5 + np.tile(img, (1, 1, 3)) * 0.5).clip(0, 255).astype(np.uint8) |
|
cv2.imwrite(save_path, np.hstack([np.tile(img, (1, 1, 3)), render_img.astype(np.uint8), blend_img])[:, :, (2, 1, 0)]) |
|
else: |
|
print(save_path, k) |
|
cv2.imwrite(save_path, img) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
test_head() |
|
|