vace-demo / vace /gradios /vace_ltx_demo.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import sys
import datetime
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-3]))
from vace.models.ltx.ltx_vace import LTXVace
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
if not skip_load:
self.pipe = LTXVace(ckpt_path=args.ckpt_path,
text_encoder_path=args.text_encoder_path,
precision=args.precision,
stg_skip_layers=args.stg_skip_layers,
stg_mode=args.stg_mode,
offload_to_cpu=args.offload_to_cpu)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-LTXV Demo</a>
</div>
""")
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
self.src_video = gr.Video(
label="src_video",
sources=['upload'],
value=None,
interactive=True)
with gr.Column(scale=1, min_width=0):
self.src_mask = gr.Video(
label="src_mask",
sources=['upload'],
value=None,
interactive=True)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
show_label=False,
placeholder="positive_prompt_input",
elem_id='positive_prompt',
container=True,
autofocus=True,
elem_classes='type_row',
visible=True,
lines=2)
self.negative_prompt = gr.Textbox(
show_label=False,
value="worst quality, inconsistent motion, blurry, jittery, distorted",
placeholder="negative_prompt_input",
elem_id='negative_prompt',
container=True,
autofocus=False,
elem_classes='type_row',
visible=True,
interactive=True,
lines=1)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.sample_steps = gr.Slider(
label='sample_steps',
minimum=1,
maximum=100,
step=1,
value=40,
interactive=True)
self.context_scale = gr.Slider(
label='context_scale',
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
interactive=True)
self.guide_scale = gr.Slider(
label='guide_scale',
minimum=1,
maximum=10,
step=0.5,
value=3.0,
interactive=True)
self.infer_seed = gr.Slider(minimum=-1,
maximum=10000000,
value=2025,
label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
self.output_height = gr.Textbox(
label='resolutions_height',
value=512,
interactive=True)
self.output_width = gr.Textbox(
label='resolutions_width',
value=768,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate',
value=25,
interactive=True)
self.num_frames = gr.Textbox(
label='num_frames',
value=97,
interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
self.generate_button = gr.Button(
value='Run',
elem_classes='type_row',
elem_id='generate_button',
visible=True)
with gr.Column(scale=1):
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
#
self.output_gallery = gr.Gallery(
label="output_gallery",
value=[],
interactive=False,
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
output = self.pipe.generate(src_video=src_video,
src_mask=src_mask,
src_ref_images=[src_ref_image_1, src_ref_image_2, src_ref_image_3],
prompt=prompt,
negative_prompt=negative_prompt,
seed=infer_seed,
num_inference_steps=sample_steps,
num_images_per_prompt=1,
context_scale=context_scale,
guidance_scale=guide_scale,
frame_rate=frame_rate,
output_height=output_height,
output_width=output_width,
num_frames=num_frames)
frame_rate = output['info']['frame_rate']
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (torch.clamp(output['out_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
print(video_path)
except Exception as e:
raise gr.Error(f"Video save error: {e}")
if self.gallery_share:
self.gallery_share_data.add(video_path)
return self.gallery_share_data.get()
else:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for VACE-LTXV Demo:\n')
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument(
"--ckpt_path",
type=str,
default='models/VACE-LTX-Video-0.9/ltx-video-2b-v0.9.safetensors',
help="Path to a safetensors file that contains all model parts.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default='models/VACE-LTX-Video-0.9',
help="Path to a safetensors file that contains all model parts.",
)
parser.add_argument(
"--stg_mode",
type=str,
default="stg_a",
help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.",
)
parser.add_argument(
"--stg_skip_layers",
type=str,
default="19",
help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.",
)
parser.add_argument(
"--precision",
choices=["bfloat16", "mixed_precision"],
default="bfloat16",
help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True, debug=True)