|
|
|
|
|
import argparse |
|
import os |
|
import random |
|
import time |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from models.ltx.ltx_vace import LTXVace |
|
from annotators.utils import save_one_video, save_one_image, get_annotator |
|
|
|
MAX_HEIGHT = 720 |
|
MAX_WIDTH = 1280 |
|
MAX_NUM_FRAMES = 257 |
|
|
|
def get_total_gpu_memory(): |
|
if torch.cuda.is_available(): |
|
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
return total_memory |
|
return None |
|
|
|
|
|
def seed_everething(seed: int): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
description="Load models from separate directories and run the pipeline." |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--ckpt_path", |
|
type=str, |
|
default='models/VACE-LTX-Video-0.9/ltx-video-2b-v0.9.safetensors', |
|
help="Path to a safetensors file that contains all model parts.", |
|
) |
|
parser.add_argument( |
|
"--text_encoder_path", |
|
type=str, |
|
default='models/VACE-LTX-Video-0.9', |
|
help="Path to a safetensors file that contains all model parts.", |
|
) |
|
parser.add_argument( |
|
"--src_video", |
|
type=str, |
|
default=None, |
|
help="The file of the source video. Default None.") |
|
parser.add_argument( |
|
"--src_mask", |
|
type=str, |
|
default=None, |
|
help="The file of the source mask. Default None.") |
|
parser.add_argument( |
|
"--src_ref_images", |
|
type=str, |
|
default=None, |
|
help="The file list of the source reference images. Separated by ','. Default None.") |
|
parser.add_argument( |
|
"--save_dir", |
|
type=str, |
|
default=None, |
|
help="Path to the folder to save output video, if None will save in results/ directory.", |
|
) |
|
parser.add_argument("--seed", type=int, default="42") |
|
|
|
|
|
parser.add_argument( |
|
"--num_inference_steps", type=int, default=40, help="Number of inference steps" |
|
) |
|
parser.add_argument( |
|
"--num_images_per_prompt", |
|
type=int, |
|
default=1, |
|
help="Number of images per prompt", |
|
) |
|
parser.add_argument( |
|
"--context_scale", |
|
type=float, |
|
default=1.0, |
|
help="Context scale for the pipeline", |
|
) |
|
parser.add_argument( |
|
"--guidance_scale", |
|
type=float, |
|
default=3, |
|
help="Guidance scale for the pipeline", |
|
) |
|
parser.add_argument( |
|
"--stg_scale", |
|
type=float, |
|
default=1, |
|
help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.", |
|
) |
|
parser.add_argument( |
|
"--stg_rescale", |
|
type=float, |
|
default=0.7, |
|
help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.", |
|
) |
|
parser.add_argument( |
|
"--stg_mode", |
|
type=str, |
|
default="stg_a", |
|
help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.", |
|
) |
|
parser.add_argument( |
|
"--stg_skip_layers", |
|
type=str, |
|
default="19", |
|
help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.", |
|
) |
|
parser.add_argument( |
|
"--image_cond_noise_scale", |
|
type=float, |
|
default=0.15, |
|
help="Amount of noise to add to the conditioned image", |
|
) |
|
parser.add_argument( |
|
"--height", |
|
type=int, |
|
default=512, |
|
help="The height of the output video only if src_video is empty.", |
|
) |
|
parser.add_argument( |
|
"--width", |
|
type=int, |
|
default=768, |
|
help="The width of the output video only if src_video is empty.", |
|
) |
|
parser.add_argument( |
|
"--num_frames", |
|
type=int, |
|
default=97, |
|
help="The frames of the output video only if src_video is empty.", |
|
) |
|
parser.add_argument( |
|
"--frame_rate", type=int, default=25, help="Frame rate for the output video" |
|
) |
|
|
|
parser.add_argument( |
|
"--precision", |
|
choices=["bfloat16", "mixed_precision"], |
|
default="bfloat16", |
|
help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--decode_timestep", |
|
type=float, |
|
default=0.05, |
|
help="Timestep for decoding noise", |
|
) |
|
parser.add_argument( |
|
"--decode_noise_scale", |
|
type=float, |
|
default=0.025, |
|
help="Noise level for decoding noise", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
required=True, |
|
help="Text prompt to guide generation", |
|
) |
|
parser.add_argument( |
|
"--negative_prompt", |
|
type=str, |
|
default="worst quality, inconsistent motion, blurry, jittery, distorted", |
|
help="Negative prompt for undesired features", |
|
) |
|
|
|
parser.add_argument( |
|
"--offload_to_cpu", |
|
action="store_true", |
|
help="Offloading unnecessary computations to CPU.", |
|
) |
|
parser.add_argument( |
|
"--use_prompt_extend", |
|
default='plain', |
|
choices=['plain', 'ltx_en', 'ltx_en_ds'], |
|
help="Whether to use prompt extend." |
|
) |
|
return parser |
|
|
|
def main(args): |
|
args = argparse.Namespace(**args) if isinstance(args, dict) else args |
|
|
|
print(f"Running generation with arguments: {args}") |
|
|
|
seed_everething(args.seed) |
|
|
|
offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30 |
|
|
|
assert os.path.exists(args.ckpt_path) and os.path.exists(args.text_encoder_path) |
|
|
|
ltx_vace = LTXVace(ckpt_path=args.ckpt_path, |
|
text_encoder_path=args.text_encoder_path, |
|
precision=args.precision, |
|
stg_skip_layers=args.stg_skip_layers, |
|
stg_mode=args.stg_mode, |
|
offload_to_cpu=offload_to_cpu) |
|
|
|
src_ref_images = args.src_ref_images.split(',') if args.src_ref_images is not None else [] |
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain': |
|
prompt = get_annotator(config_type='prompt', config_task=args.use_prompt_extend, return_dict=False).forward(args.prompt) |
|
print(f"Prompt extended from '{args.prompt}' to '{prompt}'") |
|
else: |
|
prompt = args.prompt |
|
|
|
output = ltx_vace.generate(src_video=args.src_video, |
|
src_mask=args.src_mask, |
|
src_ref_images=src_ref_images, |
|
prompt=prompt, |
|
negative_prompt=args.negative_prompt, |
|
seed=args.seed, |
|
num_inference_steps=args.num_inference_steps, |
|
num_images_per_prompt=args.num_images_per_prompt, |
|
context_scale=args.context_scale, |
|
guidance_scale=args.guidance_scale, |
|
stg_scale=args.stg_scale, |
|
stg_rescale=args.stg_rescale, |
|
frame_rate=args.frame_rate, |
|
image_cond_noise_scale=args.image_cond_noise_scale, |
|
decode_timestep=args.decode_timestep, |
|
decode_noise_scale=args.decode_noise_scale, |
|
output_height=args.height, |
|
output_width=args.width, |
|
num_frames=args.num_frames) |
|
|
|
|
|
if args.save_dir is None: |
|
save_dir = os.path.join('results', 'vace_ltxv', time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))) |
|
else: |
|
save_dir = args.save_dir |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
frame_rate = output['info']['frame_rate'] |
|
|
|
ret_data = {} |
|
if output['out_video'] is not None: |
|
save_path = os.path.join(save_dir, 'out_video.mp4') |
|
out_video = (torch.clamp(output['out_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) |
|
save_one_video(save_path, out_video, fps=frame_rate) |
|
print(f"Save out_video to {save_path}") |
|
ret_data['out_video'] = save_path |
|
if output['src_video'] is not None: |
|
save_path = os.path.join(save_dir, 'src_video.mp4') |
|
src_video = (torch.clamp(output['src_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) |
|
save_one_video(save_path, src_video, fps=frame_rate) |
|
print(f"Save src_video to {save_path}") |
|
ret_data['src_video'] = save_path |
|
if output['src_mask'] is not None: |
|
save_path = os.path.join(save_dir, 'src_mask.mp4') |
|
src_mask = (torch.clamp(output['src_mask'], min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) |
|
save_one_video(save_path, src_mask, fps=frame_rate) |
|
print(f"Save src_mask to {save_path}") |
|
ret_data['src_mask'] = save_path |
|
if output['src_ref_images'] is not None: |
|
for i, ref_img in enumerate(output['src_ref_images']): |
|
save_path = os.path.join(save_dir, f'src_ref_image_{i}.png') |
|
ref_img = (torch.clamp(ref_img.squeeze(1), min=0.0, max=1.0).permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8) |
|
save_one_image(save_path, ref_img, use_type='pil') |
|
print(f"Save src_ref_image_{i} to {save_path}") |
|
ret_data[f'src_ref_image_{i}'] = save_path |
|
return ret_data |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_parser().parse_args() |
|
main(args) |