|
import re |
|
import types |
|
import io |
|
import torch |
|
from PIL import Image |
|
from qwen_vl_utils import fetch_image |
|
|
|
from transformers import ( |
|
ProcessorMixin, |
|
SiglipImageProcessor, |
|
BatchFeature, |
|
Qwen2VLImageProcessor, |
|
PreTrainedTokenizer |
|
) |
|
|
|
from .utils import ( |
|
process_anyres_image, |
|
BLACK_IMG_ENV, |
|
DEFAULT_IM_END_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IMAGE_TOKEN, |
|
DEFAULT_VI_END_TOKEN, |
|
DEFAULT_VI_START_TOKEN, |
|
DEFAULT_VIDEO_TOKEN, |
|
IMAGE_TOKEN_INDEX, |
|
SEQ_MAX_LEN, |
|
) |
|
|
|
siglip_processor_config = { |
|
"do_normalize": True, |
|
"do_rescale": True, |
|
"do_resize": True, |
|
"image_mean": [ |
|
0.5, |
|
0.5, |
|
0.5 |
|
], |
|
"image_processor_type": "SiglipImageProcessor", |
|
"image_std": [ |
|
0.5, |
|
0.5, |
|
0.5 |
|
], |
|
"processor_class": "SiglipProcessor", |
|
"resample": 3, |
|
"rescale_factor": 0.00392156862745098, |
|
"size": { |
|
"height": 384, |
|
"width": 384 |
|
} |
|
} |
|
|
|
qwen2vl_processor_config = { |
|
"min_pixels": 3136, |
|
"max_pixels": 12845056, |
|
"patch_size": 14, |
|
"temporal_patch_size": 2, |
|
"merge_size": 2, |
|
"image_mean": [ |
|
0.48145466, |
|
0.4578275, |
|
0.40821073 |
|
], |
|
"image_std": [ |
|
0.26862954, |
|
0.26130258, |
|
0.27577711 |
|
], |
|
"image_processor_type": "Qwen2VLImageProcessor", |
|
"processor_class": "Qwen2VLProcessor" |
|
} |
|
|
|
class ValleyProcessor(ProcessorMixin): |
|
attributes = ["tokenizer"] |
|
optional_attributes = [ |
|
"max_pixels", |
|
"min_pixels", |
|
"anyres", |
|
"only_crop_single_image", |
|
"grid_pinpoints", |
|
"use_special_start_end_token", |
|
] |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, tokenizer=None, **kwargs): |
|
super().__init__(tokenizer, **kwargs) |
|
self.black_img = BLACK_IMG_ENV |
|
self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config) |
|
self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict( |
|
qwen2vl_processor_config, |
|
) |
|
|
|
self.anyres = kwargs.get("anyres", True) |
|
self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)") |
|
self.only_crop_single_image = kwargs.get("only_crop_single_image", True) |
|
self.use_special_start_end_token = kwargs.get("use_special_start_end_token", True) |
|
self.only_navit = kwargs.get("only_navit", False) |
|
|
|
def preprocess_images_siglip(self, images) -> torch.FloatTensor: |
|
if isinstance(images[0], str): |
|
images_pil = [Image.open(img).convert("RGB") for img in images] |
|
elif isinstance(images[0], Image.Image): |
|
images_pil = [img.convert("RGB") for img in images] |
|
elif isinstance(images[0], bytes): |
|
images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images] |
|
else: |
|
raise ValueError("unsupported type") |
|
|
|
processed_images = [] |
|
have_multi_images = len(images_pil) > 1 |
|
for img in images_pil: |
|
if self.anyres: |
|
if not self.only_crop_single_image or not have_multi_images: |
|
image = process_anyres_image(img, self.siglip_image_processor, self.grid_pinpoints) |
|
else: |
|
image = [self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]] |
|
else: |
|
image = self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0] |
|
|
|
processed_images.append(image) |
|
|
|
if not self.anyres: |
|
return torch.stack(processed_images, dim=0) |
|
else: |
|
return [torch.stack(img, dim=0) for img in processed_images] |
|
|
|
def preprocess_images_qwen2vl(self, images) -> dict: |
|
if isinstance(images[0], str): |
|
images_pil = [Image.open(img).convert("RGB") for img in images] |
|
elif isinstance(images[0], Image.Image): |
|
images_pil = [img.convert("RGB") for img in images] |
|
elif isinstance(images[0], bytes): |
|
images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images] |
|
else: |
|
raise ValueError("unsupported type") |
|
|
|
image_sizes = [[x.size for x in images_pil]] |
|
data_dict_qwen2vl = self.qwen2vl_image_processor( |
|
[fetch_image({"image": img}) for img in images_pil], |
|
return_tensors="pt" |
|
) |
|
|
|
data_dict_qwen2vl["image_sizes"] = image_sizes |
|
|
|
return data_dict_qwen2vl |
|
|
|
def preprocess_multimodal(self, conversations): |
|
for sentence in conversations: |
|
if sentence["role"] == "system": |
|
continue |
|
segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["content"]) |
|
if self.use_special_start_end_token: |
|
sentence["content"] = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN).join(segs) |
|
else: |
|
sentence["content"] = DEFAULT_IMAGE_TOKEN.join(segs) |
|
|
|
return conversations |
|
|
|
def preprocess_qwen2( |
|
self, |
|
conversations, |
|
tokenizer: PreTrainedTokenizer, |
|
has_image: bool = False, |
|
inference: bool = False, |
|
only_mask_system: bool = False, |
|
) -> dict: |
|
conv = types.SimpleNamespace( |
|
system="You are a helpful assistant.", |
|
roles=("user", "assistant"), |
|
version="qwen2", |
|
offset=0, |
|
sep="<|im_start|>", |
|
sep2="<|im_end|>\n", |
|
) |
|
|
|
|
|
assert conversations[0]["role"] == "system" |
|
if conversations[0]["content"] == None: |
|
conversations[0]["content"] = conv.system |
|
|
|
|
|
for j, sentence in enumerate(conversations[1:]): |
|
role = sentence["role"] |
|
assert role == conv.roles[j % 2], "The conversation sequence is incorrect." |
|
|
|
conversation_str = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=inference) |
|
|
|
|
|
rounds = conversation_str.split(conv.sep2) |
|
input_ids_ = torch.tensor([], dtype=torch.int64) |
|
targets_ = torch.tensor([], dtype=torch.int64) |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
continue |
|
if (not inference) or (i < (len(rounds) - 1)): |
|
rou += conv.sep2 |
|
if has_image: |
|
cur_input_ids_ = self.tokenizer_image_token(rou, tokenizer, return_tensors='pt') |
|
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0) |
|
if only_mask_system: |
|
mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[0]}\n[\s\S]*', f'{conv.roles[0]}:', rou), |
|
tokenizer)) |
|
else: |
|
mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[1]}\n[\s\S]*', f'{conv.roles[1]}:', rou), |
|
tokenizer)) |
|
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0) |
|
else: |
|
cur_input_ids_ = tokenizer(rou, return_tensors='pt')["input_ids"][0, :] |
|
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0) |
|
mask_len = len(tokenizer(re.sub(rf'{conv.roles[1]}\n[\s\S]*', rf'{conv.roles[1]}:', rou))["input_ids"][:]) |
|
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0) |
|
|
|
return {"input_ids": input_ids_, "labels": targets_} |
|
|
|
|
|
def tokenizer_image_token( |
|
self, |
|
prompt, |
|
tokenizer, |
|
image_token_index=IMAGE_TOKEN_INDEX, |
|
return_tensors=None, |
|
): |
|
def split_with_token(string, token): |
|
result = string.split(token) |
|
for i in range(len(result) - 1): |
|
result.insert(i * 2 + 1, token) |
|
return result |
|
|
|
if len(prompt) > SEQ_MAX_LEN: |
|
raise ValueError("sequence is too long !!!") |
|
|
|
prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN) |
|
input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0) |
|
token2index = {DEFAULT_IMAGE_TOKEN: image_token_index} |
|
for chunk in prompt_chunks: |
|
if chunk in token2index: |
|
input_ids.append(token2index[chunk]) |
|
else: |
|
chunk_ids = tokenizer(chunk).input_ids |
|
if chunk_ids[0] != getattr(tokenizer,'bos_token_id', None): |
|
offset = 0 |
|
input_ids.extend(chunk_ids[offset:]) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == "pt": |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
return input_ids |
|
|
|
|
|
def __call__(self, messages, inference=True, **kwargs) -> BatchFeature: |
|
max_pixels=kwargs.get("max_pixels", self.max_pixels) |
|
min_pixels=kwargs.get("min_pixels", self.min_pixels) |
|
if max_pixels is not None: |
|
self.qwen2vl_image_processor.max_pixels = max_pixels |
|
if min_pixels is not None: |
|
self.qwen2vl_image_processor.min_pixels = min_pixels |
|
|
|
|
|
if "images" not in messages or not messages["images"] or not messages["images"][0]: |
|
images = [self.black_img] |
|
elif type(messages["images"]) == str: |
|
images = [messages["images"]] |
|
else: |
|
images = messages["images"][:16] |
|
|
|
|
|
conversations = messages["conversations"] |
|
if conversations[0]["role"] != "system": |
|
conversations = [{"role":"system", "content": None}] + conversations |
|
|
|
|
|
assert conversations[1]["role"] == "user" |
|
if images and "<image>" not in conversations[1]["content"]: |
|
image_token = " ".join(["<image>"] * len(images)) |
|
conversations[1]["content"] = f"{image_token}\n{conversations[1]['content']}" |
|
|
|
|
|
if inference: |
|
assert conversations[-1]["role"] == "user", "the last message should be assistant if inference=True" |
|
|
|
|
|
if self.only_navit: |
|
precessed_images_siglip = None |
|
else: |
|
precessed_images_siglip = self.preprocess_images_siglip(images) |
|
processed_data_dict_qwen2vl = self.preprocess_images_qwen2vl(images) |
|
source = self.preprocess_multimodal(conversations) |
|
data_dict = self.preprocess_qwen2(source, self.tokenizer, has_image=True, only_mask_system=False, inference=inference) |
|
|
|
|
|
data_dict["input_ids"] = data_dict["input_ids"].unsqueeze(0) |
|
data_dict["labels"] = data_dict["labels"].unsqueeze(0) |
|
data_dict["images"] = [precessed_images_siglip] |
|
|
|
return BatchFeature(data={**data_dict, **processed_data_dict_qwen2vl}) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |