Spaces:
Build error
Build error
File size: 2,460 Bytes
b41a54a 35ed471 b41a54a 35ed471 b41a54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# os
from pathlib import Path
# torch
import torch
import torchvision.transforms.functional as F
from einops import repeat
# Text2Punks and Tokenizer
from text2punks.text2punk import Text2Punks, CLIP
from text2punks.tokenizer import txt_tokenizer
# select device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# load decoder
codebook = torch.load('./text2punks/data/codebook.pt')
# helper fns
def exists(val):
return val is not None
def resize(image_tensor, size):
return F.resize(image_tensor, (size, size), F.InterpolationMode.NEAREST)
def to_pil_image(image_tensor):
return F.to_pil_image(image_tensor.type(torch.uint8))
def model_loader(text2punk_path, clip_path):
# load pre-trained TEXT2PUNKS model
text2punk_path = Path(text2punk_path)
assert text2punk_path.exists(), 'trained Text2Punks must exist'
load_obj = torch.load(str(text2punk_path), map_location=torch.device(device))
text2punks_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
text2punk = Text2Punks(**text2punks_params).to(device)
text2punk.load_state_dict(weights)
# load pre-trained CLIP model
clip_path = Path(clip_path)
assert clip_path.exists(), 'trained CLIP must exist'
load_obj = torch.load(str(clip_path), map_location=torch.device(device))
clip_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
clip = CLIP(**clip_params).to(device)
clip.load_state_dict(weights)
return text2punk, clip
def generate_image(prompt_text, top_k, temperature, num_images, batch_size, top_prediction, text2punk_model, clip_model, codebook=codebook):
text = txt_tokenizer.tokenize(prompt_text, text2punk_model.text_seq_len, truncate_text=True).to(device)
text = repeat(text, '() n -> b n', b = num_images)
img_outputs = []
score_outputs = []
for text_chunk in text.split(batch_size):
images, scores = text2punk_model.generate_images(text_chunk, codebook.to(device), clip = clip_model, filter_thres = top_k, temperature = temperature)
img_outputs.append(images)
score_outputs.append(scores)
img_outputs = torch.cat(img_outputs)
score_outputs = torch.cat(score_outputs)
similarity = score_outputs.softmax(dim=-1)
values, indices = similarity.topk(top_prediction)
img_outputs = img_outputs[indices]
score_outputs = score_outputs[indices]
return img_outputs, score_outputs
|