describe-anything-model-demo / dam /describe_anything_model.py
richardaecn's picture
Upload 105 files
e19aac6 verified
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from .model.conversation import SeparatorStyle, conv_templates
from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token
from .model import get_model_name_from_path, load_pretrained_model
from transformers import TextIteratorStreamer
from threading import Thread
class DescribeAnythingModel(nn.Module):
def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs):
super().__init__()
self.model_path = model_path
self.conv_mode = conv_mode
self.prompt_mode = prompt_mode
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.max_new_tokens = max_new_tokens
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs)
model.config.image_processor = image_processor
self.tokenizer = tokenizer
self.model = model
self.context_len = context_len
self.model_name = get_model_name_from_path(model_path)
def get_prompt(self, qs):
if DEFAULT_IMAGE_TOKEN not in qs:
raise ValueError("no <image> tag found in input.")
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt, conv
@staticmethod
def mask_to_box(mask_np):
mask_coords = np.argwhere(mask_np)
y0, x0 = mask_coords.min(axis=0)
y1, x1 = mask_coords.max(axis=0) + 1
h = y1 - y0
w = x1 - x0
return x0, y0, w, h
@classmethod
def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48):
if crop_mode == "full":
# no crop
info = dict(mask_np=mask_np)
return pil_img, info
if crop_mode == "crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
cropped_img_np = img_np[y0:y0+h, x0:x0+w]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "context_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "focal_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
xc, yc = x0 + w/2, y0 + h/2
# focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
w, h = max(w, min_box_w), max(h, min_box_h)
x0, y0 = int(xc - w / 2), int(yc - h / 2)
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "crop_mask":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
cropped_img_np = img_np[y0:y0+h, x0:x0+w]
# Mask the image
cropped_img_np = cropped_img_np * cropped_mask_np[..., None]
cropped_pil_img = Image.fromarray(cropped_img_np)
else:
raise ValueError(f"Unsupported crop_mode: {crop_mode}")
info = dict(mask_np=cropped_mask_np)
return cropped_pil_img, info
def get_description(self, image_pil, mask_pil, query, streaming=False):
prompt, conv = self.get_prompt(query)
if not isinstance(image_pil, (list, tuple)):
assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple."
image_pils = [image_pil]
mask_pils = [mask_pil]
else:
image_pils = image_pil
mask_pils = mask_pil
description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming)
return description
def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2):
# the pil has True/False (if the value is non-zero, then we treat it as True)
mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8)
images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode))
images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16)
mask_np = image_info["mask_np"]
mask_pil = Image.fromarray(mask_np * 255)
masks_tensor = process_image(mask_pil, self.model.config, None)
masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16)
images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1)
if crop_mode2 is not None:
images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2))
images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16)
mask_np2 = image_info2["mask_np"]
mask_pil2 = Image.fromarray(mask_np2 * 255)
masks_tensor2 = process_image(mask_pil2, self.model.config, None)
masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16)
images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1)
else:
images_tensor2 = None
return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor
def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False):
if streaming:
return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True)
else:
# If streaming is False, there will be only one output
output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False)
return next(output)
def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False):
crop_mode, crop_mode2 = self.prompt_mode.split("+")
assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt."
assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}."
image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)]
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None
generation_kwargs = dict(
input_ids=input_ids,
images=image_tensors,
do_sample=True if self.temperature > 0 else False,
temperature=self.temperature,
top_p=self.top_p,
num_beams=self.num_beams,
max_new_tokens=self.max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
streamer=streamer
)
if streaming:
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if stop_str in generated_text:
generated_text = generated_text[:generated_text.find(stop_str)]
break
yield new_text
thread.join()
else:
with torch.inference_mode():
output_ids = self.model.generate(**generation_kwargs)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
yield outputs