Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import spaces | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import ml_collections | |
from torchvision.utils import save_image, make_grid | |
import torch.nn.functional as F | |
import einops | |
import random | |
import torchvision.transforms as standard_transforms | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="thu-ml/unidiffuser-v1", filename="autoencoder_kl.pth", local_dir='./models') | |
hf_hub_download(repo_id="mespinosami/COP-GEN-Beta", filename="nnet_ema_114000.pth", local_dir='./models') | |
import sys | |
sys.path.append('./src/COP-GEN-Beta') | |
import libs | |
from dpm_solver_pp import DPM_Solver, NoiseScheduleVP | |
from sample_n_triffuser import set_seed, stable_diffusion_beta_schedule, unpreprocess | |
import utils | |
from diffusers import AutoencoderKL | |
from .Triffuser import * | |
# Function to load model | |
def load_model(device='cuda'): | |
nnet = Triffuser(num_modalities=4) | |
checkpoint = torch.load('models/nnet_ema_114000.pth', map_location='cuda') | |
nnet.load_state_dict(checkpoint) | |
nnet.to(device) | |
nnet.eval() | |
autoencoder = libs.autoencoder.get_model(pretrained_path = "models/autoencoder_kl.pth") | |
autoencoder.to(device) | |
autoencoder.eval() | |
return nnet, autoencoder | |
print('Loading COP-GEN-Beta model...') | |
nnet, autoencoder = load_model() | |
to_PIL = standard_transforms.ToPILImage() | |
print('[DONE]') | |
def get_config(generate_modalities, condition_modalities, seed, num_inference_steps=50): | |
config = ml_collections.ConfigDict() | |
config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
config.seed = seed | |
config.n_samples = 1 | |
config.z_shape = (4, 32, 32) # Shape of the latent vectors | |
config.sample = { | |
'sample_steps': num_inference_steps, | |
'algorithm': "dpm_solver", | |
} | |
# Model config | |
config.num_modalities = 4 # 4 modalities: DEM, S1RTC, S2L1C, S2L2A | |
config.modalities = ['dem', 's1_rtc', 's2_l1c', 's2_l2a'] | |
# Network config | |
config.nnet = { | |
'name': 'triffuser_multi_post_ln', | |
'img_size': 32, | |
'in_chans': 4, | |
'patch_size': 2, | |
'embed_dim': 1024, | |
'depth': 20, | |
'num_heads': 16, | |
'mlp_ratio': 4, | |
'qkv_bias': False, | |
'pos_drop_rate': 0., | |
'drop_rate': 0., | |
'attn_drop_rate': 0., | |
'mlp_time_embed': False, | |
'num_modalities': 4, | |
'use_checkpoint': True, | |
} | |
# Parse generate and condition modalities | |
config.generate_modalities = generate_modalities | |
config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x)) | |
config.condition_modalities = condition_modalities if condition_modalities else [] | |
config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x)) | |
config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities] | |
config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities] | |
# Validate modalities | |
valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'} | |
for mod in config.generate_modalities + config.condition_modalities: | |
if mod not in valid_modalities: | |
raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}") | |
# Check that generate and condition modalities don't overlap | |
if set(config.generate_modalities) & set(config.condition_modalities): | |
raise ValueError("Generate and condition modalities must be different") | |
# Default data paths | |
config.nnet_path = 'models/nnet_ema_114000.pth' | |
#config.autoencoder = {"pretrained_path": "assets/stable-diffusion/autoencoder_kl_ema.pth"} | |
return config | |
# Function to prepare image for inference | |
def prepare_images(images): | |
transforms = standard_transforms.Compose([ | |
standard_transforms.ToTensor(), | |
standard_transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
]) | |
img_tensors = [] | |
for img in images: | |
img_tensors.append(transforms(img)) # Add batch dimension | |
return img_tensors | |
def run_inference(config, nnet, autoencoder, img_tensors): | |
set_seed(config.seed) | |
img_tensors = [tensor.to(config.device) for tensor in img_tensors] | |
# Create a context tensor for all modalities | |
img_contexts = torch.randn(config.num_modalities, 1, 2 * config.z_shape[0], | |
config.z_shape[1], config.z_shape[2], device=config.device) | |
with torch.no_grad(): | |
# Encode the input images with autoencoder | |
z_conds = [autoencoder.encode_moments(tensor.unsqueeze(0)) for tensor in img_tensors] | |
# Create mapping of conditional modalities indices to the encoded inputs | |
cond_indices = [i for i, is_cond in enumerate(config.condition_modalities_mask) if is_cond] | |
# Check if we have the right number of inputs | |
if len(cond_indices) != len(z_conds): | |
raise ValueError(f"Number of conditioning modalities ({len(cond_indices)}) must match number of input images ({len(z_conds)})") | |
# Assign each encoded input to the corresponding modality | |
for i, z_cond in zip(cond_indices, z_conds): | |
img_contexts[i] = z_cond | |
# Sample values from the distribution (mean and variance) | |
z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts]) | |
# Generate initial noise for the modalities being generated | |
_z_init = torch.randn(len(config.generate_modalities), 1, *z_imgs[0].shape[1:], device=config.device) | |
def combine_joint(z_list): | |
"""Combine individual modality tensors into a single concatenated tensor""" | |
return torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z_list], dim=-1) | |
def split_joint(x, z_imgs, config): | |
""" | |
Split the combined tensor back into individual modality tensors | |
and arrange them according to the full set of modalities | |
""" | |
C, H, W = config.z_shape | |
z_dim = C * H * W | |
z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1) | |
z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) | |
for z_i, modality in zip(z_generated, config.generate_modalities)} | |
z = [] | |
for i, modality in enumerate(config.modalities): | |
if modality in config.generate_modalities: # Modalities that are being denoised | |
z.append(z_generated[modality]) | |
elif modality in config.condition_modalities: # Modalities that are being conditioned on | |
z.append(z_imgs[i]) | |
else: # Modalities that are ignored | |
z.append(torch.randn(x.shape[0], C, H, W, device=config.device)) | |
return z | |
_x_init = combine_joint(_z_init) # Initial tensor for the modalities being generated | |
_betas = stable_diffusion_beta_schedule() | |
N = len(_betas) | |
def model_fn(x, t_continuous): | |
t = t_continuous * N | |
# Create timesteps for each modality based on the generate mask | |
timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask] | |
# Split the input into a list of tensors for all modalities | |
z = split_joint(x, z_imgs, config) | |
# Call the network with the right format | |
z_out = nnet(z, t_imgs=timesteps) | |
# Select only the generated modalities for the denoising process | |
z_out_generated = [z_out[i] | |
for i, modality in enumerate(config.modalities) | |
if modality in config.generate_modalities] | |
# Combine the outputs back into a single tensor | |
return combine_joint(z_out_generated) | |
# Sample using the DPM-Solver with exact parameters from sample_n_triffuser.py | |
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=config.device).float()) | |
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) | |
# Generate samples | |
with torch.no_grad(): | |
with torch.autocast(device_type=config.device): | |
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) | |
# Split the result back into individual modality tensors | |
_zs = split_joint(x, z_imgs, config) | |
# Replace conditional modalities with the original images | |
for i, mask in enumerate(config.condition_modalities_mask): | |
if mask: | |
_zs[i] = z_imgs[i] | |
# Decode and unprocess the generated samples | |
generated_samples = [] | |
for i, modality in enumerate(config.modalities): | |
if modality in config.generate_modalities: | |
sample = autoencoder.decode(_zs[i]) # Decode the latent representation | |
sample = unpreprocess(sample) # Unpreprocess to [0, 1] range | |
generated_samples.append((modality, sample)) | |
return generated_samples | |
def custom_inference(images, generate_modalities, condition_modalities, num_inference_steps, seed=None): | |
""" | |
Run custom inference with user-specified parameters | |
Args: | |
generate_modalities: List of modalities to generate | |
condition_modalities: List of modalities to condition on | |
image_paths: Path to conditioning image or list of paths (ordered to match condition_modalities) | |
Returns: | |
Dict mapping modality names to generated tensors | |
""" | |
if seed is None: | |
seed = random.randint(0, int(1e8)) | |
img_tensors = prepare_images(images) | |
config = get_config(generate_modalities, condition_modalities, seed=seed) | |
config.sample.sample_steps = num_inference_steps | |
generated_samples = run_inference(config, nnet, autoencoder, img_tensors) | |
results = {modality: tensor for modality, tensor in generated_samples} | |
return results | |
def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, seed_number, ignore_seed): | |
seed = seed_number if not ignore_seed else None | |
s2l2a_active = s2l2a_input is not None | |
s2l1c_active = s2l1c_input is not None | |
s1rtc_active = s1rtc_input is not None | |
dem_active = dem_input is not None | |
if s2l2a_active and s2l1c_active and s1rtc_active and dem_active: | |
gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.") | |
return s2l1c_input, s2l2a_input, s1rtc_input, dem_input | |
# Instead of collecting in UI order, create ordered dictionaries | |
input_images = {} | |
if s2l1c_active: | |
input_images['s2_l1c'] = s2l1c_input | |
if s2l2a_active: | |
input_images['s2_l2a'] = s2l2a_input | |
if s1rtc_active: | |
input_images['s1_rtc'] = s1rtc_input | |
if dem_active: | |
input_images['dem'] = dem_input | |
condition_modalities = list(input_images.keys()) | |
# Sort modalities and collect images in the same order | |
sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x)) | |
sorted_images = [input_images[mod] for mod in sorted_modalities] | |
imgs_out = custom_inference( | |
images=sorted_images, | |
generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities], | |
condition_modalities=sorted_modalities, | |
num_inference_steps=num_inference_steps_slider, | |
seed=seed | |
) | |
output = [] | |
# Collect outputs | |
if s2l1c_active: | |
output.append(s2l1c_input) | |
else: | |
output.append(to_PIL(imgs_out['s2_l1c'][0])) | |
if s2l2a_active: | |
output.append(s2l2a_input) | |
else: | |
output.append(to_PIL(imgs_out['s2_l2a'][0])) | |
if s1rtc_active: | |
output.append(s1rtc_input) | |
else: | |
output.append(to_PIL(imgs_out['s1_rtc'][0])) | |
if dem_active: | |
output.append(dem_input) | |
else: | |
output.append(to_PIL(imgs_out['dem'][0])) | |
return output | |