import torch import numpy as np from einops import rearrange def sample_img_rays(x, img_fov=45): """ Samples a unit ray for each pixel in image Args: x: images (...,h,w) img_fov: assumed image fov for ray calculation; int or tuple(h,w) Returns: img_rays (h,w,3) 3: """ h, w, dtype, device = *x.shape[-2:], x.dtype, x.device hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360 axis_mag = (1/hf_rad.cos()).expand(2) # [y,x] axis_max_coord = (axis_mag**2-1)**.5 # [y,x] y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device) x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device) y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij') xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,) img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1) return img_rays def gen_rotation_matrix(angles): """ Generate rotation matrix from angles Args: angles: axis-wise rotations in [0,360] (...,3) Returns: rot_mat (...,3,3) """ dims = angles.shape[:-1] angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi] angles = rearrange(angles, '... a -> a ...') # (3,...) cos = angles.cos() sin = angles.sin() rot_mat = torch.stack([ cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2], cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2], -sin[1], sin[0]*cos[1], cos[0]*cos[1] ], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3) return rot_mat def cart_2_spherical(pts): """ Convert Cartesian to spherical coordinates Args: pts: input pts (...,) Returns: ret (...,) () (radians) """ x,y,z = pts.moveaxis(-1,0) r = pts.norm(dim=-1) phi = torch.arcsin(y/r) theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5) ret = torch.stack([theta,phi,r],dim=-1) return ret def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1): """ Sample points from panoramic image Args: img: pano-image (...,3:,h,w) pts: spherical points to sample from img (...,h,w,3:) *_fov_ratio: ratio of full fov for pano Returns: sampled_img (...,3:,h,w) """ h, w = img.shape[-2:] sh, sw = pts.shape[-3:-1] h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3) pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3) # convert (pts) radians to indices h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2) w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi) # get inds for bilin interp h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1) h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1) # get weights h_p_r, w_p_r = h_inds-h_l, w_inds-w_l h_p_l, w_p_l = 1-h_p_r, 1-w_p_r # linearize inds,weights inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...) weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...) # do bilin interp img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3)) sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3) sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw) return sampled_img def sample_perspective_img(pano_img, output_shape, fov=None, rot=None): """ Sample perspective image from panoramic Args: pano_img: pano-image numpy.array (h,w,3:) output_shape: output image dimensions tuple(h,w) fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180) rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360] Returns: sampled_img numpy.array (h,w,3:), fov, rot """ if fov is None: fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov) fov = (fov[0].item(), fov[1].item()) if rot is None: rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll) else: rot = torch.tensor(rot) pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0) out_dtype = pano_img.dtype pano_img = pano_img.to(torch.float) img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov) rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3) rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1) spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3) # sample img pano_img = sample_pano_img(pano_img, spher_rot_img_rays) return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy()