File size: 10,060 Bytes
e19aac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from .model.conversation import SeparatorStyle, conv_templates
from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token
from .model import get_model_name_from_path, load_pretrained_model
from transformers import TextIteratorStreamer
from threading import Thread

class DescribeAnythingModel(nn.Module):
    def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs):
        super().__init__()
        
        self.model_path = model_path
        self.conv_mode = conv_mode
        self.prompt_mode = prompt_mode
        self.temperature = temperature
        self.top_p = top_p
        self.num_beams = num_beams
        self.max_new_tokens = max_new_tokens

        tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs)
        model.config.image_processor = image_processor
        
        self.tokenizer = tokenizer
        self.model = model
        self.context_len = context_len
    
        self.model_name = get_model_name_from_path(model_path)
    
    def get_prompt(self, qs):
        if DEFAULT_IMAGE_TOKEN not in qs:
            raise ValueError("no <image> tag found in input.")

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        return prompt, conv

    @staticmethod
    def mask_to_box(mask_np):
        mask_coords = np.argwhere(mask_np)
        y0, x0 = mask_coords.min(axis=0)
        y1, x1 = mask_coords.max(axis=0) + 1
        
        h = y1 - y0
        w = x1 - x0

        return x0, y0, w, h

    @classmethod
    def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48):
        if crop_mode == "full":
            # no crop
            info = dict(mask_np=mask_np)
            return pil_img, info

        if crop_mode == "crop":
            # crop image and mask
            x0, y0, w, h = cls.mask_to_box(mask_np)
            img_np = np.asarray(pil_img)
            assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
            cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
            cropped_img_np = img_np[y0:y0+h, x0:x0+w]
            cropped_pil_img = Image.fromarray(cropped_img_np)
        elif crop_mode == "context_crop":
            # crop image and mask
            x0, y0, w, h = cls.mask_to_box(mask_np)
            img_np = np.asarray(pil_img)
            assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
            img_h, img_w = img_np.shape[:2]
            cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
            cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
            cropped_pil_img = Image.fromarray(cropped_img_np)
        elif crop_mode == "focal_crop":
            # crop image and mask
            x0, y0, w, h = cls.mask_to_box(mask_np)
            img_np = np.asarray(pil_img)
            assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
            img_h, img_w = img_np.shape[:2]

            xc, yc = x0 + w/2, y0 + h/2
            # focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
            w, h = max(w, min_box_w), max(h, min_box_h)
            x0, y0 = int(xc - w / 2), int(yc - h / 2)
            
            cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
            cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
            cropped_pil_img = Image.fromarray(cropped_img_np)
        elif crop_mode == "crop_mask":
            # crop image and mask
            x0, y0, w, h = cls.mask_to_box(mask_np)
            img_np = np.asarray(pil_img)
            assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
            cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
            cropped_img_np = img_np[y0:y0+h, x0:x0+w]
            # Mask the image
            cropped_img_np = cropped_img_np * cropped_mask_np[..., None]
            cropped_pil_img = Image.fromarray(cropped_img_np)
        else:
            raise ValueError(f"Unsupported crop_mode: {crop_mode}")

        info = dict(mask_np=cropped_mask_np)
        return cropped_pil_img, info

    def get_description(self, image_pil, mask_pil, query, streaming=False):
        prompt, conv = self.get_prompt(query)
        if not isinstance(image_pil, (list, tuple)):
            assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple."
            image_pils = [image_pil]
            mask_pils = [mask_pil]
        else:
            image_pils = image_pil
            mask_pils = mask_pil
        description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming)
        
        return description

    def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2):
        # the pil has True/False (if the value is non-zero, then we treat it as True)
        mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8)
        images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode))
        images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16)

        mask_np = image_info["mask_np"]
        mask_pil = Image.fromarray(mask_np * 255)
        
        masks_tensor = process_image(mask_pil, self.model.config, None)
        masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16)
        
        images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1)

        if crop_mode2 is not None:
            images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2))
            images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16)

            mask_np2 = image_info2["mask_np"]
            mask_pil2 = Image.fromarray(mask_np2 * 255)
            
            masks_tensor2 = process_image(mask_pil2, self.model.config, None)
            masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16)

            images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1)
        else:
            images_tensor2 = None
            
        return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor
    
    def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False):
        if streaming:
            return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True)
        else:
            # If streaming is False, there will be only one output
            output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False)
            return next(output)

    def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False):
        crop_mode, crop_mode2 = self.prompt_mode.split("+")
        assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt."
        
        assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}."
        image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)]
        
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None
        generation_kwargs = dict(
            input_ids=input_ids,
            images=image_tensors,
            do_sample=True if self.temperature > 0 else False,
            temperature=self.temperature,
            top_p=self.top_p,
            num_beams=self.num_beams,
            max_new_tokens=self.max_new_tokens,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
            streamer=streamer
        )


        if streaming:
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()
            
            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
                if stop_str in generated_text:
                    generated_text = generated_text[:generated_text.find(stop_str)]
                    break
                yield new_text
            
            thread.join()
        else:
            with torch.inference_mode():
                output_ids = self.model.generate(**generation_kwargs)

            outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
            outputs = outputs.strip()
            if outputs.endswith(stop_str):
                outputs = outputs[: -len(stop_str)]
            outputs = outputs.strip()

            yield outputs