File size: 2,460 Bytes
b41a54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35ed471
 
b41a54a
 
35ed471
b41a54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# os

from pathlib import Path

# torch

import torch
import torchvision.transforms.functional as F
from einops import repeat

# Text2Punks and Tokenizer

from text2punks.text2punk import Text2Punks, CLIP
from text2punks.tokenizer import txt_tokenizer

# select device

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load decoder

codebook = torch.load('./text2punks/data/codebook.pt')

# helper fns

def exists(val):
    return val is not None

def resize(image_tensor, size):
    return F.resize(image_tensor, (size, size), F.InterpolationMode.NEAREST)

def to_pil_image(image_tensor):
    return F.to_pil_image(image_tensor.type(torch.uint8))


def model_loader(text2punk_path, clip_path):
    # load pre-trained TEXT2PUNKS model

    text2punk_path = Path(text2punk_path)
    assert text2punk_path.exists(), 'trained Text2Punks must exist'

    load_obj = torch.load(str(text2punk_path), map_location=torch.device(device))
    text2punks_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')

    text2punk = Text2Punks(**text2punks_params).to(device)
    text2punk.load_state_dict(weights)

    # load pre-trained CLIP model

    clip_path = Path(clip_path)
    assert clip_path.exists(), 'trained CLIP must exist'

    load_obj = torch.load(str(clip_path), map_location=torch.device(device))
    clip_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')

    clip = CLIP(**clip_params).to(device)
    clip.load_state_dict(weights)

    return text2punk, clip


def generate_image(prompt_text, top_k, temperature, num_images, batch_size, top_prediction, text2punk_model, clip_model, codebook=codebook):
    text = txt_tokenizer.tokenize(prompt_text, text2punk_model.text_seq_len, truncate_text=True).to(device)

    text = repeat(text, '() n -> b n', b = num_images)

    img_outputs = []
    score_outputs = []

    for text_chunk in text.split(batch_size):
        images, scores = text2punk_model.generate_images(text_chunk, codebook.to(device), clip = clip_model, filter_thres = top_k, temperature = temperature)
        img_outputs.append(images)
        score_outputs.append(scores)

    img_outputs = torch.cat(img_outputs)
    score_outputs = torch.cat(score_outputs)

    similarity = score_outputs.softmax(dim=-1)
    values, indices = similarity.topk(top_prediction)

    img_outputs = img_outputs[indices]
    score_outputs = score_outputs[indices]

    return img_outputs, score_outputs