import gradio as gr from PIL import Image from torchvision.transforms import Compose, ToTensor, Resize, Normalize import numpy as np import imageio import tempfile from utils.utils import denorm from model.hub import MultiInputResShiftHub import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI") model.requires_grad_(False).to(device).eval() transform = Compose([ Resize((256, 448)), ToTensor(), Normalize(mean=[0.5]*3, std=[0.5]*3), ]) def to_numpy(img_tensor: torch.Tensor) -> np.ndarray: img_np = denorm(img_tensor, mean=[0.5]*3, std=[0.5]*3).squeeze().permute(1, 2, 0).cpu().numpy() img_np = np.clip(img_np, 0, 1) return (img_np * 255).astype(np.uint8) def interpolate(img0_pil: Image.Image, img2_pil: Image.Image, tau: float=0.5, num_samples: int=1) -> tuple: img0 = transform(img0_pil.convert("RGB")).unsqueeze(0).to(device) img2 = transform(img2_pil.convert("RGB")).unsqueeze(0).to(device) try: if num_samples == 1: # Unique image img1 = model.reverse_process([img0, img2], tau) return Image.fromarray(to_numpy(img1)), None else: # Múltiples imágenes → video frames = [to_numpy(img0)] for t in np.linspace(0, 1, num_samples): img = model.reverse_process([img0, img2], float(t)) frames.append(to_numpy(img)) frames.append(to_numpy(img2)) temp_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name imageio.mimsave(temp_path, frames, fps=8) return None, temp_path except Exception as e: print(f"Error during interpolation: {e}") return None, None def build_demo() -> gr.Blocks: header = """

🎞️ Multi-Input ResShift Diffusion VFI

Efficient and stochastic video frame interpolation for hand-drawn animation.

arXiv HF Colab GitHub

Usage:

""" with gr.Blocks() as demo: gr.HTML(header) with gr.Row(): img0 = gr.Image(type="pil", label="Initial Image (frame1)") img2 = gr.Image(type="pil", label="Final Image (frame3)") with gr.Row(): tau = gr.Slider(0.0, 1.0, step=0.05, value=0.5, label="Tau Value (only if Num Samples = 1)") samples = gr.Slider(1, 20, step=1, value=1, label="Number of Samples") btn = gr.Button("Generate") with gr.Row(): output_img = gr.Image(label="Interpolated Image (if num_samples = 1)") output_vid = gr.Video(label="Interpolation in video (if num_samples > 1)") btn.click(interpolate, inputs=[img0, img2, tau, samples], outputs=[output_img, output_vid]) gr.Examples( examples=[ ["_data/example_images/frame1.png", "_data/example_images/frame3.png", 0.5, 1], ], inputs=[img0, img2, tau, samples], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch(server_name="0.0.0.0", ssr_mode=False) #demo.launch()