import os import numpy as np import torch from torch import nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms as T from PIL import Image as PILImage, ImageDraw, ImageFont from imwatermark import WatermarkEncoder from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.utils.torch_utils import randn_tensor from transformers import MT5Tokenizer, MT5EncoderModel from typing import List, Optional, Tuple, Union # Determine device and torch dtype device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load MT5 tokenizer and encoder (can be replaced with private model + token if needed) tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", use_safetensors=True) encoder_model = MT5EncoderModel.from_pretrained("google/mt5-small", use_safetensors=True).to(device=device, dtype=torch_dtype) encoder_model.eval() class QPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) def add_watermark(self, img: PILImage.Image) -> PILImage.Image: # Resize image to 256, as 128 is too small for watermark img = img.resize((256, 256), resample=PILImage.BICUBIC) watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi") encoder = WatermarkEncoder() encoder.set_watermark('bytes', watermark_str.encode('utf-8')) # Convert PIL image to NumPy array img_np = np.asarray(img.convert("RGB")) # ensure 3-channel RGB watermarked_np = encoder.encode(img_np, 'dwtDct') # Convert back to PIL return PILImage.fromarray(watermarked_np) @torch.no_grad() def __call__( self, texts: List[str], batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 20, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]: batch_size = len(texts) # Tokenize input text tokenized = tokenizer( texts, return_tensors="pt", padding="max_length", truncation=True, max_length=48 ) input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long) attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long) # Encode to latent space encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask) # Prepare noise tensor if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype) # Run denoising loop self.scheduler.set_timesteps(num_inference_steps) for timestep in self.progress_bar(self.scheduler.timesteps): noise_pred = self.unet( image, timestep, encoder_hidden_states=encoded.last_hidden_state, encoder_attention_mask=attention_mask.bool(), return_dict=False )[0] image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0] # Final image post-processing image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) image = [self.add_watermark(img) for img in image] if not return_dict: return (image,) return ImagePipelineOutput(images=image)