IDMR-demo / src /dataset.py
liubangwei
init IDMR demo
1855cc2
import random
from typing import List, Tuple
from itertools import islice
import datasets
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision.transforms import RandAugment
def get_randaugment_transform(n=2, m=9):
return RandAugment(num_ops=n, magnitude=m)
def add_prompt_template(data):
data["qry"] = f"<|image_1|>{data['qry']}"
data["pos_text"] = f"<|image_1|>{data['pos_text']}"
data["hard_neg_text"] = f"<|image_1|>{data['hard_neg_text']}"
return data
Phi_Image_token = "<|image_1|>"
Llava_Image_token = "<image>"
Qwen_Image_token = "<|image_pad|>"
Internvl_Image_token = "<image>"
class TrainDataset(Dataset):
def __init__(self, data_args, model_args):
self.data_args = data_args
self.model_args = model_args
self.transform = None
if self.data_args.randaugment:
self.transform = get_randaugment_transform()
train_data = []
if data_args.subset_name is not None:
print(f"Loading {len(data_args.subset_name)} datasets: {data_args.subset_name}")
for subset in data_args.subset_name:
dataset_name = os.path.join(self.data_args.dataset_name, subset)
subset_data = load_dataset(
dataset_name,
split=f"{self.data_args.dataset_split}",
)
train_data.append(subset_data)
self.train_data = concatenate_datasets(train_data)
self.train_data = self.train_data.shuffle(seed=42)
else:
train_data = load_dataset(
self.data_args.dataset_name,
split=f"{self.data_args.dataset_split}",
)
if "hard_neg" in self.data_args.dataset_name:
# self.train_data = train_data.map(add_prompt_template, num_proc=8)
print(train_data)
else:
self.train_data = train_data
if self.data_args.num_samples:
# self.train_data = self.train_data[:self.data_args.num_samples]
self.train_data = self.train_data.select(range(self.data_args.num_samples))
print(f"len of train_data: {len(self.train_data)}")
def __len__(self):
return len(self.train_data)
def _process_image(self, image, resolution):
if image is None:
return None
if resolution == "high":
image = image.resize((1344, 1344))
elif resolution == "low":
image = image.resize((336, 336))
elif resolution == "clip":
image = image.resize((224, 224))
return image
def _get_image(self, img_path):
if img_path == "":
return None
if img_path.startswith('/'):
full_img_path = img_path
else:
full_img_path = os.path.join(self.data_args.image_dir, img_path)
image = Image.open(full_img_path)
if self.model_args.model_backbone == "llava_next":
# TODO: make it configurable
return self._process_image(image, "high")
elif self.model_args.model_backbone == "qwen":
return self._process_image(image, "low")
elif self.model_args.model_backbone == "internvl_2_5":
# TODO: make it configurable
return self._process_image(image, "high")
else:
return image
def __getitem__(self, item) -> Tuple[str, List[str]]:
data_item = self.train_data[item]
qry_text, qry_image_path, pos_text, pos_image_path = (
data_item["qry"], data_item["qry_image_path"],
data_item["pos_text"], data_item["pos_image_path"],
)
qry_image = self._get_image(qry_image_path)
if self.transform:
qry_image = self.transform(qry_image)
if self.model_args.model_backbone == "llava_next":
# Update image token
qry_text = qry_text.replace(Phi_Image_token, Llava_Image_token)
pos_text = pos_text.replace(Phi_Image_token, Llava_Image_token)
elif self.model_args.model_backbone == "qwen":
qry_text = qry_text.replace(Phi_Image_token, Qwen_Image_token)
pos_text = pos_text.replace(Phi_Image_token, Qwen_Image_token)
elif self.model_args.model_backbone == "internvl_2_5":
qry_text = qry_text.replace(Phi_Image_token, Internvl_Image_token)
pos_text = pos_text.replace(Phi_Image_token, Internvl_Image_token)
if "hard_neg" in self.data_args.dataset_name:
hard_neg_text, hard_neg_image_path = (
data_item["hard_neg_text"], data_item["hard_neg_image_path"],
)
if self.model_args.model_backbone == "llava_next":
# Update image token
hard_neg_text = hard_neg_text.replace(Phi_Image_token, Llava_Image_token)
elif self.model_args.model_backbone == "internvl_2_5":
hard_neg_text = hard_neg_text.replace(Phi_Image_token, Internvl_Image_token)
return (
qry_text, qry_image,
pos_text, self._get_image(pos_image_path),
hard_neg_text, self._get_image(hard_neg_image_path)
)
return (
qry_text, qry_image,
pos_text, self._get_image(pos_image_path)
)
class EvalDataset(Dataset):
def __init__(self, data_args, model_args, subset, text_field, img_path_field):
"""
(text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path")
"""
self.data_args = data_args
self.model_args = model_args
if data_args.subset_name is not None:
self.eval_data = load_dataset(
self.data_args.dataset_name,
subset,
split=self.data_args.dataset_split,
)
else:
self.eval_data = load_dataset(
self.data_args.dataset_name,
split=self.data_args.dataset_split,
)
print(f"len of eval_data: {len(self.eval_data)}")
self.paired_data = self.get_paired_data(text_field, img_path_field)
self.paired_dataset = datasets.Dataset.from_dict({
"text": [pair["text"] for pair in self.paired_data],
"img_path": [pair["img_path"] for pair in self.paired_data]
})
def __len__(self):
return len(self.paired_dataset)
def __getitem__(self, item):
text, img_path = self.paired_dataset[item]["text"], self.paired_dataset[item]["img_path"]
if self.model_args.model_backbone == "llava_next":
# Update llava image token
text = text.replace(Phi_Image_token, Llava_Image_token)
elif self.model_args.model_backbone == "qwen":
text = text.replace(Phi_Image_token, Qwen_Image_token)
elif self.model_args.model_backbone == "internvl_2_5":
text = text.replace(Phi_Image_token, Internvl_Image_token)
return text, self._get_image(img_path),
def _process_image(self, image, resolution):
if image is None:
return None
if resolution == "high":
image = image.resize((1344, 1344))
else:
image = image.resize((336, 336))
return image
def _get_image(self, img_path):
if img_path == "":
return None
if img_path.startswith("/"):
full_img_path = img_path
else:
full_img_path = os.path.join(self.data_args.image_dir, img_path)
image = Image.open(full_img_path)
if self.model_args.model_backbone == "llava_next":
return self._process_image(image, "high")
elif self.model_args.model_backbone == "internvl_2_5":
return self._process_image(image, "high")
else:
return image
return image
def get_paired_data(self, text_field, img_path_field):
"""
(text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path")
"""
unique_pair = set()
for row in self.eval_data:
if isinstance(row[text_field], str):
if row[text_field]:
unique_pair.add((row[text_field], row[img_path_field]))
else:
if isinstance(row[img_path_field], List):
for img_path in row[img_path_field]:
unique_pair.add((row[text_field], img_path))
else:
unique_pair.add((row[text_field], row[img_path_field]))
elif isinstance(row[text_field], List):
assert isinstance(row[img_path_field], List) and len(row[img_path_field]) == len(row[text_field])
for text, img_path in zip(row[text_field], row[img_path_field]):
unique_pair.add((text, img_path))
paired_data = [{"text": text, "img_path": img_path} for text, img_path in unique_pair]
return paired_data
class FlickrDataset(Dataset):
def __init__(self, modality, model_backbone):
self.model_backbone = model_backbone
self.modality = modality
self.raw_data = load_dataset("nlphuji/flickr_1k_test_image_text_retrieval", split="test")
if modality == "image":
self.eval_data, self.image_names = self.get_image_data()
else:
self.eval_data, self.image_names = self.get_text_data()
def __len__(self):
return len(self.eval_data)
def __getitem__(self, idx):
return self.eval_data[idx]
def __getitem__(self, idx):
text, image = self.eval_data[idx]
if self.model_backbone == "llava_next":
# Update llava image token
text = text.replace(Phi_Image_token, Llava_Image_token)
image = self._process_image(image, "high")
return text, image
def _process_image(self, image, resolution):
if image is None:
return None
if resolution == "high":
image = image.resize((1344, 1344))
else:
image = image.resize((336, 336))
return image
def _get_image(self, img_path):
if img_path == "":
return None
full_img_path = os.path.join(self.data_args.image_dir, img_path)
image = Image.open(full_img_path)
if self.model_backbone == "llava_next":
return self._process_image(image, "high")
else:
return image
return image
def get_image_data(self):
eval_data, image_names = [], []
# i2t
inst = "<|image_1|> Find an image caption describing the given image." # llava-1344-step1k4, i2t=94.0, t2i=80.26
# inst = "<|image_1|> Represent the given image for image caption retrieval." # llava-1344-step1k4, i2t=94.6, t2i=78.98
# t2i
# inst = "<|image_1|> Represent the given image." # MSCOCO t2i
for row in self.raw_data:
eval_data.append((inst, row["image"]))
image_names.append(row["filename"])
return eval_data, image_names
def get_text_data(self):
eval_data, image_names = [], []
# i2t
inst = ""
# t2i
# inst = "Retrieve an image that matches the given caption: "
# inst = "Find me an everyday image that matches the given caption." # MSCOCO t2i
for row in self.raw_data:
for caption in row["caption"]:
# eval_data.append((caption, None))
eval_data.append((inst + caption, None))
image_names.append(row["filename"])
return eval_data, image_names