Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from einops import rearrange | |
def sample_img_rays(x, img_fov=45): | |
""" | |
Samples a unit ray for each pixel in image | |
Args: | |
x: images (...,h,w) | |
img_fov: assumed image fov for ray calculation; int or tuple(h,w) | |
Returns: | |
img_rays (h,w,3) 3:<x,y,z> | |
""" | |
h, w, dtype, device = *x.shape[-2:], x.dtype, x.device | |
hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360 | |
axis_mag = (1/hf_rad.cos()).expand(2) # [y,x] | |
axis_max_coord = (axis_mag**2-1)**.5 # [y,x] | |
y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device) | |
x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device) | |
y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij') | |
xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,<x,y,z>) | |
img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1) | |
return img_rays | |
def gen_rotation_matrix(angles): | |
""" | |
Generate rotation matrix from angles | |
Args: | |
angles: axis-wise rotations in [0,360] (...,3) | |
Returns: | |
rot_mat (...,3,3) | |
""" | |
dims = angles.shape[:-1] | |
angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi] | |
angles = rearrange(angles, '... a -> a ...') # (3,...) | |
cos = angles.cos() | |
sin = angles.sin() | |
rot_mat = torch.stack([ | |
cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2], | |
cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2], | |
-sin[1], sin[0]*cos[1], cos[0]*cos[1] | |
], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3) | |
return rot_mat | |
def cart_2_spherical(pts): | |
""" | |
Convert Cartesian to spherical coordinates | |
Args: | |
pts: input pts (...,<x,y,z>) | |
Returns: | |
ret (...,<theta,phi,r>) (<azimuth,inclination,radius>) (radians) | |
""" | |
x,y,z = pts.moveaxis(-1,0) | |
r = pts.norm(dim=-1) | |
phi = torch.arcsin(y/r) | |
theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5) | |
ret = torch.stack([theta,phi,r],dim=-1) | |
return ret | |
def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1): | |
""" | |
Sample points from panoramic image | |
Args: | |
img: pano-image (...,3:<rgb>,h,w) | |
pts: spherical points to sample from img (...,h,w,3:<azimuth,inclination,radius>) | |
*_fov_ratio: ratio of full fov for pano | |
Returns: | |
sampled_img (...,3:<rgb>,h,w) | |
""" | |
h, w = img.shape[-2:] | |
sh, sw = pts.shape[-3:-1] | |
h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio | |
img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3) | |
pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3) | |
# convert (pts) radians to indices | |
h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2) | |
w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi) | |
# get inds for bilin interp | |
h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1) | |
h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1) | |
# get weights | |
h_p_r, w_p_r = h_inds-h_l, w_inds-w_l | |
h_p_l, w_p_l = 1-h_p_r, 1-w_p_r | |
# linearize inds,weights | |
inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...) | |
weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...) | |
# do bilin interp | |
img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3)) | |
sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3) | |
sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw) | |
return sampled_img | |
def sample_perspective_img(pano_img, output_shape, fov=None, rot=None): | |
""" | |
Sample perspective image from panoramic | |
Args: | |
pano_img: pano-image numpy.array (h,w,3:<rgb>) | |
output_shape: output image dimensions tuple(h,w) | |
fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180) | |
rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360] | |
Returns: | |
sampled_img numpy.array (h,w,3:<rgb>), fov, rot | |
""" | |
if fov is None: | |
fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov) | |
fov = (fov[0].item(), fov[1].item()) | |
if rot is None: | |
rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll) | |
else: | |
rot = torch.tensor(rot) | |
pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0) | |
out_dtype = pano_img.dtype | |
pano_img = pano_img.to(torch.float) | |
img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov) | |
rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3) | |
rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1) | |
spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3) | |
# sample img | |
pano_img = sample_pano_img(pano_img, spher_rot_img_rays) | |
return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy() |