Spaces:
Build error
Build error
import numpy as np | |
from PIL import Image, UnidentifiedImageError | |
from pathlib import Path | |
from random import randint, choice | |
import torch | |
from torch.utils.data import Dataset | |
class TextImageDataset(Dataset): | |
def __init__(self, | |
folder, | |
text_len=40, | |
truncate_captions=False, | |
text_tokenizer=None, | |
image_tokenizer=None, | |
shuffle=False | |
): | |
""" | |
@param folder: Folder containing images and text files matched by their paths' respective "stem" | |
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated. | |
""" | |
super().__init__() | |
self.shuffle = shuffle | |
path = Path(folder) | |
text_files = [*path.glob('**/*.txt')] | |
image_files = [ | |
*path.glob('**/*.png'), *path.glob('**/*.jpg'), | |
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp') | |
] | |
text_files = {text_file.stem: text_file for text_file in text_files} | |
image_files = {image_file.stem: image_file for image_file in image_files} | |
keys = (image_files.keys() & text_files.keys()) | |
self.keys = list(keys) | |
self.text_files = {k: v for k, v in text_files.items() if k in keys} | |
self.image_files = {k: v for k, v in image_files.items() if k in keys} | |
self.text_len = text_len | |
self.truncate_captions = truncate_captions | |
self.text_tokenizer = text_tokenizer | |
self.image_tokenizer = image_tokenizer | |
def __len__(self): | |
return len(self.keys) | |
def random_sample(self): | |
return self.__getitem__(randint(0, self.__len__() - 1)) | |
def sequential_sample(self, ind): | |
if ind >= self.__len__() - 1: | |
return self.__getitem__(0) | |
return self.__getitem__(ind + 1) | |
def skip_sample(self, ind): | |
if self.shuffle: | |
return self.random_sample() | |
return self.sequential_sample(ind=ind) | |
def __getitem__(self, ind): | |
key = self.keys[ind] | |
text_file = self.text_files[key] | |
image_file = self.image_files[key] | |
descriptions = text_file.read_text().split('\n') | |
descriptions = list(filter(lambda t: len(t) > 0, descriptions)) | |
try: | |
description = choice(descriptions) | |
except IndexError as zero_captions_in_file_ex: | |
print(f"An exception occurred trying to load file {text_file}.") | |
print(f"Skipping index {ind}") | |
return self.skip_sample(ind) | |
tokenized_text = self.text_tokenizer.tokenize( | |
description, | |
self.text_len, | |
truncate_text=self.truncate_captions | |
).squeeze(0) | |
try: | |
image = Image.open(image_file).convert('RGB') | |
pixels = np.array(image).reshape(-1, 3) | |
tokenized_image = [self.image_tokenizer[str(idx)] for idx in pixels] | |
tokenized_image = torch.tensor(tokenized_image) | |
except (UnidentifiedImageError, OSError) as corrupt_image_exceptions: | |
print(f"An exception occurred trying to load file {image_file}.") | |
print(f"Skipping index {ind}") | |
return self.skip_sample(ind) | |
# Success | |
return tokenized_text, tokenized_image | |