neochar / utils.py
Liang Qu
Initial commit.
f2de1ca
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image as PILImage, ImageDraw, ImageFont
from imwatermark import WatermarkEncoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from transformers import MT5Tokenizer, MT5EncoderModel
from typing import List, Optional, Tuple, Union
# Determine device and torch dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load MT5 tokenizer and encoder (can be replaced with private model + token if needed)
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", use_safetensors=True)
encoder_model = MT5EncoderModel.from_pretrained("google/mt5-small", use_safetensors=True).to(device=device, dtype=torch_dtype)
encoder_model.eval()
class QPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def add_watermark(self, img: PILImage.Image) -> PILImage.Image:
# Resize image to 256, as 128 is too small for watermark
img = img.resize((256, 256), resample=PILImage.BICUBIC)
watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi")
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', watermark_str.encode('utf-8'))
# Convert PIL image to NumPy array
img_np = np.asarray(img.convert("RGB")) # ensure 3-channel RGB
watermarked_np = encoder.encode(img_np, 'dwtDct')
# Convert back to PIL
return PILImage.fromarray(watermarked_np)
@torch.no_grad()
def __call__(
self,
texts: List[str],
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]:
batch_size = len(texts)
# Tokenize input text
tokenized = tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=48
)
input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long)
attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long)
# Encode to latent space
encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Prepare noise tensor
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype)
# Run denoising loop
self.scheduler.set_timesteps(num_inference_steps)
for timestep in self.progress_bar(self.scheduler.timesteps):
noise_pred = self.unet(
image,
timestep,
encoder_hidden_states=encoded.last_hidden_state,
encoder_attention_mask=attention_mask.bool(),
return_dict=False
)[0]
image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0]
# Final image post-processing
image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = [self.add_watermark(img) for img in image]
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)