Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import io | |
from typing import Union | |
import ipywidgets as widgets | |
import numpy as np | |
import torch | |
from PIL import Image | |
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera | |
from shap_e.models.transmitter.base import Transmitter, VectorDecoder | |
from shap_e.rendering.torch_mesh import TorchMesh | |
from shap_e.util.collections import AttrDict | |
def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch: | |
origins = [] | |
xs = [] | |
ys = [] | |
zs = [] | |
for theta in np.linspace(0, 2 * np.pi, num=20): | |
z = np.array([np.sin(theta), np.cos(theta), -0.5]) | |
z /= np.sqrt(np.sum(z**2)) | |
origin = -z * 4 | |
x = np.array([np.cos(theta), -np.sin(theta), 0.0]) | |
y = np.cross(z, x) | |
origins.append(origin) | |
xs.append(x) | |
ys.append(y) | |
zs.append(z) | |
return DifferentiableCameraBatch( | |
shape=(1, len(xs)), | |
flat_camera=DifferentiableProjectiveCamera( | |
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), | |
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), | |
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), | |
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), | |
width=size, | |
height=size, | |
x_fov=0.7, | |
y_fov=0.7, | |
), | |
) | |
def decode_latent_images( | |
xm: Union[Transmitter, VectorDecoder], | |
latent: torch.Tensor, | |
cameras: DifferentiableCameraBatch, | |
rendering_mode: str = "stf", | |
): | |
decoded = xm.renderer.render_views( | |
AttrDict(cameras=cameras), | |
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( | |
latent[None] | |
), | |
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), | |
) | |
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
return [Image.fromarray(x) for x in arr] | |
def decode_latent_mesh( | |
xm: Union[Transmitter, VectorDecoder], | |
latent: torch.Tensor, | |
) -> TorchMesh: | |
decoded = xm.renderer.render_views( | |
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible | |
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( | |
latent[None] | |
), | |
options=AttrDict(rendering_mode="stf", render_with_direction=False), | |
) | |
return decoded.raw_meshes[0] | |
def gif_widget(images): | |
writer = io.BytesIO() | |
images[0].save( | |
writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0 | |
) | |
writer.seek(0) | |
data = base64.b64encode(writer.read()).decode("ascii") | |
return widgets.HTML(f'<img src="data:image/gif;base64,{data}" />') | |