File size: 4,879 Bytes
0887e6b d5e34b6 0887e6b a89535f 0887e6b d5e34b6 a89535f 870001d a89535f 870001d a89535f 870001d a89535f 17c3fdf a89535f 17c3fdf a89535f 17c3fdf a89535f 17c3fdf a89535f 0887e6b a89535f 6edbb29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
import numpy as np
import imageio
import tempfile
from utils.utils import denorm
from model.hub import MultiInputResShiftHub
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI")
model.requires_grad_(False).to(device).eval()
transform = Compose([
Resize((256, 448)),
ToTensor(),
Normalize(mean=[0.5]*3, std=[0.5]*3),
])
def to_numpy(img_tensor: torch.Tensor) -> np.ndarray:
img_np = denorm(img_tensor, mean=[0.5]*3, std=[0.5]*3).squeeze().permute(1, 2, 0).cpu().numpy()
img_np = np.clip(img_np, 0, 1)
return (img_np * 255).astype(np.uint8)
def interpolate(img0_pil: Image.Image,
img2_pil: Image.Image,
tau: float=0.5,
num_samples: int=1) -> tuple:
img0 = transform(img0_pil.convert("RGB")).unsqueeze(0).to(device)
img2 = transform(img2_pil.convert("RGB")).unsqueeze(0).to(device)
try:
if num_samples == 1:
# Unique image
img1 = model.reverse_process([img0, img2], tau)
return Image.fromarray(to_numpy(img1)), None
else:
# Múltiples imágenes → video
frames = [to_numpy(img0)]
for t in np.linspace(0, 1, num_samples):
img = model.reverse_process([img0, img2], float(t))
frames.append(to_numpy(img))
frames.append(to_numpy(img2))
temp_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
imageio.mimsave(temp_path, frames, fps=8)
return None, temp_path
except Exception as e:
print(f"Error during interpolation: {e}")
return None, None
def build_demo() -> gr.Blocks:
header = """
<div style="text-align: center; padding: 1.5rem 0;">
<h1 style="font-size: 2.4rem; margin-bottom: 0.5rem;">🎞️ Multi-Input ResShift Diffusion VFI</h1>
<p style="font-size: 1.1rem; color: #444;">
Efficient and stochastic video frame interpolation for hand-drawn animation.
</p>
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 12px; margin: 1rem 0;">
<a href="https://arxiv.org/pdf/2504.05402">
<img src="https://img.shields.io/badge/arXiv-Paper-A42C25.svg" alt="arXiv">
</a>
<a href="https://huggingface.co./vfontech/Multiple-Input-Resshift-VFI">
<img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HF">
</a>
<a href="https://colab.research.google.com/drive/1MGYycbNMW6Mxu5MUqw_RW_xxiVeHK5Aa#scrollTo=EKaYCioiP3tQ">
<img src="https://img.shields.io/badge/Colab-Demo-green.svg" alt="Colab">
</a>
<a href="https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI">
<img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github" alt="GitHub">
</a>
</div>
<div style="max-width: 700px; margin: 0 auto; font-size: 0.96rem; color: #333;">
<p style="margin-bottom: 0.5rem;"><strong>Usage:</strong></p>
<ul style="list-style-type: none; padding: 0; line-height: 1.6;">
<li>All images are resized to <strong>256×448</strong>.</li>
<li>If <code>Number of Samples = 1</code>, generates a single interpolated frame using Tau.</li>
<li>If <code>Number of Samples > 1</code>, Tau is ignored and a full interpolation sequence is generated.</li>
</ul>
</div>
</div>
"""
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
img0 = gr.Image(type="pil", label="Initial Image (frame1)")
img2 = gr.Image(type="pil", label="Final Image (frame3)")
with gr.Row():
tau = gr.Slider(0.0, 1.0, step=0.05, value=0.5, label="Tau Value (only if Num Samples = 1)")
samples = gr.Slider(1, 20, step=1, value=1, label="Number of Samples")
btn = gr.Button("Generate")
with gr.Row():
output_img = gr.Image(label="Interpolated Image (if num_samples = 1)")
output_vid = gr.Video(label="Interpolation in video (if num_samples > 1)")
btn.click(interpolate, inputs=[img0, img2, tau, samples], outputs=[output_img, output_vid])
gr.Examples(
examples=[
["_data/example_images/frame1.png", "_data/example_images/frame3.png", 0.5, 1],
],
inputs=[img0, img2, tau, samples],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch(server_name="0.0.0.0", ssr_mode=False)
#demo.launch() |