File size: 4,118 Bytes
f2de1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image as PILImage, ImageDraw, ImageFont
from imwatermark import WatermarkEncoder

from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from transformers import MT5Tokenizer, MT5EncoderModel
from typing import List, Optional, Tuple, Union

# Determine device and torch dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load MT5 tokenizer and encoder (can be replaced with private model + token if needed)
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", use_safetensors=True)
encoder_model = MT5EncoderModel.from_pretrained("google/mt5-small", use_safetensors=True).to(device=device, dtype=torch_dtype)
encoder_model.eval()

class QPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    def add_watermark(self, img: PILImage.Image) -> PILImage.Image:
        # Resize image to 256, as 128 is too small for watermark
        img = img.resize((256, 256), resample=PILImage.BICUBIC)

        watermark_str = os.getenv("WATERMARK_URL", "hf.co/lqume/new-hanzi")
        encoder = WatermarkEncoder()
        encoder.set_watermark('bytes', watermark_str.encode('utf-8'))

        # Convert PIL image to NumPy array
        img_np = np.asarray(img.convert("RGB"))  # ensure 3-channel RGB
        watermarked_np = encoder.encode(img_np, 'dwtDct')

        # Convert back to PIL
        return PILImage.fromarray(watermarked_np)

    @torch.no_grad()
    def __call__(
        self,
        texts: List[str],
        batch_size: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        num_inference_steps: int = 20,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple[List[PILImage.Image]]]:

        batch_size = len(texts)

        # Tokenize input text
        tokenized = tokenizer(
            texts,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=48
        )
        input_ids = tokenized["input_ids"].to(device=device, dtype=torch.long)
        attention_mask = tokenized["attention_mask"].to(device=device, dtype=torch.long)

        # Encode to latent space
        encoded = encoder_model.encoder(input_ids=input_ids, attention_mask=attention_mask)

        # Prepare noise tensor
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=torch_dtype)

        # Run denoising loop
        self.scheduler.set_timesteps(num_inference_steps)

        for timestep in self.progress_bar(self.scheduler.timesteps):
            noise_pred = self.unet(
                image,
                timestep,
                encoder_hidden_states=encoded.last_hidden_state,
                encoder_attention_mask=attention_mask.bool(),
                return_dict=False
            )[0]

            image = self.scheduler.step(noise_pred, timestep, image, generator=generator, return_dict=False)[0]

        # Final image post-processing
        image = image.clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)
            image = [self.add_watermark(img) for img in image]

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)