Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,060 Bytes
e19aac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from .model.conversation import SeparatorStyle, conv_templates
from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token
from .model import get_model_name_from_path, load_pretrained_model
from transformers import TextIteratorStreamer
from threading import Thread
class DescribeAnythingModel(nn.Module):
def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs):
super().__init__()
self.model_path = model_path
self.conv_mode = conv_mode
self.prompt_mode = prompt_mode
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.max_new_tokens = max_new_tokens
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs)
model.config.image_processor = image_processor
self.tokenizer = tokenizer
self.model = model
self.context_len = context_len
self.model_name = get_model_name_from_path(model_path)
def get_prompt(self, qs):
if DEFAULT_IMAGE_TOKEN not in qs:
raise ValueError("no <image> tag found in input.")
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt, conv
@staticmethod
def mask_to_box(mask_np):
mask_coords = np.argwhere(mask_np)
y0, x0 = mask_coords.min(axis=0)
y1, x1 = mask_coords.max(axis=0) + 1
h = y1 - y0
w = x1 - x0
return x0, y0, w, h
@classmethod
def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48):
if crop_mode == "full":
# no crop
info = dict(mask_np=mask_np)
return pil_img, info
if crop_mode == "crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
cropped_img_np = img_np[y0:y0+h, x0:x0+w]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "context_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "focal_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
xc, yc = x0 + w/2, y0 + h/2
# focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
w, h = max(w, min_box_w), max(h, min_box_h)
x0, y0 = int(xc - w / 2), int(yc - h / 2)
cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "crop_mask":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
cropped_img_np = img_np[y0:y0+h, x0:x0+w]
# Mask the image
cropped_img_np = cropped_img_np * cropped_mask_np[..., None]
cropped_pil_img = Image.fromarray(cropped_img_np)
else:
raise ValueError(f"Unsupported crop_mode: {crop_mode}")
info = dict(mask_np=cropped_mask_np)
return cropped_pil_img, info
def get_description(self, image_pil, mask_pil, query, streaming=False):
prompt, conv = self.get_prompt(query)
if not isinstance(image_pil, (list, tuple)):
assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple."
image_pils = [image_pil]
mask_pils = [mask_pil]
else:
image_pils = image_pil
mask_pils = mask_pil
description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming)
return description
def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2):
# the pil has True/False (if the value is non-zero, then we treat it as True)
mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8)
images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode))
images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16)
mask_np = image_info["mask_np"]
mask_pil = Image.fromarray(mask_np * 255)
masks_tensor = process_image(mask_pil, self.model.config, None)
masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16)
images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1)
if crop_mode2 is not None:
images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2))
images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16)
mask_np2 = image_info2["mask_np"]
mask_pil2 = Image.fromarray(mask_np2 * 255)
masks_tensor2 = process_image(mask_pil2, self.model.config, None)
masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16)
images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1)
else:
images_tensor2 = None
return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor
def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False):
if streaming:
return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True)
else:
# If streaming is False, there will be only one output
output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False)
return next(output)
def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False):
crop_mode, crop_mode2 = self.prompt_mode.split("+")
assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt."
assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}."
image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)]
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None
generation_kwargs = dict(
input_ids=input_ids,
images=image_tensors,
do_sample=True if self.temperature > 0 else False,
temperature=self.temperature,
top_p=self.top_p,
num_beams=self.num_beams,
max_new_tokens=self.max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
streamer=streamer
)
if streaming:
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if stop_str in generated_text:
generated_text = generated_text[:generated_text.find(stop_str)]
break
yield new_text
thread.join()
else:
with torch.inference_mode():
output_ids = self.model.generate(**generation_kwargs)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
yield outputs
|