|
import json |
|
import os |
|
import types |
|
from urllib.parse import urlparse |
|
|
|
import cv2 |
|
import diffusers |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
from einops import rearrange |
|
from huggingface_hub import hf_hub_download |
|
from omegaconf import OmegaConf |
|
from PIL import Image, ImageOps |
|
from safetensors.torch import load_file |
|
from torch.nn import functional as F |
|
from torchdiffeq import odeint_adjoint as odeint |
|
|
|
from echoflow.common import instantiate_class_from_config, unscale_latents |
|
from echoflow.common.models import ( |
|
ContrastiveModel, |
|
DiffuserSTDiT, |
|
ResNet18, |
|
SegDiTTransformer2DModel, |
|
) |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
dtype = torch.float32 |
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
B, T, C, H, W = 1, 64, 4, 28, 28 |
|
|
|
VIEWS = ["A4C", "PSAX", "PLAX"] |
|
|
|
|
|
def load_model(path): |
|
if path.startswith("http"): |
|
parsed_url = urlparse(path) |
|
if "huggingface.co" in parsed_url.netloc: |
|
parts = parsed_url.path.strip("/").split("/") |
|
repo_id = "/".join(parts[:2]) |
|
|
|
subfolder = None |
|
if len(parts) > 3: |
|
subfolder = "/".join(parts[4:]) |
|
|
|
local_root = "./tmp" |
|
local_dir = os.path.join(local_root, repo_id.replace("/", "_")) |
|
if subfolder: |
|
local_dir = os.path.join(local_root, subfolder) |
|
os.makedirs(local_root, exist_ok=True) |
|
|
|
config_file = hf_hub_download( |
|
repo_id=repo_id, |
|
subfolder=subfolder, |
|
filename="config.json", |
|
local_dir=local_root, |
|
repo_type="model", |
|
token=os.getenv("READ_HF_TOKEN"), |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
assert os.path.exists(config_file) |
|
|
|
hf_hub_download( |
|
repo_id=repo_id, |
|
filename="diffusion_pytorch_model.safetensors", |
|
subfolder=subfolder, |
|
local_dir=local_root, |
|
local_dir_use_symlinks=False, |
|
token=os.getenv("READ_HF_TOKEN"), |
|
) |
|
|
|
path = local_dir |
|
|
|
model_root = os.path.join(config_file.split("config.json")[0]) |
|
json_path = os.path.join(model_root, "config.json") |
|
assert os.path.exists(json_path) |
|
|
|
with open(json_path, "r") as f: |
|
config = json.load(f) |
|
|
|
klass_name = config["_class_name"] |
|
klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None) |
|
assert ( |
|
klass is not None |
|
), f"Could not find class {klass_name} in diffusers or global scope." |
|
assert hasattr( |
|
klass, "from_pretrained" |
|
), f"Class {klass_name} does not support 'from_pretrained'." |
|
|
|
return klass.from_pretrained(path) |
|
|
|
|
|
def load_reid(path): |
|
parsed_url = urlparse(path) |
|
parts = parsed_url.path.strip("/").split("/") |
|
repo_id = "/".join(parts[:2]) |
|
subfolder = "/".join(parts[4:]) |
|
|
|
local_root = "./tmp" |
|
|
|
config_file = hf_hub_download( |
|
repo_id=repo_id, |
|
subfolder=subfolder, |
|
filename="config.yaml", |
|
local_dir=local_root, |
|
repo_type="model", |
|
token=os.getenv("READ_HF_TOKEN"), |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
weights_file = hf_hub_download( |
|
repo_id=repo_id, |
|
subfolder=subfolder, |
|
filename="backbone.safetensors", |
|
local_dir=local_root, |
|
repo_type="model", |
|
token=os.getenv("READ_HF_TOKEN"), |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
config = OmegaConf.load(config_file) |
|
backbone = instantiate_class_from_config(config.backbone) |
|
backbone = ContrastiveModel.patch_backbone( |
|
backbone, config.model.args.in_channels, config.model.args.out_channels |
|
) |
|
state_dict = load_file(weights_file) |
|
backbone.load_state_dict(state_dict) |
|
backbone = backbone.to(device, dtype=dtype) |
|
backbone.eval() |
|
return backbone |
|
|
|
|
|
def get_vae_scaler(path): |
|
scaler = torch.load(path) |
|
scaler = {k: v.to(device) for k, v in scaler.items()} |
|
return scaler |
|
|
|
|
|
|
|
|
|
lifm = load_model("https://huggingface.co./HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4") |
|
lifm = lifm.to(device, dtype=dtype) |
|
lifm.eval() |
|
|
|
vae = load_model("https://huggingface.co./HReynaud/EchoFlow/tree/main/vae/avae-4f4") |
|
vae = vae.to(device, dtype=dtype) |
|
vae.eval() |
|
vae_scaler = get_vae_scaler("assets/scaling.pt") |
|
|
|
reid = { |
|
"anatomies": { |
|
"A4C": torch.cat( |
|
[ |
|
torch.load("assets/anatomies_dynamic.pt"), |
|
torch.load("assets/anatomies_ped_a4c.pt"), |
|
], |
|
dim=0, |
|
), |
|
"PSAX": torch.load("assets/anatomies_ped_psax.pt"), |
|
"PLAX": torch.load("assets/anatomies_lvh.pt"), |
|
}, |
|
"models": { |
|
"A4C": load_reid( |
|
"https://huggingface.co./HReynaud/EchoFlow/tree/main/reid/dynamic-4f4" |
|
), |
|
"PSAX": load_reid( |
|
"https://huggingface.co./HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4" |
|
), |
|
"PLAX": load_reid( |
|
"https://huggingface.co./HReynaud/EchoFlow/tree/main/reid/lvh-4f4" |
|
), |
|
}, |
|
"tau": { |
|
"A4C": 0.9997, |
|
"PSAX": 0.9997, |
|
"PLAX": 0.9997, |
|
}, |
|
} |
|
|
|
lvfm = load_model("https://huggingface.co./HReynaud/EchoFlow/tree/main/lvfm/FMvT-S2-4f4") |
|
lvfm = lvfm.to(device, dtype=dtype) |
|
lvfm.eval() |
|
|
|
|
|
def load_default_mask(): |
|
"""Load the default mask from disk. If not found, return a blank black mask.""" |
|
default_mask_path = os.path.join("assets", "default_mask.png") |
|
try: |
|
if os.path.exists(default_mask_path): |
|
mask = Image.open(default_mask_path).convert("L") |
|
|
|
mask = mask.resize((400, 400), Image.Resampling.LANCZOS) |
|
|
|
mask = ImageOps.autocontrast(mask, cutoff=0) |
|
return np.array(mask) |
|
except Exception as e: |
|
print(f"Error loading default mask: {e}") |
|
|
|
|
|
return np.zeros((400, 400), dtype=np.uint8) |
|
|
|
|
|
def preprocess_mask(mask): |
|
"""Ensure mask is properly formatted for the model.""" |
|
if mask is None: |
|
return np.zeros((112, 112), dtype=np.uint8) |
|
|
|
|
|
if isinstance(mask, dict) and "composite" in mask: |
|
|
|
mask = mask["composite"] |
|
|
|
|
|
if isinstance(mask, np.ndarray): |
|
mask_pil = Image.fromarray(mask) |
|
else: |
|
mask_pil = mask |
|
|
|
|
|
mask_pil = mask_pil.convert("L") |
|
|
|
|
|
mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0) |
|
|
|
|
|
mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0) |
|
|
|
|
|
|
|
|
|
|
|
mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS) |
|
|
|
|
|
return np.array(mask_pil) |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def generate_latent_image(mask, class_selection, sampling_steps=50): |
|
"""Generate a latent image based on mask, class selection, and sampling steps""" |
|
|
|
|
|
mask = preprocess_mask(mask) |
|
mask = torch.from_numpy(mask).to(device, dtype=dtype) |
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
mask = F.interpolate(mask, size=(H, W), mode="bilinear", align_corners=False) |
|
mask = 1.0 * (mask > 0) |
|
|
|
|
|
|
|
|
|
class_idx = VIEWS.index(class_selection) |
|
class_idx = torch.tensor([class_idx], device=device, dtype=torch.long) |
|
|
|
|
|
timesteps = torch.linspace( |
|
1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype |
|
) |
|
|
|
forward_kwargs = { |
|
"class_labels": class_idx, |
|
"segmentation": mask, |
|
} |
|
|
|
z_1 = torch.randn( |
|
(B, C, H, W), |
|
device=device, |
|
dtype=dtype, |
|
|
|
) |
|
|
|
lifm.forward_original = lifm.forward |
|
|
|
def new_forward(self, t, y, *args, **kwargs): |
|
kwargs = {**kwargs, **forward_kwargs} |
|
return self.forward_original(y, t.view(1), *args, **kwargs).sample |
|
|
|
lifm.forward = types.MethodType(new_forward, lifm) |
|
|
|
|
|
with torch.autocast("cuda"): |
|
latent_image = odeint( |
|
lifm, |
|
z_1, |
|
timesteps, |
|
atol=1e-5, |
|
rtol=1e-5, |
|
adjoint_params=lifm.parameters(), |
|
method="euler", |
|
)[-1] |
|
|
|
lifm.forward = lifm.forward_original |
|
|
|
latent_image = latent_image.detach().cpu().numpy() |
|
|
|
|
|
|
|
return latent_image |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def decode_images(latents): |
|
"""Decode latent representations to pixel space using a VAE. |
|
|
|
Args: |
|
latents: A numpy array of shape [B, C, H, W] for single image |
|
or [B, C, T, H, W] for sequences/animations |
|
|
|
Returns: |
|
numpy array of decoded images in [B, H, W, 3] format for single image |
|
or [B, C, T, H, W] for sequences |
|
""" |
|
global vae |
|
if latents is None: |
|
return None |
|
|
|
vae = vae.to(device, dtype=dtype) |
|
vae.eval() |
|
|
|
|
|
if not isinstance(latents, torch.Tensor): |
|
latents = torch.from_numpy(latents).to(device, dtype=dtype) |
|
|
|
|
|
latents = unscale_latents(latents, vae_scaler) |
|
|
|
|
|
is_sequence = len(latents.shape) == 5 |
|
|
|
|
|
|
|
if is_sequence: |
|
B, C, T, H, W = latents.shape |
|
latents = rearrange(latents[0], "c t h w -> t c h w") |
|
else: |
|
B, C, H, W = latents.shape |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
decoded = [] |
|
for i in range(latents.shape[0]): |
|
decoded.append(vae.decode(latents[i : i + 1].float()).sample) |
|
decoded = torch.cat(decoded, dim=0) |
|
|
|
decoded = (decoded + 1) * 128 |
|
decoded = decoded.clamp(0, 255).to(torch.uint8).cpu() |
|
|
|
if is_sequence: |
|
|
|
decoded = rearrange(decoded, "t c h w -> c t h w").unsqueeze(0) |
|
else: |
|
decoded = decoded.squeeze() |
|
decoded = decoded.permute(1, 2, 0) |
|
|
|
|
|
return decoded.numpy() |
|
|
|
|
|
def decode_latent_to_pixel(latent_image): |
|
"""Decode a single latent image to pixel space""" |
|
if latent_image is None: |
|
return None |
|
|
|
|
|
if len(latent_image.shape) == 3: |
|
latent_image = latent_image[None, ...] |
|
|
|
decoded_image = decode_images(latent_image) |
|
decoded_image = cv2.resize( |
|
decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST |
|
) |
|
|
|
return decoded_image |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def check_privacy(latent_image_numpy, class_selection): |
|
"""Check if the latent image is too similar to database images""" |
|
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype) |
|
reid_model = reid["models"][class_selection].to(device, dtype=dtype) |
|
real_anatomies = reid["anatomies"][class_selection] |
|
tau = reid["tau"][class_selection] |
|
|
|
with torch.no_grad(): |
|
features = reid_model(latent_image).sigmoid().cpu() |
|
|
|
corr = torch.corrcoef(torch.cat([real_anatomies, features], dim=0))[0, 1:] |
|
corr = corr.max() |
|
|
|
if corr > tau: |
|
return ( |
|
None, |
|
f"⚠️ **Warning:** Generated image is too similar to training data. Privacy check failed.", |
|
) |
|
else: |
|
return ( |
|
latent_image_numpy, |
|
f"✅ **Success:** Generated image passed privacy check.", |
|
) |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def generate_animation( |
|
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0 |
|
): |
|
"""Generate an animated sequence of latent images based on EF""" |
|
|
|
|
|
|
|
|
|
print("Generating animation...") |
|
|
|
if latent_image is None: |
|
return None |
|
|
|
lvefs = torch.tensor([ejection_fraction / 100.0], device=device, dtype=dtype) |
|
lvefs = lvefs[:, None, None].to(device, dtype) |
|
uncond_lvefs = -1 * torch.ones_like(lvefs) |
|
|
|
ref_images = torch.from_numpy(latent_image).to(device, dtype) |
|
ref_images = ref_images[:, :, None, :, :] |
|
ref_images = ref_images.repeat(1, 1, T, 1, 1) |
|
uncond_images = torch.zeros_like(ref_images) |
|
|
|
timesteps = torch.linspace( |
|
1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype |
|
) |
|
|
|
forward_kwargs = { |
|
"encoder_hidden_states": lvefs, |
|
"cond_image": ref_images, |
|
} |
|
|
|
z_1 = torch.randn( |
|
(B, C, T, H, W), |
|
device=device, |
|
dtype=dtype, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lvfm.forward_original = lvfm.forward |
|
|
|
def new_forward(self, t, y, *args, **kwargs): |
|
kwargs = {**kwargs, **forward_kwargs} |
|
|
|
|
|
pred = self.forward_original(y, t.repeat(y.size(0)), *args, **kwargs).sample |
|
|
|
if cfg_scale != 1.0: |
|
uncond_kwargs = { |
|
"encoder_hidden_states": uncond_lvefs, |
|
"cond_image": uncond_images, |
|
} |
|
uncond_pred = self.forward_original( |
|
y, t.repeat(y.size(0)), *args, **uncond_kwargs |
|
).sample |
|
|
|
pred = uncond_pred + cfg_scale * (pred - uncond_pred) |
|
|
|
return pred |
|
|
|
lvfm.forward = types.MethodType(new_forward, lvfm) |
|
|
|
with torch.autocast("cuda"): |
|
synthetic_video = odeint( |
|
lvfm, |
|
z_1, |
|
timesteps, |
|
atol=1e-5, |
|
rtol=1e-5, |
|
adjoint_params=lvfm.parameters(), |
|
method="euler", |
|
)[-1] |
|
|
|
lvfm.forward = lvfm.forward_original |
|
|
|
|
|
|
|
print("Animation generated") |
|
|
|
return synthetic_video.detach().cpu() |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def decode_animation(latent_animation): |
|
"""Decode a latent animation to pixel space""" |
|
if latent_animation is None: |
|
return None |
|
|
|
|
|
if not isinstance(latent_animation, torch.Tensor): |
|
latent_animation = torch.from_numpy(latent_animation) |
|
latent_animation = latent_animation.to(device, dtype=dtype) |
|
|
|
|
|
if len(latent_animation.shape) == 4: |
|
latent_animation = latent_animation[None, ...] |
|
|
|
|
|
decoded = decode_images(latent_animation) |
|
|
|
|
|
decoded = np.transpose(decoded[0], (1, 2, 3, 0)) |
|
|
|
|
|
decoded = np.stack( |
|
[ |
|
cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST) |
|
for frame in decoded |
|
] |
|
) |
|
|
|
|
|
temp_file = "temp_video_2.mp4" |
|
fps = 32 |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400)) |
|
|
|
|
|
for frame in decoded: |
|
out.write(frame) |
|
out.release() |
|
|
|
return temp_file |
|
|
|
|
|
def convert_latent_to_display(latent_image): |
|
"""Convert multi-channel latent image to grayscale for display""" |
|
if latent_image is None: |
|
return None |
|
|
|
|
|
if len(latent_image.shape) == 4: |
|
|
|
display_image = np.squeeze(latent_image, axis=0) |
|
display_image = np.mean(display_image, axis=0) |
|
elif len(latent_image.shape) == 3: |
|
|
|
display_image = np.mean(latent_image, axis=0) |
|
else: |
|
display_image = latent_image |
|
|
|
|
|
display_image = (display_image - display_image.min()) / ( |
|
display_image.max() - display_image.min() + 1e-8 |
|
) |
|
|
|
|
|
display_image = (display_image * 255).astype(np.uint8) |
|
|
|
|
|
display_image = cv2.resize( |
|
display_image, (400, 400), interpolation=cv2.INTER_NEAREST |
|
) |
|
|
|
return display_image |
|
|
|
|
|
@spaces.GPU(duration=3) |
|
@torch.no_grad() |
|
def latent_animation_to_grayscale(latent_animation): |
|
"""Convert multi-channel latent animation to grayscale for display""" |
|
if latent_animation is None: |
|
return None |
|
|
|
|
|
|
|
|
|
if torch.is_tensor(latent_animation): |
|
latent_animation = latent_animation.detach().cpu().numpy() |
|
|
|
|
|
if len(latent_animation.shape) == 5: |
|
latent_animation = np.squeeze(latent_animation, axis=0) |
|
latent_animation = np.transpose(latent_animation, (1, 0, 2, 3)) |
|
|
|
|
|
|
|
|
|
latent_animation = np.mean(latent_animation, axis=1) |
|
|
|
|
|
|
|
|
|
min_vals = latent_animation.min(axis=(1, 2), keepdims=True) |
|
max_vals = latent_animation.max(axis=(1, 2), keepdims=True) |
|
latent_animation = (latent_animation - min_vals) / (max_vals - min_vals + 1e-8) |
|
|
|
|
|
latent_animation = (latent_animation * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
|
resized_frames = [] |
|
for frame in latent_animation: |
|
resized = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST) |
|
resized_frames.append(resized) |
|
|
|
|
|
grayscale_video = np.stack(resized_frames) |
|
|
|
|
|
|
|
|
|
grayscale_video = grayscale_video[..., None].repeat(3, axis=-1) |
|
|
|
|
|
|
|
|
|
temp_file = "temp_video.mp4" |
|
fps = 32 |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400)) |
|
|
|
|
|
for frame in grayscale_video: |
|
out.write(frame) |
|
|
|
out.release() |
|
|
|
return temp_file |
|
|
|
|
|
|
|
def load_view_mask(view): |
|
mask_path = f"assets/{view.lower()}_seg.png" |
|
try: |
|
mask_image = Image.open(mask_path).convert("L") |
|
mask_image = mask_image.resize((400, 400), Image.Resampling.LANCZOS) |
|
|
|
mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
|
mask_array = np.array(mask_image) |
|
|
|
|
|
editor_value = { |
|
"background": np.zeros((400, 400), dtype=np.uint8), |
|
"layers": [mask_array], |
|
"composite": mask_array, |
|
} |
|
return editor_value |
|
except Exception as e: |
|
print(f"Error loading mask for view {view}: {e}") |
|
return None |
|
|
|
|
|
custom_js = """ |
|
<script> |
|
console.log("Hello, world!"); |
|
(function() { |
|
// Poll every 100ms for the existence of the header row |
|
const intervalId = setInterval(() => { |
|
console.log("Polling for header row"); |
|
const headerRow = document.querySelector("tr.tr-head"); |
|
if (headerRow) { |
|
const headers = headerRow.querySelectorAll("th"); |
|
headers.forEach(cell => { |
|
const text = cell.innerText.trim(); |
|
if (text === "Binary Mask") { |
|
cell.innerText = "Mask"; |
|
} else if (text === "View Class") { |
|
cell.innerText = "View"; |
|
} else if (text === "Number of Sampling Steps") { |
|
cell.innerText = "Img Samp. Steps"; |
|
} else if (text === "Ejection Fraction (%)") { |
|
cell.innerText = "EF %"; |
|
} else if (text === "Number of Sampling Steps.") { |
|
cell.innerText = "Video Samp. Steps"; |
|
} else if (text === "Classifier-Free Guidance Scale") { |
|
cell.innerText = "CFG"; |
|
} else if (text === "Filtered Latent Image") { |
|
cell.innerText = "Filtered Image"; |
|
} |
|
}); |
|
clearInterval(intervalId); |
|
console.log("Headers updated."); |
|
} |
|
}, 500); |
|
})(); |
|
</script> |
|
""" |
|
|
|
|
|
def create_demo(): |
|
|
|
black_background = np.zeros((400, 400), dtype=np.uint8) |
|
|
|
|
|
try: |
|
mask_image = Image.open("assets/a4c_seg.png").convert("L") |
|
mask_image = mask_image.resize((400, 400), Image.Resampling.LANCZOS) |
|
|
|
mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
|
mask_image = mask_image.point(lambda p: 255 if p > 127 else 0) |
|
mask_array = np.array(mask_image) |
|
|
|
|
|
editor_value = { |
|
"background": black_background, |
|
"layers": [mask_array], |
|
"composite": mask_array, |
|
} |
|
except Exception as e: |
|
print(f"Error loading mask image: {e}") |
|
|
|
editor_value = black_background |
|
|
|
|
|
mask_input = gr.ImageEditor( |
|
label="Binary Mask", |
|
height=400, |
|
width=400, |
|
image_mode="L", |
|
value=editor_value, |
|
type="numpy", |
|
brush=gr.Brush( |
|
colors=["#ffffff"], |
|
color_mode="fixed", |
|
default_size=20, |
|
default_color="#ffffff", |
|
), |
|
eraser=gr.Eraser(default_size=20), |
|
show_download_button=True, |
|
sources=[], |
|
canvas_size=(400, 400), |
|
fixed_canvas=True, |
|
layers=False, |
|
render=False, |
|
) |
|
|
|
class_selection = gr.Radio( |
|
choices=["A4C", "PSAX", "PLAX"], |
|
label="View Class", |
|
value="A4C", |
|
render=False, |
|
) |
|
|
|
sampling_steps = gr.Slider( |
|
minimum=1, |
|
maximum=200, |
|
value=100, |
|
step=1, |
|
label="Number of Sampling Steps", |
|
render=False, |
|
) |
|
|
|
ef_slider = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
value=65, |
|
label="Ejection Fraction (%)", |
|
render=False, |
|
) |
|
|
|
animation_steps = gr.Slider( |
|
minimum=1, |
|
maximum=200, |
|
value=100, |
|
step=1, |
|
label="Number of Sampling Steps.", |
|
render=False, |
|
) |
|
|
|
cfg_slider = gr.Slider( |
|
minimum=0, |
|
maximum=10, |
|
value=1, |
|
step=1, |
|
label="Classifier-Free Guidance Scale", |
|
render=False, |
|
) |
|
|
|
latent_image_display = gr.Image( |
|
label="Latent Image", |
|
type="numpy", |
|
height=400, |
|
width=400, |
|
render=False, |
|
) |
|
|
|
decoded_image_display = gr.Image( |
|
label="Decoded Image", |
|
type="numpy", |
|
height=400, |
|
width=400, |
|
render=False, |
|
) |
|
|
|
privacy_status = gr.Markdown(render=False) |
|
|
|
filtered_latent_display = gr.Image( |
|
label="Filtered Latent Image", |
|
type="numpy", |
|
height=400, |
|
width=400, |
|
render=False, |
|
) |
|
|
|
latent_animation_display = gr.Video( |
|
label="Latent Video", |
|
format="mp4", |
|
render=False, |
|
autoplay=True, |
|
loop=True, |
|
) |
|
|
|
decoded_animation_display = gr.Video( |
|
label="Decoded Video", |
|
format="mp4", |
|
render=False, |
|
autoplay=True, |
|
loop=True, |
|
) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), head=custom_js) as demo: |
|
gr.Markdown( |
|
"# EchoFlow: A Foundation Model for Cardiac Ultrasound Image and Video Generation" |
|
) |
|
gr.Markdown("## Preprint: https://arxiv.org/abs/2503.22357") |
|
gr.Markdown("## Dataset Generation Pipeline") |
|
|
|
gr.Markdown( |
|
""" |
|
This demo showcases EchoFlow's ability to generate synthetic echocardiogram images and videos while preserving patient privacy. The pipeline consists of four main steps: |
|
|
|
1. **Latent Image Generation**: Draw a mask to indicate the region where the Left Ventricle should appear. Select the desired cardiac view, and click "Generate Latent Image". This outputs a latent image, which can be decoded into a pixel space image by clicking "Decode to Pixel Space". |
|
2. **Privacy Filter**: When clicking "Run Privacy Check", the generated image will be checked against a database of all training anatomies to ensure it is sufficiently different from real patient data. |
|
3. **Latent Video Generation**: If the privacy check passes, the latent image can be animated into a video with the desired Ejection Fraction. |
|
4. **Video Decoding**: The video can be decoded back to pixel space by clicking "Decode Video". |
|
|
|
### ⚙️ Parameters |
|
- **Sampling Steps**: Higher values produce better quality but take longer |
|
- **Ejection Fraction**: Controls the strength of heart contraction in the animation |
|
- **CFG Scale**: Controls how closely the animation follows the specified conditions |
|
""" |
|
) |
|
|
|
def load_example( |
|
mask, |
|
view, |
|
steps, |
|
ef, |
|
anim_steps, |
|
cfg, |
|
latent, |
|
decoded, |
|
status, |
|
filtered, |
|
latent_vid, |
|
decoded_vid, |
|
): |
|
|
|
|
|
return [ |
|
mask, |
|
view, |
|
steps, |
|
ef, |
|
anim_steps, |
|
cfg, |
|
latent, |
|
decoded, |
|
status, |
|
filtered, |
|
latent_vid, |
|
decoded_vid, |
|
] |
|
|
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
|
|
[ |
|
|
|
{ |
|
"background": np.zeros((400, 400), dtype=np.uint8), |
|
"layers": [ |
|
np.array( |
|
Image.open("assets/a4c_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
) |
|
], |
|
"composite": np.array( |
|
Image.open("assets/a4c_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
), |
|
}, |
|
"A4C", |
|
100, |
|
65, |
|
100, |
|
1.0, |
|
|
|
Image.open("assets/examples/a4c_latent.png"), |
|
Image.open("assets/examples/a4c_decoded.png"), |
|
"✅ **Success:** Generated image passed privacy check.", |
|
Image.open("assets/examples/a4c_filtered.png"), |
|
"assets/examples/a4c_latent.mp4", |
|
"assets/examples/a4c_decoded.mp4", |
|
], |
|
|
|
[ |
|
|
|
{ |
|
"background": np.zeros((400, 400), dtype=np.uint8), |
|
"layers": [ |
|
np.array( |
|
Image.open("assets/psax_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
) |
|
], |
|
"composite": np.array( |
|
Image.open("assets/psax_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
), |
|
}, |
|
"PSAX", |
|
100, |
|
65, |
|
100, |
|
1.0, |
|
|
|
Image.open("assets/examples/psax_latent.png"), |
|
Image.open("assets/examples/psax_decoded.png"), |
|
"✅ **Success:** Generated image passed privacy check.", |
|
Image.open("assets/examples/psax_filtered.png"), |
|
"assets/examples/psax_latent.mp4", |
|
"assets/examples/psax_decoded.mp4", |
|
], |
|
|
|
[ |
|
|
|
{ |
|
"background": np.zeros((400, 400), dtype=np.uint8), |
|
"layers": [ |
|
np.array( |
|
Image.open("assets/plax_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
) |
|
], |
|
"composite": np.array( |
|
Image.open("assets/plax_seg.png") |
|
.convert("L") |
|
.resize((400, 400)) |
|
), |
|
}, |
|
"PLAX", |
|
100, |
|
65, |
|
100, |
|
1.0, |
|
|
|
Image.open("assets/examples/plax_latent.png"), |
|
Image.open("assets/examples/plax_decoded.png"), |
|
"✅ **Success:** Generated image passed privacy check.", |
|
Image.open("assets/examples/plax_filtered.png"), |
|
"assets/examples/plax_latent.mp4", |
|
"assets/examples/plax_decoded.mp4", |
|
], |
|
], |
|
inputs=[ |
|
mask_input, |
|
class_selection, |
|
sampling_steps, |
|
ef_slider, |
|
animation_steps, |
|
cfg_slider, |
|
latent_image_display, |
|
decoded_image_display, |
|
privacy_status, |
|
filtered_latent_display, |
|
latent_animation_display, |
|
decoded_animation_display, |
|
], |
|
fn=load_example, |
|
label="Click on an example to see the results immediately.", |
|
examples_per_page=3, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Markdown( |
|
'<img src="https://huggingface.co./spaces/HReynaud/EchoFlow/resolve/main/assets/h1.png" style="width: 100%; height: 75px; object-fit: contain;">' |
|
) |
|
gr.Markdown("### Latent Image Generation") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("Draw the LV mask (white = region of interest)") |
|
|
|
black_background = np.zeros((400, 400), dtype=np.uint8) |
|
|
|
|
|
try: |
|
mask_image = Image.open("assets/a4c_seg.png").convert("L") |
|
mask_image = mask_image.resize( |
|
(400, 400), Image.Resampling.LANCZOS |
|
) |
|
|
|
mask_image = ImageOps.autocontrast(mask_image, cutoff=0) |
|
mask_image = mask_image.point( |
|
lambda p: 255 if p > 127 else 0 |
|
) |
|
mask_array = np.array(mask_image) |
|
|
|
|
|
editor_value = { |
|
"background": black_background, |
|
"layers": [mask_array], |
|
"composite": mask_array, |
|
} |
|
except Exception as e: |
|
print(f"Error loading mask image: {e}") |
|
|
|
editor_value = black_background |
|
|
|
|
|
mask_input.render() |
|
class_selection.render() |
|
sampling_steps.render() |
|
|
|
|
|
generate_btn = gr.Button("Generate Latent Image", variant="primary") |
|
|
|
|
|
latent_image_display.render() |
|
|
|
|
|
decode_btn = gr.Button( |
|
"Decode to Pixel Space (Optional)", |
|
interactive=False, |
|
variant="primary", |
|
) |
|
|
|
|
|
decoded_image_display.render() |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown( |
|
'<img src="https://huggingface.co./spaces/HReynaud/EchoFlow/resolve/main/assets/h2.png" style="width: 100%; height: 75px; object-fit: contain;">' |
|
) |
|
gr.Markdown("### Privacy Filter") |
|
gr.Markdown( |
|
"Checks if the generated image is too similar to training data" |
|
) |
|
|
|
|
|
privacy_btn = gr.Button( |
|
"Run Privacy Check", interactive=False, variant="primary" |
|
) |
|
|
|
|
|
privacy_status.render() |
|
|
|
|
|
filtered_latent_display.render() |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown( |
|
'<img src="https://huggingface.co./spaces/HReynaud/EchoFlow/resolve/main/assets/h3.png" style="width: 100%; height: 75px; object-fit: contain;">' |
|
) |
|
gr.Markdown("### Latent Video Generation") |
|
|
|
|
|
ef_slider.render() |
|
animation_steps.render() |
|
cfg_slider.render() |
|
|
|
|
|
animate_btn = gr.Button( |
|
"Generate Video", interactive=False, variant="primary" |
|
) |
|
|
|
|
|
latent_animation_display.render() |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown( |
|
'<img src="https://huggingface.co./spaces/HReynaud/EchoFlow/resolve/main/assets/h4.png" style="width: 100%; height: 75px; object-fit: contain;">' |
|
) |
|
gr.Markdown("### Video Decoding") |
|
|
|
|
|
decode_animation_btn = gr.Button( |
|
"Decode Video", interactive=False, variant="primary" |
|
) |
|
|
|
|
|
decoded_animation_display.render() |
|
|
|
|
|
latent_image_state = gr.State(None) |
|
filtered_latent_state = gr.State(None) |
|
latent_animation_state = gr.State(None) |
|
|
|
|
|
class_selection.change( |
|
fn=load_view_mask, |
|
inputs=[class_selection], |
|
outputs=[mask_input], |
|
queue=False, |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_latent_image, |
|
inputs=[mask_input, class_selection, sampling_steps], |
|
outputs=[latent_image_state], |
|
queue=True, |
|
).then( |
|
fn=convert_latent_to_display, |
|
inputs=[latent_image_state], |
|
outputs=[latent_image_display], |
|
queue=False, |
|
).then( |
|
fn=lambda x: gr.Button( |
|
interactive=x is not None |
|
), |
|
inputs=[latent_image_state], |
|
outputs=[decode_btn], |
|
queue=False, |
|
).then( |
|
fn=lambda x: gr.Button( |
|
interactive=x is not None |
|
), |
|
inputs=[latent_image_state], |
|
outputs=[privacy_btn], |
|
queue=False, |
|
) |
|
|
|
decode_btn.click( |
|
fn=decode_latent_to_pixel, |
|
inputs=[latent_image_state], |
|
outputs=[decoded_image_display], |
|
queue=True, |
|
).then( |
|
fn=lambda x: gr.Button( |
|
interactive=x is not None |
|
), |
|
inputs=[decoded_image_display], |
|
outputs=[privacy_btn], |
|
queue=False, |
|
) |
|
|
|
privacy_btn.click( |
|
fn=check_privacy, |
|
inputs=[latent_image_state, class_selection], |
|
outputs=[filtered_latent_state, privacy_status], |
|
queue=True, |
|
).then( |
|
fn=convert_latent_to_display, |
|
inputs=[filtered_latent_state], |
|
outputs=[filtered_latent_display], |
|
queue=False, |
|
).then( |
|
fn=lambda x: gr.Button( |
|
interactive=x is not None |
|
), |
|
inputs=[filtered_latent_state], |
|
outputs=[animate_btn], |
|
queue=False, |
|
) |
|
|
|
animate_btn.click( |
|
fn=generate_animation, |
|
inputs=[filtered_latent_state, ef_slider, animation_steps, cfg_slider], |
|
outputs=[latent_animation_state], |
|
queue=True, |
|
).then( |
|
fn=latent_animation_to_grayscale, |
|
inputs=[latent_animation_state], |
|
outputs=[latent_animation_display], |
|
queue=False, |
|
).then( |
|
fn=lambda x: gr.Button( |
|
interactive=x is not None |
|
), |
|
inputs=[latent_animation_state], |
|
outputs=[decode_animation_btn], |
|
queue=False, |
|
) |
|
|
|
decode_animation_btn.click( |
|
fn=decode_animation, |
|
inputs=[latent_animation_state], |
|
outputs=[decoded_animation_display], |
|
queue=True, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.launch() |
|
|