|
import os |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms as T |
|
from PIL import Image as PILImage, ImageDraw, ImageFont |
|
from imwatermark import WatermarkEncoder |
|
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from transformers import MT5Tokenizer, MT5EncoderModel |
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", use_safetensors=True) |
|
encoder_model = MT5EncoderModel.from_pretrained("google/mt5-small", use_safetensors=True).to(device=device, dtype=torch_dtype) |
|
encoder_model.eval() |
|
|
|
class QPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
self.register_modules(unet=unet, scheduler=scheduler) |
|
|
|
def add_watermark(self, img: PILImage.Image) -> PILImage.Image: |
|
|
|
img = img.resize((256, 256), resample=PILImage.BICUBIC) |
|
|
|
watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi") |
|
encoder = WatermarkEncoder() |
|
encoder.set_watermark('bytes', watermark_str.encode('utf-8')) |
|
|
|
|
|
img_np = np.asarray(img.convert("RGB")) |
|
watermarked_np = encoder.encode(img_np, 'dwtDct') |
|
|
|
|
|
return PILImage.fromarray(watermarked_np) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
texts: List[str], |
|
batch_size: int = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
num_inference_steps: int = 20, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]: |
|
|
|
batch_size = len(texts) |
|
|
|
|
|
tokenized = tokenizer( |
|
texts, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=48 |
|
) |
|
input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long) |
|
attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long) |
|
|
|
|
|
encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
if isinstance(self.unet.config.sample_size, int): |
|
image_shape = ( |
|
batch_size, |
|
self.unet.config.in_channels, |
|
self.unet.config.sample_size, |
|
self.unet.config.sample_size, |
|
) |
|
else: |
|
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) |
|
|
|
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
for timestep in self.progress_bar(self.scheduler.timesteps): |
|
noise_pred = self.unet( |
|
image, |
|
timestep, |
|
encoder_hidden_states=encoded.last_hidden_state, |
|
encoder_attention_mask=attention_mask.bool(), |
|
return_dict=False |
|
)[0] |
|
|
|
image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0] |
|
|
|
|
|
image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() |
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
image = [self.add_watermark(img) for img in image] |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return ImagePipelineOutput(images=image) |
|
|
|
|