Spaces:
Build error
Build error
File size: 3,323 Bytes
b41a54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import numpy as np
from PIL import Image, UnidentifiedImageError
from pathlib import Path
from random import randint, choice
import torch
from torch.utils.data import Dataset
class TextImageDataset(Dataset):
def __init__(self,
folder,
text_len=40,
truncate_captions=False,
text_tokenizer=None,
image_tokenizer=None,
shuffle=False
):
"""
@param folder: Folder containing images and text files matched by their paths' respective "stem"
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
"""
super().__init__()
self.shuffle = shuffle
path = Path(folder)
text_files = [*path.glob('**/*.txt')]
image_files = [
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
]
text_files = {text_file.stem: text_file for text_file in text_files}
image_files = {image_file.stem: image_file for image_file in image_files}
keys = (image_files.keys() & text_files.keys())
self.keys = list(keys)
self.text_files = {k: v for k, v in text_files.items() if k in keys}
self.image_files = {k: v for k, v in image_files.items() if k in keys}
self.text_len = text_len
self.truncate_captions = truncate_captions
self.text_tokenizer = text_tokenizer
self.image_tokenizer = image_tokenizer
def __len__(self):
return len(self.keys)
def random_sample(self):
return self.__getitem__(randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def __getitem__(self, ind):
key = self.keys[ind]
text_file = self.text_files[key]
image_file = self.image_files[key]
descriptions = text_file.read_text().split('\n')
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
try:
description = choice(descriptions)
except IndexError as zero_captions_in_file_ex:
print(f"An exception occurred trying to load file {text_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
tokenized_text = self.text_tokenizer.tokenize(
description,
self.text_len,
truncate_text=self.truncate_captions
).squeeze(0)
try:
image = Image.open(image_file).convert('RGB')
pixels = np.array(image).reshape(-1, 3)
tokenized_image = [self.image_tokenizer[str(idx)] for idx in pixels]
tokenized_image = torch.tensor(tokenized_image)
except (UnidentifiedImageError, OSError) as corrupt_image_exceptions:
print(f"An exception occurred trying to load file {image_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
# Success
return tokenized_text, tokenized_image
|