Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
from .model.conversation import SeparatorStyle, conv_templates | |
from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token | |
from .model import get_model_name_from_path, load_pretrained_model | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
class DescribeAnythingModel(nn.Module): | |
def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs): | |
super().__init__() | |
self.model_path = model_path | |
self.conv_mode = conv_mode | |
self.prompt_mode = prompt_mode | |
self.temperature = temperature | |
self.top_p = top_p | |
self.num_beams = num_beams | |
self.max_new_tokens = max_new_tokens | |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs) | |
model.config.image_processor = image_processor | |
self.tokenizer = tokenizer | |
self.model = model | |
self.context_len = context_len | |
self.model_name = get_model_name_from_path(model_path) | |
def get_prompt(self, qs): | |
if DEFAULT_IMAGE_TOKEN not in qs: | |
raise ValueError("no <image> tag found in input.") | |
conv = conv_templates[self.conv_mode].copy() | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
return prompt, conv | |
def mask_to_box(mask_np): | |
mask_coords = np.argwhere(mask_np) | |
y0, x0 = mask_coords.min(axis=0) | |
y1, x1 = mask_coords.max(axis=0) + 1 | |
h = y1 - y0 | |
w = x1 - x0 | |
return x0, y0, w, h | |
def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48): | |
if crop_mode == "full": | |
# no crop | |
info = dict(mask_np=mask_np) | |
return pil_img, info | |
if crop_mode == "crop": | |
# crop image and mask | |
x0, y0, w, h = cls.mask_to_box(mask_np) | |
img_np = np.asarray(pil_img) | |
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" | |
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] | |
cropped_img_np = img_np[y0:y0+h, x0:x0+w] | |
cropped_pil_img = Image.fromarray(cropped_img_np) | |
elif crop_mode == "context_crop": | |
# crop image and mask | |
x0, y0, w, h = cls.mask_to_box(mask_np) | |
img_np = np.asarray(pil_img) | |
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" | |
img_h, img_w = img_np.shape[:2] | |
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] | |
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] | |
cropped_pil_img = Image.fromarray(cropped_img_np) | |
elif crop_mode == "focal_crop": | |
# crop image and mask | |
x0, y0, w, h = cls.mask_to_box(mask_np) | |
img_np = np.asarray(pil_img) | |
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" | |
img_h, img_w = img_np.shape[:2] | |
xc, yc = x0 + w/2, y0 + h/2 | |
# focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD | |
w, h = max(w, min_box_w), max(h, min_box_h) | |
x0, y0 = int(xc - w / 2), int(yc - h / 2) | |
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] | |
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] | |
cropped_pil_img = Image.fromarray(cropped_img_np) | |
elif crop_mode == "crop_mask": | |
# crop image and mask | |
x0, y0, w, h = cls.mask_to_box(mask_np) | |
img_np = np.asarray(pil_img) | |
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" | |
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] | |
cropped_img_np = img_np[y0:y0+h, x0:x0+w] | |
# Mask the image | |
cropped_img_np = cropped_img_np * cropped_mask_np[..., None] | |
cropped_pil_img = Image.fromarray(cropped_img_np) | |
else: | |
raise ValueError(f"Unsupported crop_mode: {crop_mode}") | |
info = dict(mask_np=cropped_mask_np) | |
return cropped_pil_img, info | |
def get_description(self, image_pil, mask_pil, query, streaming=False): | |
prompt, conv = self.get_prompt(query) | |
if not isinstance(image_pil, (list, tuple)): | |
assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple." | |
image_pils = [image_pil] | |
mask_pils = [mask_pil] | |
else: | |
image_pils = image_pil | |
mask_pils = mask_pil | |
description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming) | |
return description | |
def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2): | |
# the pil has True/False (if the value is non-zero, then we treat it as True) | |
mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8) | |
images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode)) | |
images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16) | |
mask_np = image_info["mask_np"] | |
mask_pil = Image.fromarray(mask_np * 255) | |
masks_tensor = process_image(mask_pil, self.model.config, None) | |
masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16) | |
images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1) | |
if crop_mode2 is not None: | |
images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2)) | |
images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16) | |
mask_np2 = image_info2["mask_np"] | |
mask_pil2 = Image.fromarray(mask_np2 * 255) | |
masks_tensor2 = process_image(mask_pil2, self.model.config, None) | |
masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16) | |
images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1) | |
else: | |
images_tensor2 = None | |
return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor | |
def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False): | |
if streaming: | |
return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True) | |
else: | |
# If streaming is False, there will be only one output | |
output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False) | |
return next(output) | |
def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False): | |
crop_mode, crop_mode2 = self.prompt_mode.split("+") | |
assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt." | |
assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}." | |
image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)] | |
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) | |
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None | |
generation_kwargs = dict( | |
input_ids=input_ids, | |
images=image_tensors, | |
do_sample=True if self.temperature > 0 else False, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
num_beams=self.num_beams, | |
max_new_tokens=self.max_new_tokens, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria], | |
streamer=streamer | |
) | |
if streaming: | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
if stop_str in generated_text: | |
generated_text = generated_text[:generated_text.find(stop_str)] | |
break | |
yield new_text | |
thread.join() | |
else: | |
with torch.inference_mode(): | |
output_ids = self.model.generate(**generation_kwargs) | |
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
outputs = outputs.strip() | |
if outputs.endswith(stop_str): | |
outputs = outputs[: -len(stop_str)] | |
outputs = outputs.strip() | |
yield outputs | |