cavargas10's picture
Upload 288 files
178f950 verified
from typing import *
from abc import abstractmethod
import os
import json
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
class StandardDatasetBase(Dataset):
"""
Base class for standard datasets.
Args:
roots (str): paths to the dataset
"""
def __init__(self,
roots: str,
):
super().__init__()
self.roots = roots.split(',')
self.instances = []
self.metadata = pd.DataFrame()
self._stats = {}
for root in self.roots:
key = os.path.basename(root)
self._stats[key] = {}
metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
self._stats[key]['Total'] = len(metadata)
metadata, stats = self.filter_metadata(metadata)
self._stats[key].update(stats)
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
metadata.set_index('sha256', inplace=True)
self.metadata = pd.concat([self.metadata, metadata])
@abstractmethod
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
pass
@abstractmethod
def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
pass
def __len__(self):
return len(self.instances)
def __getitem__(self, index) -> Dict[str, Any]:
try:
root, instance = self.instances[index]
return self.get_instance(root, instance)
except Exception as e:
print(e)
return self.__getitem__(np.random.randint(0, len(self)))
def __str__(self):
lines = []
lines.append(self.__class__.__name__)
lines.append(f' - Total instances: {len(self)}')
lines.append(f' - Sources:')
for key, stats in self._stats.items():
lines.append(f' - {key}:')
for k, v in stats.items():
lines.append(f' - {k}: {v}')
return '\n'.join(lines)
class TextConditionedMixin:
def __init__(self, roots, **kwargs):
super().__init__(roots, **kwargs)
self.captions = {}
for instance in self.instances:
sha256 = instance[1]
self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
def filter_metadata(self, metadata):
metadata, stats = super().filter_metadata(metadata)
metadata = metadata[metadata['captions'].notna()]
stats['With captions'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
pack = super().get_instance(root, instance)
text = np.random.choice(self.captions[instance])
pack['cond'] = text
return pack
class ImageConditionedMixin:
def __init__(self, roots, *, image_size=518, **kwargs):
self.image_size = image_size
super().__init__(roots, **kwargs)
def filter_metadata(self, metadata):
metadata, stats = super().filter_metadata(metadata)
metadata = metadata[metadata[f'cond_rendered']]
stats['Cond rendered'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
pack = super().get_instance(root, instance)
image_root = os.path.join(root, 'renders_cond', instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
view = np.random.randint(n_views)
metadata = metadata['frames'][view]
image_path = os.path.join(image_root, metadata['file_path'])
image = Image.open(image_path)
alpha = np.array(image.getchannel(3))
bbox = np.array(alpha).nonzero()
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
image = image.crop(aug_bbox)
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
alpha = image.getchannel(3)
image = image.convert('RGB')
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
image = image * alpha.unsqueeze(0)
pack['cond'] = image
return pack