from typing import * from abc import abstractmethod import os import json import torch import numpy as np import pandas as pd from PIL import Image from torch.utils.data import Dataset class StandardDatasetBase(Dataset): """ Base class for standard datasets. Args: roots (str): paths to the dataset """ def __init__(self, roots: str, ): super().__init__() self.roots = roots.split(',') self.instances = [] self.metadata = pd.DataFrame() self._stats = {} for root in self.roots: key = os.path.basename(root) self._stats[key] = {} metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) self._stats[key]['Total'] = len(metadata) metadata, stats = self.filter_metadata(metadata) self._stats[key].update(stats) self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) metadata.set_index('sha256', inplace=True) self.metadata = pd.concat([self.metadata, metadata]) @abstractmethod def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: pass @abstractmethod def get_instance(self, root: str, instance: str) -> Dict[str, Any]: pass def __len__(self): return len(self.instances) def __getitem__(self, index) -> Dict[str, Any]: try: root, instance = self.instances[index] return self.get_instance(root, instance) except Exception as e: print(e) return self.__getitem__(np.random.randint(0, len(self))) def __str__(self): lines = [] lines.append(self.__class__.__name__) lines.append(f' - Total instances: {len(self)}') lines.append(f' - Sources:') for key, stats in self._stats.items(): lines.append(f' - {key}:') for k, v in stats.items(): lines.append(f' - {k}: {v}') return '\n'.join(lines) class TextConditionedMixin: def __init__(self, roots, **kwargs): super().__init__(roots, **kwargs) self.captions = {} for instance in self.instances: sha256 = instance[1] self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions']) def filter_metadata(self, metadata): metadata, stats = super().filter_metadata(metadata) metadata = metadata[metadata['captions'].notna()] stats['With captions'] = len(metadata) return metadata, stats def get_instance(self, root, instance): pack = super().get_instance(root, instance) text = np.random.choice(self.captions[instance]) pack['cond'] = text return pack class ImageConditionedMixin: def __init__(self, roots, *, image_size=518, **kwargs): self.image_size = image_size super().__init__(roots, **kwargs) def filter_metadata(self, metadata): metadata, stats = super().filter_metadata(metadata) metadata = metadata[metadata[f'cond_rendered']] stats['Cond rendered'] = len(metadata) return metadata, stats def get_instance(self, root, instance): pack = super().get_instance(root, instance) image_root = os.path.join(root, 'renders_cond', instance) with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) n_views = len(metadata['frames']) view = np.random.randint(n_views) metadata = metadata['frames'][view] image_path = os.path.join(image_root, metadata['file_path']) image = Image.open(image_path) alpha = np.array(image.getchannel(3)) bbox = np.array(alpha).nonzero() bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] image = image.crop(aug_bbox) image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) alpha = image.getchannel(3) image = image.convert('RGB') image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 alpha = torch.tensor(np.array(alpha)).float() / 255.0 image = image * alpha.unsqueeze(0) pack['cond'] = image return pack