upload model arch
Browse files- README.md +259 -454
- chat_template.json +1 -0
- config.json +20 -99
- configuration_eagle_chat.py → configuration_eagle2_5_vl.py +50 -39
- demo.py +49 -422
- image_processing_eagle2.py +715 -0
- image_processing_eagle2_5_vl_fast.py +458 -0
- modeling_eagle_chat.py → modeling_eagle2_5_vl.py +129 -193
- preprocessor_config.json +23 -0
- processing_eagle2_5_vl.py +738 -0
- processor_config.json +15 -0
README.md
CHANGED
@@ -18,7 +18,12 @@ tags:
|
|
18 |
# Eagle-2
|
19 |
|
20 |
[\[📂 GitHub\]](https://github.com/NVlabs/EAGLE) [\[📜 Eagle2 Tech Report\]](http://arxiv.org/abs/2501.14818)
|
21 |
-
[\[
|
|
|
|
|
|
|
|
|
|
|
22 |
## Introduction
|
23 |
|
24 |
We are thrilled to release our latest Eagle2 series Vision-Language Model. Open-source Vision-Language Models (VLMs) have made significant strides in narrowing the gap with proprietary models. However, critical details about data strategies and implementation are often missing, limiting reproducibility and innovation. In this project, we focus on VLM post-training from a data-centric perspective, sharing insights into building effective data strategies from scratch. By combining these strategies with robust training recipes and model design, we introduce Eagle2, a family of performant VLMs. Our work aims to empower the open-source community to develop competitive VLMs with transparent processes.
|
@@ -66,490 +71,290 @@ We provide the following models:
|
|
66 |
|
67 |
|
68 |
|
69 |
-
We provide a [
|
70 |
- pure text input
|
71 |
- single image input
|
72 |
- multiple image input
|
73 |
- video input
|
74 |
|
75 |
-
###
|
76 |
|
77 |
```bash
|
78 |
pip install transformers
|
79 |
pip install flash-attn
|
80 |
```
|
81 |
-
**Note**: Latest version of transformers is not compatible with the model.
|
82 |
|
83 |
-
### 1. Prepare the Model worker
|
84 |
|
85 |
-
|
86 |
-
<summary>Click to expand</summary>
|
87 |
|
88 |
```python
|
89 |
-
|
90 |
-
"""
|
91 |
-
A model worker executes the model.
|
92 |
-
Copied and modified from https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py
|
93 |
-
"""
|
94 |
-
# Importing torch before transformers can cause `segmentation fault`
|
95 |
-
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
|
96 |
-
|
97 |
-
import argparse
|
98 |
-
import base64
|
99 |
-
import json
|
100 |
-
import os
|
101 |
-
import decord
|
102 |
-
import threading
|
103 |
-
import time
|
104 |
-
from io import BytesIO
|
105 |
-
from threading import Thread
|
106 |
-
import math
|
107 |
import requests
|
|
|
108 |
import torch
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
# Calculate the size of each segment from which a frame will be extracted
|
140 |
-
seg_size = float(total_num_frames - 1) / desired_num_frames
|
141 |
-
|
142 |
-
seq = []
|
143 |
-
for i in range(desired_num_frames):
|
144 |
-
# Calculate the start and end indices of each segment
|
145 |
-
start = int(np.round(seg_size * i))
|
146 |
-
end = int(np.round(seg_size * (i + 1)))
|
147 |
-
|
148 |
-
# Append the middle index of the segment to the list
|
149 |
-
seq.append((start + end) // 2)
|
150 |
-
|
151 |
-
return seq
|
152 |
-
|
153 |
-
def build_video_prompt(meta_list, num_frames, time_position=False):
|
154 |
-
# if time_position is True, the frame_timestamp is used.
|
155 |
-
# 1. pass time_position, 2. use env TIME_POSITION
|
156 |
-
time_position = os.environ.get("TIME_POSITION", time_position)
|
157 |
-
prefix = f"This is a video:\n"
|
158 |
-
for i in range(num_frames):
|
159 |
-
if time_position:
|
160 |
-
frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
|
161 |
-
else:
|
162 |
-
frame_txt = f"Frame {i+1}: <image>\n"
|
163 |
-
prefix += frame_txt
|
164 |
-
return prefix
|
165 |
-
|
166 |
-
def load_video(video_path, num_frames=64, frame_cache_root=None):
|
167 |
-
if isinstance(video_path, str):
|
168 |
-
video = decord.VideoReader(video_path)
|
169 |
-
elif isinstance(video_path, dict):
|
170 |
-
assert False, 'we not support vidoe: "video_path" as input'
|
171 |
-
fps = video.get_avg_fps()
|
172 |
-
sampled_frames = get_seq_frames(len(video), num_frames)
|
173 |
-
samepld_timestamps = [i / fps for i in sampled_frames]
|
174 |
-
frames = video.get_batch(sampled_frames).asnumpy()
|
175 |
-
images = [Image.fromarray(frame) for frame in frames]
|
176 |
-
|
177 |
-
return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
|
178 |
-
|
179 |
-
def load_image(image):
|
180 |
-
if isinstance(image, str) and os.path.exists(image):
|
181 |
-
return Image.open(image)
|
182 |
-
elif isinstance(image, dict):
|
183 |
-
if 'disk_path' in image:
|
184 |
-
return Image.open(image['disk_path'])
|
185 |
-
elif 'base64' in image:
|
186 |
-
return Image.open(BytesIO(base64.b64decode(image['base64'])))
|
187 |
-
elif 'url' in image:
|
188 |
-
response = requests.get(image['url'])
|
189 |
-
return Image.open(BytesIO(response.content))
|
190 |
-
elif 'bytes' in image:
|
191 |
-
return Image.open(BytesIO(image['bytes']))
|
192 |
-
else:
|
193 |
-
raise ValueError(f'Invalid image: {image}')
|
194 |
-
else:
|
195 |
-
raise ValueError(f'Invalid image: {image}')
|
196 |
-
|
197 |
-
def build_transform(input_size, norm_type='imagenet'):
|
198 |
-
if norm_type == 'imagenet':
|
199 |
-
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
200 |
-
elif norm_type == 'siglip':
|
201 |
-
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
|
202 |
-
|
203 |
-
transform = T.Compose([
|
204 |
-
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
205 |
-
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
206 |
-
T.ToTensor(),
|
207 |
-
T.Normalize(mean=MEAN, std=STD)
|
208 |
-
])
|
209 |
-
return transform
|
210 |
-
|
211 |
-
|
212 |
-
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
213 |
-
"""
|
214 |
-
previous version mainly foucs on ratio.
|
215 |
-
We also consider area ratio here.
|
216 |
-
"""
|
217 |
-
best_factor = float('-inf')
|
218 |
-
best_ratio = (1, 1)
|
219 |
-
area = width * height
|
220 |
-
for ratio in target_ratios:
|
221 |
-
target_aspect_ratio = ratio[0] / ratio[1]
|
222 |
-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
223 |
-
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
|
224 |
-
"""
|
225 |
-
new area > 60% of original image area is enough.
|
226 |
-
"""
|
227 |
-
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
|
228 |
-
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
|
229 |
-
|
230 |
-
if factor_based_on_area_n_ratio > best_factor:
|
231 |
-
best_factor = factor_based_on_area_n_ratio
|
232 |
-
best_ratio = ratio
|
233 |
-
|
234 |
-
return best_ratio
|
235 |
-
|
236 |
-
|
237 |
-
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
|
238 |
-
orig_width, orig_height = image.size
|
239 |
-
aspect_ratio = orig_width / orig_height
|
240 |
-
|
241 |
-
# calculate the existing image aspect ratio
|
242 |
-
target_ratios = set(
|
243 |
-
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
244 |
-
i * j <= max_num and i * j >= min_num)
|
245 |
-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
246 |
-
|
247 |
-
# find the closest aspect ratio to the target
|
248 |
-
target_aspect_ratio = find_closest_aspect_ratio(
|
249 |
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
250 |
-
|
251 |
-
# calculate the target width and height
|
252 |
-
target_width = image_size * target_aspect_ratio[0]
|
253 |
-
target_height = image_size * target_aspect_ratio[1]
|
254 |
-
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
255 |
-
|
256 |
-
# resize the image
|
257 |
-
resized_img = image.resize((target_width, target_height))
|
258 |
-
processed_images = []
|
259 |
-
for i in range(blocks):
|
260 |
-
box = (
|
261 |
-
(i % (target_width // image_size)) * image_size,
|
262 |
-
(i // (target_width // image_size)) * image_size,
|
263 |
-
((i % (target_width // image_size)) + 1) * image_size,
|
264 |
-
((i // (target_width // image_size)) + 1) * image_size
|
265 |
-
)
|
266 |
-
# split the image
|
267 |
-
split_img = resized_img.crop(box)
|
268 |
-
processed_images.append(split_img)
|
269 |
-
assert len(processed_images) == blocks
|
270 |
-
if use_thumbnail and len(processed_images) != 1:
|
271 |
-
thumbnail_img = image.resize((image_size, image_size))
|
272 |
-
processed_images.append(thumbnail_img)
|
273 |
-
return processed_images
|
274 |
-
|
275 |
-
def split_model(model_path, device):
|
276 |
-
|
277 |
-
device_map = {}
|
278 |
-
world_size = torch.cuda.device_count()
|
279 |
-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
280 |
-
num_layers = config.llm_config.num_hidden_layers
|
281 |
-
|
282 |
-
print('world_size', world_size)
|
283 |
-
num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
|
284 |
-
num_layers_per_gpu = [num_layers_per_gpu_] * world_size
|
285 |
-
num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
|
286 |
-
print(num_layers_per_gpu)
|
287 |
-
layer_cnt = 0
|
288 |
-
for i, num_layer in enumerate(num_layers_per_gpu):
|
289 |
-
for j in range(num_layer):
|
290 |
-
device_map[f'language_model.model.layers.{layer_cnt}'] = i
|
291 |
-
layer_cnt += 1
|
292 |
-
device_map['vision_model'] = device
|
293 |
-
device_map['mlp1'] = device
|
294 |
-
device_map['language_model.model.tok_embeddings'] = device
|
295 |
-
device_map['language_model.model.embed_tokens'] = device
|
296 |
-
device_map['language_model.output'] = device
|
297 |
-
device_map['language_model.model.norm'] = device
|
298 |
-
device_map['language_model.lm_head'] = device
|
299 |
-
device_map['language_model.model.rotary_emb'] = device
|
300 |
-
device_map[f'language_model.model.layers.{num_layers - 1}'] = device
|
301 |
-
return device_map
|
302 |
-
|
303 |
-
class ModelWorker:
|
304 |
-
def __init__(self, model_path, model_name,
|
305 |
-
load_8bit, device):
|
306 |
-
|
307 |
-
if model_path.endswith('/'):
|
308 |
-
model_path = model_path[:-1]
|
309 |
-
if model_name is None:
|
310 |
-
model_paths = model_path.split('/')
|
311 |
-
if model_paths[-1].startswith('checkpoint-'):
|
312 |
-
self.model_name = model_paths[-2] + '_' + model_paths[-1]
|
313 |
-
else:
|
314 |
-
self.model_name = model_paths[-1]
|
315 |
-
else:
|
316 |
-
self.model_name = model_name
|
317 |
-
|
318 |
-
print(f'Loading the model {self.model_name}')
|
319 |
-
|
320 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
|
321 |
-
tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
|
322 |
-
tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
|
323 |
-
self.tokenizer = tokenizer
|
324 |
-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
325 |
-
model_type = config.vision_config.model_type
|
326 |
-
self.device = torch.cuda.current_device()
|
327 |
-
if model_type == 'siglip_vision_model':
|
328 |
-
self.norm_type = 'siglip'
|
329 |
-
elif model_type == 'MOB':
|
330 |
-
self.norm_type = 'siglip'
|
331 |
-
else:
|
332 |
-
self.norm_type = 'imagenet'
|
333 |
-
|
334 |
-
if any(x in model_path.lower() for x in ['34b']):
|
335 |
-
device_map = split_model(model_path, self.device)
|
336 |
-
else:
|
337 |
-
device_map = None
|
338 |
-
|
339 |
-
if device_map is not None:
|
340 |
-
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
|
341 |
-
low_cpu_mem_usage=True,
|
342 |
-
device_map=device_map,
|
343 |
-
trust_remote_code=True,
|
344 |
-
load_in_8bit=load_8bit).eval()
|
345 |
-
else:
|
346 |
-
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
|
347 |
-
trust_remote_code=True,
|
348 |
-
load_in_8bit=load_8bit).eval()
|
349 |
-
|
350 |
-
if not load_8bit and device_map is None:
|
351 |
-
self.model = self.model.to(device)
|
352 |
-
self.load_8bit = load_8bit
|
353 |
-
|
354 |
-
self.model_path = model_path
|
355 |
-
self.image_size = self.model.config.force_image_size
|
356 |
-
self.context_len = tokenizer.model_max_length
|
357 |
-
self.per_tile_len = 256
|
358 |
-
|
359 |
-
def reload_model(self):
|
360 |
-
del self.model
|
361 |
-
torch.cuda.empty_cache()
|
362 |
-
if self.device == 'auto':
|
363 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
364 |
-
# This can make distributed deployment work properly
|
365 |
-
self.model = AutoModel.from_pretrained(
|
366 |
-
self.model_path,
|
367 |
-
load_in_8bit=self.load_8bit,
|
368 |
-
torch_dtype=torch.bfloat16,
|
369 |
-
device_map=self.device_map,
|
370 |
-
trust_remote_code=True).eval()
|
371 |
-
else:
|
372 |
-
self.model = AutoModel.from_pretrained(
|
373 |
-
self.model_path,
|
374 |
-
load_in_8bit=self.load_8bit,
|
375 |
-
torch_dtype=torch.bfloat16,
|
376 |
-
trust_remote_code=True).eval()
|
377 |
-
if not self.load_8bit and not self.device == 'auto':
|
378 |
-
self.model = self.model.cuda()
|
379 |
-
|
380 |
-
@torch.inference_mode()
|
381 |
-
def generate(self, params):
|
382 |
-
system_message = params['prompt'][0]['content']
|
383 |
-
send_messages = params['prompt'][1:]
|
384 |
-
max_input_tiles = params['max_input_tiles']
|
385 |
-
temperature = params['temperature']
|
386 |
-
top_p = params['top_p']
|
387 |
-
max_new_tokens = params['max_new_tokens']
|
388 |
-
repetition_penalty = params['repetition_penalty']
|
389 |
-
video_frame_num = params.get('video_frame_num', 64)
|
390 |
-
do_sample = True if temperature > 0.0 else False
|
391 |
-
|
392 |
-
global_image_cnt = 0
|
393 |
-
history, pil_images, max_input_tile_list = [], [], []
|
394 |
-
for message in send_messages:
|
395 |
-
if message['role'] == 'user':
|
396 |
-
prefix = ''
|
397 |
-
if 'image' in message:
|
398 |
-
for image_data in message['image']:
|
399 |
-
pil_images.append(load_image(image_data))
|
400 |
-
prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
|
401 |
-
global_image_cnt += 1
|
402 |
-
max_input_tile_list.append(max_input_tiles)
|
403 |
-
if 'video' in message:
|
404 |
-
for video_data in message['video']:
|
405 |
-
video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
|
406 |
-
pil_images.extend(video_frames)
|
407 |
-
prefix = prefix + tmp_prefix
|
408 |
-
global_image_cnt += len(video_frames)
|
409 |
-
max_input_tile_list.extend([1] * len(video_frames))
|
410 |
-
content = prefix + message['content']
|
411 |
-
history.append([content, ])
|
412 |
-
else:
|
413 |
-
history[-1].append(message['content'])
|
414 |
-
question, history = history[-1][0], history[:-1]
|
415 |
-
|
416 |
-
if global_image_cnt == 1:
|
417 |
-
question = question.replace('<image 1><image>\n', '<image>\n')
|
418 |
-
history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
|
419 |
-
|
420 |
-
|
421 |
-
try:
|
422 |
-
assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
|
423 |
-
except Exception as e:
|
424 |
-
from IPython import embed; embed()
|
425 |
-
exit()
|
426 |
-
print(f'Error: {e}')
|
427 |
-
print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
|
428 |
-
# raise e
|
429 |
-
|
430 |
-
old_system_message = self.model.system_message
|
431 |
-
self.model.system_message = system_message
|
432 |
-
|
433 |
-
transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
|
434 |
-
if len(pil_images) > 0:
|
435 |
-
max_input_tiles_limited_by_contect = params['max_input_tiles']
|
436 |
-
while True:
|
437 |
-
image_tiles = []
|
438 |
-
for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
|
439 |
-
if self.model.config.dynamic_image_size:
|
440 |
-
tiles = dynamic_preprocess(
|
441 |
-
pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
|
442 |
-
use_thumbnail=self.model.config.use_thumbnail)
|
443 |
-
else:
|
444 |
-
tiles = [pil_image]
|
445 |
-
image_tiles += tiles
|
446 |
-
if (len(image_tiles) * self.per_tile_len < self.context_len):
|
447 |
-
break
|
448 |
-
else:
|
449 |
-
max_input_tiles_limited_by_contect -= 2
|
450 |
-
|
451 |
-
if max_input_tiles_limited_by_contect < 1:
|
452 |
-
break
|
453 |
-
|
454 |
-
pixel_values = [transform(item) for item in image_tiles]
|
455 |
-
pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
|
456 |
-
print(f'Split images to {pixel_values.shape}')
|
457 |
-
else:
|
458 |
-
pixel_values = None
|
459 |
-
|
460 |
-
generation_config = dict(
|
461 |
-
num_beams=1,
|
462 |
-
max_new_tokens=max_new_tokens,
|
463 |
-
do_sample=do_sample,
|
464 |
-
temperature=temperature,
|
465 |
-
repetition_penalty=repetition_penalty,
|
466 |
-
max_length=self.context_len,
|
467 |
-
top_p=top_p,
|
468 |
-
)
|
469 |
-
|
470 |
-
response = self.model.chat(
|
471 |
-
tokenizer=self.tokenizer,
|
472 |
-
pixel_values=pixel_values,
|
473 |
-
question=question,
|
474 |
-
history=history,
|
475 |
-
return_history=False,
|
476 |
-
generation_config=generation_config,
|
477 |
-
)
|
478 |
-
self.model.system_message = old_system_message
|
479 |
-
return {'text': response, 'error_code': 0}
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
if __name__ == '__main__':
|
486 |
-
parser = argparse.ArgumentParser()
|
487 |
-
parser.add_argument('--model-path', type=str, default='nvidia/Eagle2-2B')
|
488 |
-
parser.add_argument('--model-name', type=str, default='Eagle2-2B')
|
489 |
-
parser.add_argument('--device', type=str, default='cuda')
|
490 |
-
parser.add_argument('--load-8bit', action='store_true')
|
491 |
-
args = parser.parse_args()
|
492 |
-
print(f'args: {args}')
|
493 |
-
|
494 |
-
worker = ModelWorker(
|
495 |
-
args.model_path,
|
496 |
-
args.model_name,
|
497 |
-
args.load_8bit,
|
498 |
-
args.device)
|
499 |
```
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
-
- Single image input
|
506 |
```python
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
```
|
516 |
|
517 |
-
|
|
|
518 |
```python
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
```
|
529 |
|
530 |
-
|
|
|
531 |
```python
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
```
|
541 |
|
542 |
-
###
|
|
|
543 |
```python
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
}
|
552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
```
|
554 |
|
555 |
## TODO
|
|
|
18 |
# Eagle-2
|
19 |
|
20 |
[\[📂 GitHub\]](https://github.com/NVlabs/EAGLE) [\[📜 Eagle2 Tech Report\]](http://arxiv.org/abs/2501.14818)
|
21 |
+
[\[🤗 HF Demo\]](https://huggingface.co/spaces/nvidia/Eagle2-Demo)
|
22 |
+
|
23 |
+
# News:
|
24 |
+
- We update the model arch to `eagle_2_5_vl` to support `generate` feature.
|
25 |
+
|
26 |
+
|
27 |
## Introduction
|
28 |
|
29 |
We are thrilled to release our latest Eagle2 series Vision-Language Model. Open-source Vision-Language Models (VLMs) have made significant strides in narrowing the gap with proprietary models. However, critical details about data strategies and implementation are often missing, limiting reproducibility and innovation. In this project, we focus on VLM post-training from a data-centric perspective, sharing insights into building effective data strategies from scratch. By combining these strategies with robust training recipes and model design, we introduce Eagle2, a family of performant VLMs. Our work aims to empower the open-source community to develop competitive VLMs with transparent processes.
|
|
|
71 |
|
72 |
|
73 |
|
74 |
+
We provide a [inference script](./demo.py) to help you quickly start using the model. We support different input types:
|
75 |
- pure text input
|
76 |
- single image input
|
77 |
- multiple image input
|
78 |
- video input
|
79 |
|
80 |
+
### Install the dependencies
|
81 |
|
82 |
```bash
|
83 |
pip install transformers
|
84 |
pip install flash-attn
|
85 |
```
|
|
|
86 |
|
|
|
87 |
|
88 |
+
### single image
|
|
|
89 |
|
90 |
```python
|
91 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
import requests
|
93 |
+
from transformers import AutoProcessor, AutoModel
|
94 |
import torch
|
95 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
|
96 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
97 |
+
processor.tokenizer.padding_side = "left"
|
98 |
+
|
99 |
+
messages = [
|
100 |
+
{
|
101 |
+
"role": "user",
|
102 |
+
"content": [
|
103 |
+
{
|
104 |
+
"type": "image",
|
105 |
+
"image": "https://www.ilankelman.org/stopsigns/australia.jpg",
|
106 |
+
},
|
107 |
+
{"type": "text", "text": "Describe this image."},
|
108 |
+
],
|
109 |
+
}
|
110 |
+
]
|
111 |
+
|
112 |
+
text_list = [processor.apply_chat_template(
|
113 |
+
messages, tokenize=False, add_generation_prompt=True
|
114 |
+
)]
|
115 |
+
image_inputs, video_inputs = processor.process_vision_info(messages)
|
116 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
|
117 |
+
inputs = inputs.to("cuda")
|
118 |
+
model = model.to("cuda")
|
119 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
120 |
+
output_text = processor.batch_decode(
|
121 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
122 |
+
)
|
123 |
+
print(output_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
```
|
125 |
+
|
126 |
+
### stream generation
|
127 |
+
|
128 |
+
```python
|
129 |
+
from PIL import Image
|
130 |
+
import requests
|
131 |
+
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
132 |
+
import torch
|
133 |
+
|
134 |
+
from transformers import TextIteratorStreamer
|
135 |
+
import threading
|
136 |
|
137 |
|
138 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16)
|
139 |
+
tokenizer = AutoTokenizer.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
140 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
141 |
+
processor.tokenizer.padding_side = "left"
|
142 |
+
|
143 |
+
messages = [
|
144 |
+
{
|
145 |
+
"role": "user",
|
146 |
+
"content": [
|
147 |
+
{
|
148 |
+
"type": "image",
|
149 |
+
"image": "https://www.ilankelman.org/stopsigns/australia.jpg",
|
150 |
+
},
|
151 |
+
{"type": "text", "text": "Describe this image."},
|
152 |
+
],
|
153 |
+
}
|
154 |
+
]
|
155 |
+
|
156 |
+
text_list = [processor.apply_chat_template(
|
157 |
+
messages, tokenize=False, add_generation_prompt=True
|
158 |
+
)]
|
159 |
+
image_inputs, video_inputs = processor.process_vision_info(messages)
|
160 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
|
161 |
+
inputs = inputs.to("cuda")
|
162 |
+
model = model.to("cuda")
|
163 |
+
|
164 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
165 |
+
|
166 |
+
generation_kwargs = dict(
|
167 |
+
**inputs,
|
168 |
+
streamer=streamer,
|
169 |
+
max_new_tokens=1024,
|
170 |
+
do_sample=True,
|
171 |
+
top_p=0.95,
|
172 |
+
temperature=0.8
|
173 |
+
)
|
174 |
+
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
175 |
+
thread.start()
|
176 |
+
|
177 |
+
|
178 |
+
for new_text in streamer:
|
179 |
+
print(new_text, end="", flush=True)
|
180 |
+
```
|
181 |
+
|
182 |
+
### multiple-images
|
183 |
|
|
|
184 |
```python
|
185 |
+
from PIL import Image
|
186 |
+
import requests
|
187 |
+
from transformers import AutoProcessor, AutoModel
|
188 |
+
import torch
|
189 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
|
190 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
191 |
+
processor.tokenizer.padding_side = "left"
|
192 |
+
|
193 |
+
messages = [
|
194 |
+
{
|
195 |
+
"role": "user",
|
196 |
+
"content": [
|
197 |
+
{
|
198 |
+
"type": "image",
|
199 |
+
"image": "https://www.ilankelman.org/stopsigns/australia.jpg",
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"type": "image",
|
203 |
+
"image": "https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]",
|
204 |
+
},
|
205 |
+
{"type": "text", "text": "Describe these two images."},
|
206 |
+
],
|
207 |
+
}
|
208 |
+
]
|
209 |
+
|
210 |
+
text_list = [processor.apply_chat_template(
|
211 |
+
messages, tokenize=False, add_generation_prompt=True
|
212 |
+
)]
|
213 |
+
image_inputs, video_inputs = processor.process_vision_info(messages)
|
214 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
|
215 |
+
inputs = inputs.to("cuda")
|
216 |
+
model = model.to("cuda")
|
217 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
218 |
+
output_text = processor.batch_decode(
|
219 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
220 |
+
)
|
221 |
+
print(output_text)
|
222 |
```
|
223 |
|
224 |
+
### single video
|
225 |
+
|
226 |
```python
|
227 |
+
|
228 |
+
from PIL import Image
|
229 |
+
import requests
|
230 |
+
from transformers import AutoProcessor, AutoModel
|
231 |
+
import torch
|
232 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
|
233 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
234 |
+
processor.tokenizer.padding_side = "left"
|
235 |
+
|
236 |
+
messages = [
|
237 |
+
{
|
238 |
+
"role": "user",
|
239 |
+
"content": [
|
240 |
+
{
|
241 |
+
"type": "video",
|
242 |
+
"video": "../Eagle2-8B/space_woaudio.mp4",
|
243 |
+
},
|
244 |
+
{"type": "text", "text": "Describe this video."},
|
245 |
+
],
|
246 |
+
}
|
247 |
+
]
|
248 |
+
|
249 |
+
text_list = [processor.apply_chat_template(
|
250 |
+
messages, tokenize=False, add_generation_prompt=True
|
251 |
+
)]
|
252 |
+
image_inputs, video_inputs, video_kwargs = processor.process_vision_info(messages, return_video_kwargs=True)
|
253 |
+
|
254 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True, videos_kwargs=video_kwargs)
|
255 |
+
inputs = inputs.to("cuda")
|
256 |
+
model = model.to("cuda")
|
257 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
258 |
+
output_text = processor.batch_decode(
|
259 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
260 |
+
)
|
261 |
+
print(output_text)
|
262 |
+
|
263 |
```
|
264 |
|
265 |
+
### multieple videos
|
266 |
+
|
267 |
```python
|
268 |
+
from PIL import Image
|
269 |
+
import requests
|
270 |
+
from transformers import AutoProcessor, AutoModel
|
271 |
+
import torch
|
272 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
|
273 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
274 |
+
processor.tokenizer.padding_side = "left"
|
275 |
+
|
276 |
+
messages = [
|
277 |
+
{
|
278 |
+
"role": "user",
|
279 |
+
"content": [
|
280 |
+
{
|
281 |
+
"type": "video",
|
282 |
+
"video": "../Eagle2-8B/space_woaudio.mp4",
|
283 |
+
"nframes": 10,
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"type": "video",
|
287 |
+
"video": "../Eagle2-8B/video_ocr.mp4",
|
288 |
+
"nframes": 10,
|
289 |
+
},
|
290 |
+
{"type": "text", "text": "Describe these two videos respectively."},
|
291 |
+
],
|
292 |
+
}
|
293 |
+
]
|
294 |
+
|
295 |
+
text_list = [processor.apply_chat_template(
|
296 |
+
messages, tokenize=False, add_generation_prompt=True
|
297 |
+
)]
|
298 |
+
image_inputs, video_inputs, video_kwargs = processor.process_vision_info(messages, return_video_kwargs=True)
|
299 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True, videos_kwargs=video_kwargs)
|
300 |
+
inputs = inputs.to("cuda")
|
301 |
+
model = model.to("cuda")
|
302 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
303 |
+
output_text = processor.batch_decode(
|
304 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
305 |
+
)
|
306 |
+
print(output_text)
|
307 |
```
|
308 |
|
309 |
+
### batch inference
|
310 |
+
|
311 |
```python
|
312 |
+
from PIL import Image
|
313 |
+
import requests
|
314 |
+
from transformers import AutoProcessor, AutoModel
|
315 |
+
import torch
|
316 |
+
model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
|
317 |
+
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
|
318 |
+
processor.tokenizer.padding_side = "left"
|
319 |
+
|
320 |
+
messages1 = [
|
321 |
+
{
|
322 |
+
"role": "user",
|
323 |
+
"content": [
|
324 |
+
{
|
325 |
+
"type": "image",
|
326 |
+
"image": "https://www.ilankelman.org/stopsigns/australia.jpg",
|
327 |
+
},
|
328 |
+
{"type": "text", "text": "Describe this image."},
|
329 |
+
],
|
330 |
+
}
|
331 |
+
]
|
332 |
+
|
333 |
+
messages2 = [
|
334 |
+
{
|
335 |
+
"role": "user",
|
336 |
+
"content": [
|
337 |
+
{
|
338 |
+
"type": "image",
|
339 |
+
"image": "https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]",
|
340 |
+
},
|
341 |
+
{"type": "text", "text": "Describe this image."},
|
342 |
+
],
|
343 |
}
|
344 |
+
]
|
345 |
+
|
346 |
+
text_list = [processor.apply_chat_template(
|
347 |
+
messages, tokenize=False, add_generation_prompt=True
|
348 |
+
) for messages in [messages1, messages2]]
|
349 |
+
image_inputs, video_inputs = processor.process_vision_info([messages1, messages2])
|
350 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
|
351 |
+
inputs = inputs.to("cuda")
|
352 |
+
model = model.to("cuda")
|
353 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
354 |
+
output_text = processor.batch_decode(
|
355 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
356 |
+
)
|
357 |
+
print(output_text)
|
358 |
```
|
359 |
|
360 |
## TODO
|
chat_template.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}<image {{ image_count.value }}>{% endif %}<image-{{ image_count.value }}>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}<video {{ video_count.value }}>{% endif %}<video-{{ video_count.value }}>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"}
|
config.json
CHANGED
@@ -1,32 +1,33 @@
|
|
1 |
{
|
2 |
-
"
|
|
|
3 |
"_name_or_path": "",
|
4 |
"architectures": [
|
5 |
-
"
|
6 |
],
|
7 |
"auto_map": {
|
8 |
-
"AutoConfig": "
|
9 |
-
"AutoModel": "
|
10 |
-
|
11 |
-
},
|
12 |
"downsample_ratio": 0.5,
|
13 |
"dynamic_image_size": true,
|
14 |
-
"efficient_loss": true,
|
15 |
"force_image_size": 448,
|
16 |
-
"
|
17 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"_name_or_path": "./pretrained/Qwen2_5-1_5B-Instruct",
|
19 |
"add_cross_attention": false,
|
20 |
"architectures": [
|
21 |
"Qwen2ForCausalLM"
|
22 |
],
|
23 |
"attention_dropout": 0.0,
|
24 |
-
"attn_implementation": "flash_attention_2",
|
25 |
-
"auto_map": {
|
26 |
-
"AutoConfig": "configuration_qwen2.Qwen2Config",
|
27 |
-
"AutoModel": "modeling_qwen2.Qwen2Model",
|
28 |
-
"AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
|
29 |
-
},
|
30 |
"bad_words_ids": null,
|
31 |
"begin_suppress_tokens": null,
|
32 |
"bos_token_id": 151643,
|
@@ -102,105 +103,25 @@
|
|
102 |
"use_sliding_window": false,
|
103 |
"vocab_size": 151674
|
104 |
},
|
105 |
-
"
|
106 |
-
"max_dynamic_patch": 12,
|
107 |
-
"min_dynamic_patch": 1,
|
108 |
-
"mlp_checkpoint": true,
|
109 |
-
"model_type": "eagle_chat",
|
110 |
-
"pad2square": false,
|
111 |
-
"pre_feature_reduction": false,
|
112 |
-
"ps_version": "v2",
|
113 |
-
"select_layer": -1,
|
114 |
-
"template": "qwen2-chat",
|
115 |
"torch_dtype": "bfloat16",
|
116 |
-
"transformers_version":
|
117 |
"use_backbone_lora": 0,
|
118 |
"use_llm_lora": 0,
|
119 |
"use_thumbnail": true,
|
120 |
"vision_config": {
|
121 |
-
"_name_or_path": "",
|
122 |
-
"add_cross_attention": false,
|
123 |
-
"architectures": [
|
124 |
-
"SiglipVisionModel"
|
125 |
-
],
|
126 |
"attention_dropout": 0.0,
|
127 |
-
"auto_map": {
|
128 |
-
"AutoConfig": "configuration_siglip.SiglipVisionConfig",
|
129 |
-
"AutoModel": "modeling_siglip.SiglipVisionModel"
|
130 |
-
},
|
131 |
-
"bad_words_ids": null,
|
132 |
-
"begin_suppress_tokens": null,
|
133 |
-
"bos_token_id": null,
|
134 |
-
"chunk_size_feed_forward": 0,
|
135 |
-
"cross_attention_hidden_size": null,
|
136 |
-
"decoder_start_token_id": null,
|
137 |
-
"diversity_penalty": 0.0,
|
138 |
-
"do_sample": false,
|
139 |
"drop_path_rate": 0.1,
|
140 |
-
"early_stopping": false,
|
141 |
-
"encoder_no_repeat_ngram_size": 0,
|
142 |
-
"eos_token_id": null,
|
143 |
-
"exponential_decay_length_penalty": null,
|
144 |
-
"finetuning_task": null,
|
145 |
-
"forced_bos_token_id": null,
|
146 |
-
"forced_eos_token_id": null,
|
147 |
"hidden_act": "gelu_pytorch_tanh",
|
148 |
"hidden_size": 1152,
|
149 |
-
"id2label": {
|
150 |
-
"0": "LABEL_0",
|
151 |
-
"1": "LABEL_1"
|
152 |
-
},
|
153 |
"image_size": 448,
|
154 |
"intermediate_size": 4304,
|
155 |
-
"is_decoder": false,
|
156 |
-
"is_encoder_decoder": false,
|
157 |
-
"label2id": {
|
158 |
-
"LABEL_0": 0,
|
159 |
-
"LABEL_1": 1
|
160 |
-
},
|
161 |
"layer_norm_eps": 1e-06,
|
162 |
-
"length_penalty": 1.0,
|
163 |
-
"max_length": 20,
|
164 |
-
"min_length": 0,
|
165 |
"model_type": "siglip_vision_model",
|
166 |
-
"no_repeat_ngram_size": 0,
|
167 |
"num_attention_heads": 16,
|
168 |
-
"num_beam_groups": 1,
|
169 |
-
"num_beams": 1,
|
170 |
"num_channels": 3,
|
171 |
"num_hidden_layers": 27,
|
172 |
-
"num_image_tokens": 1024,
|
173 |
-
"num_return_sequences": 1,
|
174 |
-
"output_attentions": false,
|
175 |
-
"output_hidden_states": false,
|
176 |
-
"output_scores": false,
|
177 |
-
"pad_token_id": null,
|
178 |
"patch_size": 14,
|
179 |
-
"
|
180 |
-
|
181 |
-
"projection_dim": 2048,
|
182 |
-
"projector_hidden_act": "gelu_fast",
|
183 |
-
"pruned_heads": {},
|
184 |
-
"remove_invalid_values": false,
|
185 |
-
"repetition_penalty": 1.0,
|
186 |
-
"return_dict": true,
|
187 |
-
"return_dict_in_generate": false,
|
188 |
-
"sep_token_id": null,
|
189 |
-
"suppress_tokens": null,
|
190 |
-
"task_specific_params": null,
|
191 |
-
"temperature": 1.0,
|
192 |
-
"tf_legacy_loss": false,
|
193 |
-
"tie_encoder_decoder": false,
|
194 |
-
"tie_word_embeddings": true,
|
195 |
-
"tokenizer_class": null,
|
196 |
-
"top_k": 50,
|
197 |
-
"top_p": 1.0,
|
198 |
-
"torch_dtype": "bfloat16",
|
199 |
-
"torchscript": false,
|
200 |
-
"transformers_version": "4.37.2",
|
201 |
-
"typical_p": 1.0,
|
202 |
-
"use_bfloat16": false,
|
203 |
-
"vision_use_head": false,
|
204 |
-
"_attn_implementation": "flash_attention_2"
|
205 |
-
}
|
206 |
}
|
|
|
1 |
{
|
2 |
+
"_attn_implementation": "flash_attention_2",
|
3 |
+
"_attn_implementation_autoset": false,
|
4 |
"_name_or_path": "",
|
5 |
"architectures": [
|
6 |
+
"Eagle2_5_VLForConditionalGeneration"
|
7 |
],
|
8 |
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_eagle2_5_vl.Eagle2_5_VLConfig",
|
10 |
+
"AutoModel": "modeling_eagle2_5_vl.Eagle2_5_VLForConditionalGeneration"
|
11 |
+
},
|
|
|
12 |
"downsample_ratio": 0.5,
|
13 |
"dynamic_image_size": true,
|
|
|
14 |
"force_image_size": 448,
|
15 |
+
"image_token_index": 151667,
|
16 |
+
"max_dynamic_tiles": 12,
|
17 |
+
"min_dynamic_tiles": 1,
|
18 |
+
"mlp_checkpoint": false,
|
19 |
+
"model_type": "eagle_2_5_vl",
|
20 |
+
"pad2square": false,
|
21 |
+
"pre_feature_reduction": false,
|
22 |
+
"select_layer": -1,
|
23 |
+
"template": "qwen2-chat",
|
24 |
+
"text_config": {
|
25 |
"_name_or_path": "./pretrained/Qwen2_5-1_5B-Instruct",
|
26 |
"add_cross_attention": false,
|
27 |
"architectures": [
|
28 |
"Qwen2ForCausalLM"
|
29 |
],
|
30 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"bad_words_ids": null,
|
32 |
"begin_suppress_tokens": null,
|
33 |
"bos_token_id": 151643,
|
|
|
103 |
"use_sliding_window": false,
|
104 |
"vocab_size": 151674
|
105 |
},
|
106 |
+
"tie_word_embeddings": true,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
"torch_dtype": "bfloat16",
|
108 |
+
"transformers_version": "4.51.0",
|
109 |
"use_backbone_lora": 0,
|
110 |
"use_llm_lora": 0,
|
111 |
"use_thumbnail": true,
|
112 |
"vision_config": {
|
|
|
|
|
|
|
|
|
|
|
113 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
"drop_path_rate": 0.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
"hidden_act": "gelu_pytorch_tanh",
|
116 |
"hidden_size": 1152,
|
|
|
|
|
|
|
|
|
117 |
"image_size": 448,
|
118 |
"intermediate_size": 4304,
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
"layer_norm_eps": 1e-06,
|
|
|
|
|
|
|
120 |
"model_type": "siglip_vision_model",
|
|
|
121 |
"num_attention_heads": 16,
|
|
|
|
|
122 |
"num_channels": 3,
|
123 |
"num_hidden_layers": 27,
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
"patch_size": 14,
|
125 |
+
"torch_dtype": "bfloat16"
|
126 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
}
|
configuration_eagle_chat.py → configuration_eagle2_5_vl.py
RENAMED
@@ -1,80 +1,88 @@
|
|
1 |
# --------------------------------------------------------
|
2 |
-
#
|
3 |
# Copyright (c) 2025 NVIDIA
|
4 |
-
# Licensed under The
|
5 |
# --------------------------------------------------------
|
6 |
|
7 |
import copy
|
8 |
|
9 |
-
from transformers import
|
|
|
10 |
from transformers.configuration_utils import PretrainedConfig
|
11 |
from transformers.utils import logging
|
12 |
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
13 |
-
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
16 |
-
class Eagle2ChatConfig(PretrainedConfig):
|
17 |
-
model_type = 'eagle_chat'
|
18 |
-
is_composition = True
|
19 |
|
|
|
|
|
|
|
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
vision_config=None,
|
23 |
-
|
24 |
use_backbone_lora=0,
|
25 |
use_llm_lora=0,
|
26 |
-
|
|
|
27 |
force_image_size=None,
|
28 |
downsample_ratio=0.5,
|
29 |
template=None,
|
30 |
dynamic_image_size=False,
|
31 |
use_thumbnail=False,
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
**kwargs):
|
39 |
super().__init__(**kwargs)
|
40 |
|
41 |
if vision_config is None:
|
42 |
-
vision_config = {}
|
43 |
-
logger.info('vision_config is None. Initializing
|
44 |
|
45 |
-
if
|
46 |
-
|
47 |
-
logger.info('
|
48 |
|
49 |
if vision_config['model_type'] == 'siglip_vision_model':
|
50 |
self.vision_config = SiglipVisionConfig(**vision_config)
|
51 |
else:
|
52 |
raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
else:
|
59 |
-
raise ValueError('Unsupported architecture: {}'.format(
|
60 |
self.use_backbone_lora = use_backbone_lora
|
61 |
self.use_llm_lora = use_llm_lora
|
|
|
|
|
62 |
self.select_layer = select_layer
|
63 |
self.force_image_size = force_image_size
|
64 |
self.downsample_ratio = downsample_ratio
|
65 |
self.template = template
|
66 |
self.dynamic_image_size = dynamic_image_size
|
67 |
self.use_thumbnail = use_thumbnail
|
68 |
-
self.
|
69 |
-
self.
|
70 |
-
self.
|
71 |
-
self.
|
72 |
-
self.
|
73 |
-
self.
|
74 |
-
|
75 |
-
|
76 |
-
logger.info(f'
|
77 |
-
logger.info(f'
|
78 |
|
79 |
def to_dict(self):
|
80 |
"""
|
@@ -85,18 +93,21 @@ class Eagle2ChatConfig(PretrainedConfig):
|
|
85 |
"""
|
86 |
output = copy.deepcopy(self.__dict__)
|
87 |
output['vision_config'] = self.vision_config.to_dict()
|
88 |
-
output['
|
89 |
output['model_type'] = self.__class__.model_type
|
90 |
output['use_backbone_lora'] = self.use_backbone_lora
|
91 |
output['use_llm_lora'] = self.use_llm_lora
|
|
|
92 |
output['select_layer'] = self.select_layer
|
93 |
output['force_image_size'] = self.force_image_size
|
94 |
output['downsample_ratio'] = self.downsample_ratio
|
95 |
output['template'] = self.template
|
96 |
output['dynamic_image_size'] = self.dynamic_image_size
|
97 |
output['use_thumbnail'] = self.use_thumbnail
|
98 |
-
output['
|
99 |
-
output['
|
100 |
-
output['
|
|
|
|
|
101 |
|
102 |
return output
|
|
|
1 |
# --------------------------------------------------------
|
2 |
+
# NVIDIA
|
3 |
# Copyright (c) 2025 NVIDIA
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
|
7 |
import copy
|
8 |
|
9 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
10 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
11 |
from transformers.configuration_utils import PretrainedConfig
|
12 |
from transformers.utils import logging
|
13 |
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
|
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
+
class Eagle2_5_VLConfig(PretrainedConfig):
|
18 |
+
model_type = 'eagle_2_5_vl'
|
19 |
+
is_composition = True
|
20 |
+
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
21 |
def __init__(
|
22 |
self,
|
23 |
vision_config=None,
|
24 |
+
text_config=None,
|
25 |
use_backbone_lora=0,
|
26 |
use_llm_lora=0,
|
27 |
+
pad2square=False,
|
28 |
+
select_layer=-4,
|
29 |
force_image_size=None,
|
30 |
downsample_ratio=0.5,
|
31 |
template=None,
|
32 |
dynamic_image_size=False,
|
33 |
use_thumbnail=False,
|
34 |
+
loss_version='v1',
|
35 |
+
min_dynamic_tiles=1,
|
36 |
+
max_dynamic_tiles=6,
|
37 |
+
mlp_checkpoint=False,
|
38 |
+
initializer_range=0.02,
|
39 |
+
_attn_implementation='flash_attention_2',
|
40 |
+
_attn_implementation_autoset=False,
|
41 |
+
llm_config=None,
|
42 |
+
image_token_index=None,
|
43 |
**kwargs):
|
44 |
super().__init__(**kwargs)
|
45 |
|
46 |
if vision_config is None:
|
47 |
+
vision_config = {'model_type': 'siglip_vision_model'}
|
48 |
+
logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
|
49 |
|
50 |
+
if text_config is None:
|
51 |
+
text_config = {'architectures': ['Qwen2ForCausalLM']}
|
52 |
+
logger.info('text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
|
53 |
|
54 |
if vision_config['model_type'] == 'siglip_vision_model':
|
55 |
self.vision_config = SiglipVisionConfig(**vision_config)
|
56 |
else:
|
57 |
raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
|
58 |
|
59 |
+
|
60 |
+
if text_config['architectures'][0] == 'LlamaForCausalLM':
|
61 |
+
self.text_config = LlamaConfig(**text_config)
|
62 |
+
elif text_config['architectures'][0] == 'Qwen2ForCausalLM':
|
63 |
+
self.text_config = Qwen2Config(**text_config)
|
64 |
else:
|
65 |
+
raise ValueError('Unsupported architecture: {}'.format(text_config['architectures'][0]))
|
66 |
self.use_backbone_lora = use_backbone_lora
|
67 |
self.use_llm_lora = use_llm_lora
|
68 |
+
self.mlp_checkpoint = mlp_checkpoint
|
69 |
+
self.pad2square = pad2square
|
70 |
self.select_layer = select_layer
|
71 |
self.force_image_size = force_image_size
|
72 |
self.downsample_ratio = downsample_ratio
|
73 |
self.template = template
|
74 |
self.dynamic_image_size = dynamic_image_size
|
75 |
self.use_thumbnail = use_thumbnail
|
76 |
+
self.loss_version = loss_version
|
77 |
+
self.initializer_range = initializer_range
|
78 |
+
self.min_dynamic_tiles = min_dynamic_tiles
|
79 |
+
self.max_dynamic_tiles = max_dynamic_tiles
|
80 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
81 |
+
self._attn_implementation = _attn_implementation
|
82 |
+
self._attn_implementation_autoset = _attn_implementation_autoset
|
83 |
+
self.image_token_index = image_token_index
|
84 |
+
logger.info(f'min_dynamic_tiles: {self.min_dynamic_tiles}')
|
85 |
+
logger.info(f'max_dynamic_tiles: {self.max_dynamic_tiles}')
|
86 |
|
87 |
def to_dict(self):
|
88 |
"""
|
|
|
93 |
"""
|
94 |
output = copy.deepcopy(self.__dict__)
|
95 |
output['vision_config'] = self.vision_config.to_dict()
|
96 |
+
output['text_config'] = self.text_config.to_dict()
|
97 |
output['model_type'] = self.__class__.model_type
|
98 |
output['use_backbone_lora'] = self.use_backbone_lora
|
99 |
output['use_llm_lora'] = self.use_llm_lora
|
100 |
+
output['pad2square'] = self.pad2square
|
101 |
output['select_layer'] = self.select_layer
|
102 |
output['force_image_size'] = self.force_image_size
|
103 |
output['downsample_ratio'] = self.downsample_ratio
|
104 |
output['template'] = self.template
|
105 |
output['dynamic_image_size'] = self.dynamic_image_size
|
106 |
output['use_thumbnail'] = self.use_thumbnail
|
107 |
+
output['min_dynamic_tiles'] = self.min_dynamic_tiles
|
108 |
+
output['max_dynamic_tiles'] = self.max_dynamic_tiles
|
109 |
+
output['tie_word_embeddings'] = self.tie_word_embeddings
|
110 |
+
output['_attn_implementation'] = self._attn_implementation
|
111 |
+
output['_attn_implementation_autoset'] = self._attn_implementation_autoset
|
112 |
|
113 |
return output
|
demo.py
CHANGED
@@ -1,428 +1,55 @@
|
|
1 |
-
|
2 |
-
"""
|
3 |
-
A model worker executes the model.
|
4 |
-
"""
|
5 |
-
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
|
6 |
-
import argparse
|
7 |
-
import base64
|
8 |
-
import json
|
9 |
-
import os
|
10 |
-
import decord
|
11 |
-
import threading
|
12 |
-
import time
|
13 |
-
from io import BytesIO
|
14 |
-
from threading import Thread
|
15 |
-
import math
|
16 |
import requests
|
|
|
17 |
import torch
|
18 |
-
import torchvision.transforms as T
|
19 |
-
from PIL import Image
|
20 |
-
from torchvision.transforms.functional import InterpolationMode
|
21 |
-
|
22 |
-
import numpy as np
|
23 |
-
|
24 |
-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
25 |
-
IMAGENET_STD = (0.229, 0.224, 0.225)
|
26 |
-
|
27 |
-
SIGLIP_MEAN = (0.5, 0.5, 0.5)
|
28 |
-
SIGLIP_STD = (0.5, 0.5, 0.5)
|
29 |
-
|
30 |
-
|
31 |
-
def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
|
32 |
-
"""
|
33 |
-
Calculate the indices of frames to extract from a video.
|
34 |
-
|
35 |
-
Parameters:
|
36 |
-
total_num_frames (int): Total number of frames in the video.
|
37 |
-
desired_num_frames (int): Desired number of frames to extract.
|
38 |
-
|
39 |
-
Returns:
|
40 |
-
list: List of indices of frames to extract.
|
41 |
-
"""
|
42 |
-
|
43 |
-
assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
|
44 |
-
|
45 |
-
if stride > 0:
|
46 |
-
return list(range(0, total_num_frames, stride))
|
47 |
-
|
48 |
-
# Calculate the size of each segment from which a frame will be extracted
|
49 |
-
seg_size = float(total_num_frames - 1) / desired_num_frames
|
50 |
-
|
51 |
-
seq = []
|
52 |
-
for i in range(desired_num_frames):
|
53 |
-
# Calculate the start and end indices of each segment
|
54 |
-
start = int(np.round(seg_size * i))
|
55 |
-
end = int(np.round(seg_size * (i + 1)))
|
56 |
-
|
57 |
-
# Append the middle index of the segment to the list
|
58 |
-
seq.append((start + end) // 2)
|
59 |
-
|
60 |
-
return seq
|
61 |
-
|
62 |
-
def build_video_prompt(meta_list, num_frames, time_position=False):
|
63 |
-
# if time_position is True, the frame_timestamp is used.
|
64 |
-
# 1. pass time_position, 2. use env TIME_POSITION
|
65 |
-
time_position = os.environ.get("TIME_POSITION", time_position)
|
66 |
-
prefix = f"This is a video:\n"
|
67 |
-
for i in range(num_frames):
|
68 |
-
if time_position:
|
69 |
-
frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
|
70 |
-
else:
|
71 |
-
frame_txt = f"Frame {i+1}: <image>\n"
|
72 |
-
prefix += frame_txt
|
73 |
-
return prefix
|
74 |
-
|
75 |
-
def load_video(video_path, num_frames=64, frame_cache_root=None):
|
76 |
-
if isinstance(video_path, str):
|
77 |
-
video = decord.VideoReader(video_path)
|
78 |
-
elif isinstance(video_path, dict):
|
79 |
-
assert False, 'we not support vidoe: "video_path" as input'
|
80 |
-
fps = video.get_avg_fps()
|
81 |
-
sampled_frames = get_seq_frames(len(video), num_frames)
|
82 |
-
samepld_timestamps = [i / fps for i in sampled_frames]
|
83 |
-
frames = video.get_batch(sampled_frames).asnumpy()
|
84 |
-
images = [Image.fromarray(frame) for frame in frames]
|
85 |
-
|
86 |
-
return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
|
87 |
-
|
88 |
-
def load_image(image):
|
89 |
-
if isinstance(image, str) and os.path.exists(image):
|
90 |
-
return Image.open(image)
|
91 |
-
elif isinstance(image, dict):
|
92 |
-
if 'disk_path' in image:
|
93 |
-
return Image.open(image['disk_path'])
|
94 |
-
elif 'base64' in image:
|
95 |
-
return Image.open(BytesIO(base64.b64decode(image['base64'])))
|
96 |
-
elif 'url' in image:
|
97 |
-
response = requests.get(image['url'])
|
98 |
-
return Image.open(BytesIO(response.content))
|
99 |
-
elif 'bytes' in image:
|
100 |
-
return Image.open(BytesIO(image['bytes']))
|
101 |
-
else:
|
102 |
-
raise ValueError(f'Invalid image: {image}')
|
103 |
-
else:
|
104 |
-
raise ValueError(f'Invalid image: {image}')
|
105 |
-
|
106 |
-
def build_transform(input_size, norm_type='imagenet'):
|
107 |
-
if norm_type == 'imagenet':
|
108 |
-
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
109 |
-
elif norm_type == 'siglip':
|
110 |
-
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
|
111 |
-
|
112 |
-
transform = T.Compose([
|
113 |
-
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
114 |
-
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
115 |
-
T.ToTensor(),
|
116 |
-
T.Normalize(mean=MEAN, std=STD)
|
117 |
-
])
|
118 |
-
return transform
|
119 |
-
|
120 |
-
|
121 |
-
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
122 |
-
"""
|
123 |
-
previous version mainly foucs on ratio.
|
124 |
-
We also consider area ratio here.
|
125 |
-
"""
|
126 |
-
best_factor = float('-inf')
|
127 |
-
best_ratio = (1, 1)
|
128 |
-
area = width * height
|
129 |
-
for ratio in target_ratios:
|
130 |
-
target_aspect_ratio = ratio[0] / ratio[1]
|
131 |
-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
132 |
-
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
|
133 |
-
"""
|
134 |
-
new area > 60% of original image area is enough.
|
135 |
-
"""
|
136 |
-
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
|
137 |
-
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
|
138 |
-
|
139 |
-
if factor_based_on_area_n_ratio > best_factor:
|
140 |
-
best_factor = factor_based_on_area_n_ratio
|
141 |
-
best_ratio = ratio
|
142 |
-
|
143 |
-
return best_ratio
|
144 |
-
|
145 |
-
|
146 |
-
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
|
147 |
-
orig_width, orig_height = image.size
|
148 |
-
aspect_ratio = orig_width / orig_height
|
149 |
-
|
150 |
-
# calculate the existing image aspect ratio
|
151 |
-
target_ratios = set(
|
152 |
-
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
153 |
-
i * j <= max_num and i * j >= min_num)
|
154 |
-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
155 |
-
|
156 |
-
# find the closest aspect ratio to the target
|
157 |
-
target_aspect_ratio = find_closest_aspect_ratio(
|
158 |
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
159 |
-
|
160 |
-
# calculate the target width and height
|
161 |
-
target_width = image_size * target_aspect_ratio[0]
|
162 |
-
target_height = image_size * target_aspect_ratio[1]
|
163 |
-
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
164 |
-
|
165 |
-
# resize the image
|
166 |
-
resized_img = image.resize((target_width, target_height))
|
167 |
-
processed_images = []
|
168 |
-
for i in range(blocks):
|
169 |
-
box = (
|
170 |
-
(i % (target_width // image_size)) * image_size,
|
171 |
-
(i // (target_width // image_size)) * image_size,
|
172 |
-
((i % (target_width // image_size)) + 1) * image_size,
|
173 |
-
((i // (target_width // image_size)) + 1) * image_size
|
174 |
-
)
|
175 |
-
# split the image
|
176 |
-
split_img = resized_img.crop(box)
|
177 |
-
processed_images.append(split_img)
|
178 |
-
assert len(processed_images) == blocks
|
179 |
-
if use_thumbnail and len(processed_images) != 1:
|
180 |
-
thumbnail_img = image.resize((image_size, image_size))
|
181 |
-
processed_images.append(thumbnail_img)
|
182 |
-
return processed_images
|
183 |
-
|
184 |
-
def split_model(model_path, device):
|
185 |
-
|
186 |
-
device_map = {}
|
187 |
-
world_size = torch.cuda.device_count()
|
188 |
-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
189 |
-
num_layers = config.llm_config.num_hidden_layers
|
190 |
-
|
191 |
-
num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
|
192 |
-
num_layers_per_gpu = [num_layers_per_gpu_] * world_size
|
193 |
-
num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
|
194 |
-
layer_cnt = 0
|
195 |
-
for i, num_layer in enumerate(num_layers_per_gpu):
|
196 |
-
for j in range(num_layer):
|
197 |
-
device_map[f'language_model.model.layers.{layer_cnt}'] = i
|
198 |
-
layer_cnt += 1
|
199 |
-
device_map['vision_model'] = device
|
200 |
-
device_map['mlp1'] = device
|
201 |
-
device_map['language_model.model.tok_embeddings'] = device
|
202 |
-
device_map['language_model.model.embed_tokens'] = device
|
203 |
-
device_map['language_model.output'] = device
|
204 |
-
device_map['language_model.model.norm'] = device
|
205 |
-
device_map['language_model.lm_head'] = device
|
206 |
-
device_map['language_model.model.rotary_emb'] = device
|
207 |
-
device_map[f'language_model.model.layers.{num_layers - 1}'] = device
|
208 |
-
return device_map
|
209 |
-
|
210 |
-
class ModelWorker:
|
211 |
-
def __init__(self, model_path, model_name,
|
212 |
-
load_8bit, device):
|
213 |
-
|
214 |
-
if model_path.endswith('/'):
|
215 |
-
model_path = model_path[:-1]
|
216 |
-
if model_name is None:
|
217 |
-
model_paths = model_path.split('/')
|
218 |
-
if model_paths[-1].startswith('checkpoint-'):
|
219 |
-
self.model_name = model_paths[-2] + '_' + model_paths[-1]
|
220 |
-
else:
|
221 |
-
self.model_name = model_paths[-1]
|
222 |
-
else:
|
223 |
-
self.model_name = model_name
|
224 |
-
|
225 |
-
print(f'Loading the model {self.model_name}')
|
226 |
-
|
227 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
|
228 |
-
tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
|
229 |
-
tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
|
230 |
-
self.tokenizer = tokenizer
|
231 |
-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
232 |
-
model_type = config.vision_config.model_type
|
233 |
-
self.device = torch.cuda.current_device()
|
234 |
-
if model_type == 'siglip_vision_model':
|
235 |
-
self.norm_type = 'siglip'
|
236 |
-
elif model_type == 'MOB':
|
237 |
-
self.norm_type = 'siglip'
|
238 |
-
else:
|
239 |
-
self.norm_type = 'imagenet'
|
240 |
-
print('norm_type: ', self.norm_type)
|
241 |
-
if any(x in model_path.lower() for x in ['34b']):
|
242 |
-
device_map = split_model(model_path, self.device)
|
243 |
-
else:
|
244 |
-
device_map = None
|
245 |
-
|
246 |
-
if device_map is not None:
|
247 |
-
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
|
248 |
-
low_cpu_mem_usage=True,
|
249 |
-
device_map=device_map,
|
250 |
-
trust_remote_code=True,
|
251 |
-
load_in_8bit=load_8bit).eval()
|
252 |
-
else:
|
253 |
-
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
|
254 |
-
trust_remote_code=True,
|
255 |
-
load_in_8bit=load_8bit).eval()
|
256 |
-
if not load_8bit and device_map is None:
|
257 |
-
self.model = self.model.to(device)
|
258 |
-
self.load_8bit = load_8bit
|
259 |
-
|
260 |
-
self.model_path = model_path
|
261 |
-
self.image_size = self.model.config.force_image_size
|
262 |
-
self.context_len = tokenizer.model_max_length
|
263 |
-
self.per_tile_len = 256
|
264 |
-
print(self.model)
|
265 |
-
def reload_model(self):
|
266 |
-
del self.model
|
267 |
-
torch.cuda.empty_cache()
|
268 |
-
if self.device == 'auto':
|
269 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
270 |
-
# This can make distributed deployment work properly
|
271 |
-
self.model = AutoModel.from_pretrained(
|
272 |
-
self.model_path,
|
273 |
-
load_in_8bit=self.load_8bit,
|
274 |
-
torch_dtype=torch.bfloat16,
|
275 |
-
device_map=self.device_map,
|
276 |
-
trust_remote_code=True).eval()
|
277 |
-
else:
|
278 |
-
self.model = AutoModel.from_pretrained(
|
279 |
-
self.model_path,
|
280 |
-
load_in_8bit=self.load_8bit,
|
281 |
-
torch_dtype=torch.bfloat16,
|
282 |
-
trust_remote_code=True).eval()
|
283 |
-
if not self.load_8bit and not self.device == 'auto':
|
284 |
-
self.model = self.model.cuda()
|
285 |
-
|
286 |
-
@torch.inference_mode()
|
287 |
-
def generate(self, params):
|
288 |
-
system_message = params['prompt'][0]['content']
|
289 |
-
send_messages = params['prompt'][1:]
|
290 |
-
max_input_tiles = params['max_input_tiles']
|
291 |
-
temperature = params['temperature']
|
292 |
-
top_p = params['top_p']
|
293 |
-
max_new_tokens = params['max_new_tokens']
|
294 |
-
repetition_penalty = params['repetition_penalty']
|
295 |
-
video_frame_num = params.get('video_frame_num', 64)
|
296 |
-
do_sample = True if temperature > 0.0 else False
|
297 |
-
|
298 |
-
global_image_cnt = 0
|
299 |
-
history, pil_images, max_input_tile_list = [], [], []
|
300 |
-
|
301 |
-
for message in send_messages:
|
302 |
-
if message['role'] == 'user':
|
303 |
-
prefix = ''
|
304 |
-
if 'image' in message:
|
305 |
-
for image_data in message['image']:
|
306 |
-
pil_images.append(load_image(image_data))
|
307 |
-
prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
|
308 |
-
global_image_cnt += 1
|
309 |
-
max_input_tile_list.append(max_input_tiles)
|
310 |
-
if 'video' in message:
|
311 |
-
for video_data in message['video']:
|
312 |
-
video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
|
313 |
-
pil_images.extend(video_frames)
|
314 |
-
prefix = prefix + tmp_prefix
|
315 |
-
global_image_cnt += len(video_frames)
|
316 |
-
max_input_tile_list.extend([1] * len(video_frames))
|
317 |
-
content = prefix + message['content']
|
318 |
-
history.append([content, ])
|
319 |
-
else:
|
320 |
-
history[-1].append(message['content'])
|
321 |
-
question, history = history[-1][0], history[:-1]
|
322 |
-
|
323 |
-
if global_image_cnt == 1:
|
324 |
-
question = question.replace('<image 1><image>\n', '<image>\n')
|
325 |
-
history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
|
326 |
-
|
327 |
-
|
328 |
-
try:
|
329 |
-
assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
|
330 |
-
except Exception as e:
|
331 |
-
from IPython import embed; embed()
|
332 |
-
exit()
|
333 |
-
print(f'Error: {e}')
|
334 |
-
print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
|
335 |
-
# raise e
|
336 |
-
|
337 |
-
old_system_message = self.model.system_message
|
338 |
-
self.model.system_message = system_message
|
339 |
-
|
340 |
-
transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
|
341 |
-
if len(pil_images) > 0:
|
342 |
-
max_input_tiles_limited_by_contect = params['max_input_tiles']
|
343 |
-
while True:
|
344 |
-
image_tiles = []
|
345 |
-
num_patches_list = []
|
346 |
-
for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
|
347 |
-
if self.model.config.dynamic_image_size:
|
348 |
-
tiles = dynamic_preprocess(
|
349 |
-
pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
|
350 |
-
use_thumbnail=self.model.config.use_thumbnail)
|
351 |
-
else:
|
352 |
-
tiles = [pil_image]
|
353 |
-
num_patches_list.append(len(tiles))
|
354 |
-
image_tiles += tiles
|
355 |
-
if (len(image_tiles) * self.per_tile_len < self.context_len):
|
356 |
-
break
|
357 |
-
else:
|
358 |
-
max_input_tiles_limited_by_contect -= 2
|
359 |
-
|
360 |
-
if max_input_tiles_limited_by_contect < 1:
|
361 |
-
break
|
362 |
-
|
363 |
-
pixel_values = [transform(item) for item in image_tiles]
|
364 |
-
|
365 |
-
|
366 |
-
pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
|
367 |
-
|
368 |
-
else:
|
369 |
-
pixel_values = None
|
370 |
-
|
371 |
-
generation_config = dict(
|
372 |
-
num_beams=1,
|
373 |
-
max_new_tokens=max_new_tokens,
|
374 |
-
do_sample=do_sample,
|
375 |
-
temperature=temperature,
|
376 |
-
repetition_penalty=repetition_penalty,
|
377 |
-
max_length=self.context_len,
|
378 |
-
top_p=top_p,
|
379 |
-
)
|
380 |
-
print(f'pixel_values: {pixel_values.shape}')
|
381 |
-
response = self.model.chat(
|
382 |
-
tokenizer=self.tokenizer,
|
383 |
-
pixel_values=pixel_values,
|
384 |
-
question=question,
|
385 |
-
history=history,
|
386 |
-
return_history=False,
|
387 |
-
num_patches_list=num_patches_list,
|
388 |
-
generation_config=generation_config,
|
389 |
-
)
|
390 |
-
self.model.system_message = old_system_message
|
391 |
-
return {'text': response, 'error_code': 0}
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
|
|
|
|
|
396 |
|
397 |
-
if __name__ == '__main__':
|
398 |
-
parser = argparse.ArgumentParser()
|
399 |
-
parser.add_argument('--model-path', type=str, default='/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B')
|
400 |
-
parser.add_argument('--model-name', type=str, default='Eagle2')
|
401 |
-
parser.add_argument('--device', type=str, default='cuda')
|
402 |
-
parser.add_argument('--load-8bit', action='store_true')
|
403 |
-
args = parser.parse_args()
|
404 |
-
print(f'args: {args}')
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
'prompt': prompt,
|
422 |
-
'max_input_tiles': 24,
|
423 |
-
'temperature': 0.7,
|
424 |
-
'top_p': 1.0,
|
425 |
-
'max_new_tokens': 4096,
|
426 |
-
'repetition_penalty': 1.0,
|
427 |
}
|
428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import requests
|
3 |
+
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
from transformers import TextIteratorStreamer
|
7 |
+
import threading
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
model = AutoModel.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B",trust_remote_code=True, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B", trust_remote_code=True, use_fast=True)
|
12 |
+
processor = AutoProcessor.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B", trust_remote_code=True, use_fast=True)
|
13 |
+
processor.tokenizer.padding_side = "left"
|
14 |
+
|
15 |
+
messages = [
|
16 |
+
{
|
17 |
+
"role": "user",
|
18 |
+
"content": [
|
19 |
+
{
|
20 |
+
"type": "image",
|
21 |
+
"image": "https://www.ilankelman.org/stopsigns/australia.jpg",
|
22 |
+
},
|
23 |
+
{"type": "text", "text": "Describe this image."},
|
24 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
}
|
26 |
+
]
|
27 |
+
|
28 |
+
text_list = [processor.apply_chat_template(
|
29 |
+
messages, tokenize=False, add_generation_prompt=True
|
30 |
+
)]
|
31 |
+
image_inputs, video_inputs = processor.process_vision_info(messages)
|
32 |
+
inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
|
33 |
+
inputs = inputs.to("cuda")
|
34 |
+
model = model.to("cuda")
|
35 |
+
# generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
36 |
+
# output_text = processor.batch_decode(
|
37 |
+
# generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
38 |
+
# )
|
39 |
+
# print(output_text)
|
40 |
+
|
41 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
42 |
+
|
43 |
+
generation_kwargs = dict(
|
44 |
+
**inputs,
|
45 |
+
streamer=streamer,
|
46 |
+
max_new_tokens=1024,
|
47 |
+
do_sample=True,
|
48 |
+
top_p=0.95,
|
49 |
+
temperature=0.8
|
50 |
+
)
|
51 |
+
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
52 |
+
thread.start()
|
53 |
+
|
54 |
+
for new_text in streamer:
|
55 |
+
print(new_text, end="", flush=True)
|
image_processing_eagle2.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Image processor class for LLaVa-Onevision."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
|
23 |
+
from transformers.image_transforms import (
|
24 |
+
PaddingMode,
|
25 |
+
convert_to_rgb,
|
26 |
+
pad,
|
27 |
+
resize,
|
28 |
+
to_channel_dimension_format,
|
29 |
+
)
|
30 |
+
from transformers.image_utils import (
|
31 |
+
OPENAI_CLIP_MEAN,
|
32 |
+
OPENAI_CLIP_STD,
|
33 |
+
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
34 |
+
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
35 |
+
ChannelDimension,
|
36 |
+
ImageInput,
|
37 |
+
PILImageResampling,
|
38 |
+
get_image_size,
|
39 |
+
infer_channel_dimension_format,
|
40 |
+
is_scaled_image,
|
41 |
+
make_flat_list_of_images,
|
42 |
+
to_numpy_array,
|
43 |
+
valid_images,
|
44 |
+
validate_preprocess_arguments,
|
45 |
+
)
|
46 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
if is_vision_available():
|
53 |
+
from PIL import Image
|
54 |
+
|
55 |
+
def crop(img: np.ndarray, left: int, top: int, right: int, bottom: int, input_data_format: ChannelDimension) -> np.ndarray:
|
56 |
+
"""Crop the given numpy array.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
img (np.ndarray): Image to be cropped. Format should be (H, W, C) or (H, W).
|
60 |
+
left (int): The left coordinate of the crop box.
|
61 |
+
top (int): The top coordinate of the crop box.
|
62 |
+
right (int): The right coordinate of the crop box.
|
63 |
+
bottom (int): The bottom coordinate of the crop box.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
np.ndarray: Cropped image.
|
67 |
+
"""
|
68 |
+
if not isinstance(img, np.ndarray):
|
69 |
+
raise TypeError('img should be numpy array. Got {}'.format(type(img)))
|
70 |
+
|
71 |
+
if img.ndim not in [2, 3]:
|
72 |
+
raise ValueError('Image should have 2 or 3 dimensions. Got {}'.format(img.ndim))
|
73 |
+
|
74 |
+
if input_data_format == ChannelDimension.LAST:
|
75 |
+
img_height = img.shape[0]
|
76 |
+
img_width = img.shape[1]
|
77 |
+
else:
|
78 |
+
img_height = img.shape[1]
|
79 |
+
img_width = img.shape[2]
|
80 |
+
|
81 |
+
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
82 |
+
raise ValueError('Crop coordinates out of bounds')
|
83 |
+
|
84 |
+
if top >= bottom or left >= right:
|
85 |
+
raise ValueError('Invalid crop coordinates')
|
86 |
+
if input_data_format == ChannelDimension.LAST:
|
87 |
+
return img[top:bottom, left:right, :]
|
88 |
+
else:
|
89 |
+
return img[:, top:bottom, left:right]
|
90 |
+
|
91 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches
|
92 |
+
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
|
93 |
+
"""
|
94 |
+
Divides an image into patches of a specified size.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
image (`np.array`):
|
98 |
+
The input image.
|
99 |
+
patch_size (`int`):
|
100 |
+
The size of each patch.
|
101 |
+
input_data_format (`ChannelDimension` or `str`):
|
102 |
+
The channel dimension format of the input image.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
list: A list of np.array representing the patches.
|
106 |
+
"""
|
107 |
+
patches = []
|
108 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
109 |
+
for i in range(0, height, patch_size):
|
110 |
+
for j in range(0, width, patch_size):
|
111 |
+
if input_data_format == ChannelDimension.LAST:
|
112 |
+
patch = image[i : i + patch_size, j : j + patch_size]
|
113 |
+
else:
|
114 |
+
patch = image[:, i : i + patch_size, j : j + patch_size]
|
115 |
+
patches.append(patch)
|
116 |
+
|
117 |
+
return patches
|
118 |
+
|
119 |
+
|
120 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.expand_to_square
|
121 |
+
def expand_to_square(image: np.array, background_color, input_data_format) -> np.array:
|
122 |
+
"""
|
123 |
+
Expands an image to a square by adding a background color.
|
124 |
+
"""
|
125 |
+
|
126 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
127 |
+
if width == height:
|
128 |
+
return image
|
129 |
+
elif width > height:
|
130 |
+
result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
|
131 |
+
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
|
132 |
+
return result
|
133 |
+
else:
|
134 |
+
result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
|
135 |
+
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
|
136 |
+
return result
|
137 |
+
|
138 |
+
|
139 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
|
140 |
+
def _get_patch_output_size(image, target_resolution, input_data_format):
|
141 |
+
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
142 |
+
target_height, target_width = target_resolution
|
143 |
+
|
144 |
+
scale_w = target_width / original_width
|
145 |
+
scale_h = target_height / original_height
|
146 |
+
|
147 |
+
if scale_w < scale_h:
|
148 |
+
new_width = target_width
|
149 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
150 |
+
else:
|
151 |
+
new_height = target_height
|
152 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
153 |
+
|
154 |
+
return new_height, new_width
|
155 |
+
|
156 |
+
|
157 |
+
class Eagle2ImageProcessor(BaseImageProcessor):
|
158 |
+
r"""
|
159 |
+
Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
163 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
164 |
+
`do_resize` in the `preprocess` method.
|
165 |
+
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
166 |
+
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
167 |
+
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
168 |
+
method.
|
169 |
+
image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
|
170 |
+
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
171 |
+
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
172 |
+
method. Not used for processinf videos.
|
173 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
174 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
175 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
176 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
177 |
+
the `preprocess` method.
|
178 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
179 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
180 |
+
method.
|
181 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
182 |
+
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
183 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
184 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
185 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
186 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
187 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
188 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
189 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
190 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
191 |
+
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
192 |
+
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
193 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
194 |
+
Whether to convert the image to RGB.
|
195 |
+
"""
|
196 |
+
|
197 |
+
model_input_names = ["pixel_values_videos"]
|
198 |
+
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
do_resize: bool = True,
|
202 |
+
size: Dict[str, int] = None,
|
203 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
204 |
+
do_rescale: bool = True,
|
205 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
206 |
+
do_normalize: bool = True,
|
207 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
208 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
209 |
+
do_pad: Optional[bool] = True,
|
210 |
+
do_convert_rgb: bool = True,
|
211 |
+
min_dynamic_tiles: int = 1,
|
212 |
+
max_dynamic_tiles: int = 12,
|
213 |
+
use_thumbnail: bool = True,
|
214 |
+
pad_during_tiling: bool = False,
|
215 |
+
**kwargs,
|
216 |
+
) -> None:
|
217 |
+
super().__init__(**kwargs)
|
218 |
+
size = size if size is not None else {"height": 384, "width": 384}
|
219 |
+
size = get_size_dict(size, default_to_square=False)
|
220 |
+
|
221 |
+
self.do_resize = do_resize
|
222 |
+
self.size = size
|
223 |
+
self.resample = resample
|
224 |
+
self.do_rescale = do_rescale
|
225 |
+
self.rescale_factor = rescale_factor
|
226 |
+
self.do_normalize = do_normalize
|
227 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
228 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
229 |
+
self.do_pad = do_pad
|
230 |
+
self.do_convert_rgb = do_convert_rgb
|
231 |
+
self.min_dynamic_tiles = min_dynamic_tiles
|
232 |
+
self.max_dynamic_tiles = max_dynamic_tiles
|
233 |
+
self.use_thumbnail = use_thumbnail
|
234 |
+
self.pad_during_tiling = pad_during_tiling
|
235 |
+
|
236 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.pad
|
237 |
+
def pad(
|
238 |
+
self,
|
239 |
+
image: np.ndarray,
|
240 |
+
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
|
241 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
242 |
+
constant_values: Union[float, Iterable[float]] = 0.0,
|
243 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
244 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
245 |
+
) -> np.ndarray:
|
246 |
+
"""
|
247 |
+
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
|
248 |
+
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
|
249 |
+
as input.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
image (`np.ndarray`):
|
253 |
+
The image to pad.
|
254 |
+
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
|
255 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
256 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
257 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
258 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
259 |
+
mode (`PaddingMode`):
|
260 |
+
The padding mode to use. Can be one of:
|
261 |
+
- `"constant"`: pads with a constant value.
|
262 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
263 |
+
vector along each axis.
|
264 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
265 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
266 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
267 |
+
The value to use for the padding if `mode` is `"constant"`.
|
268 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
269 |
+
The channel dimension format for the output image. Can be one of:
|
270 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
271 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
272 |
+
If unset, will use same as the input image.
|
273 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
274 |
+
The channel dimension format for the input image. Can be one of:
|
275 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
276 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
277 |
+
If unset, will use the inferred format of the input image.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
`np.ndarray`: The padded image.
|
281 |
+
|
282 |
+
"""
|
283 |
+
|
284 |
+
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
|
285 |
+
if isinstance(padding, int) or len(padding) != 4:
|
286 |
+
return pad(image, padding, mode, constant_values, data_format, input_data_format)
|
287 |
+
|
288 |
+
if input_data_format is None:
|
289 |
+
input_data_format = infer_channel_dimension_format(image)
|
290 |
+
if mode == PaddingMode.CONSTANT:
|
291 |
+
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
|
292 |
+
elif mode == PaddingMode.REFLECT:
|
293 |
+
image = np.pad(image, padding, mode="reflect")
|
294 |
+
elif mode == PaddingMode.REPLICATE:
|
295 |
+
image = np.pad(image, padding, mode="edge")
|
296 |
+
elif mode == PaddingMode.SYMMETRIC:
|
297 |
+
image = np.pad(image, padding, mode="symmetric")
|
298 |
+
else:
|
299 |
+
raise ValueError(f"Invalid padding mode: {mode}")
|
300 |
+
image = (
|
301 |
+
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
302 |
+
)
|
303 |
+
return image
|
304 |
+
|
305 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._resize_for_patching
|
306 |
+
def _resize_for_patching(
|
307 |
+
self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
|
308 |
+
) -> np.array:
|
309 |
+
"""
|
310 |
+
Resizes an image to a target resolution while maintaining aspect ratio.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
image (np.array):
|
314 |
+
The input image.
|
315 |
+
target_resolution (tuple):
|
316 |
+
The target resolution (height, width) of the image.
|
317 |
+
resample (`PILImageResampling`):
|
318 |
+
Resampling filter to use if resizing the image.
|
319 |
+
input_data_format (`ChannelDimension` or `str`):
|
320 |
+
The channel dimension format of the input image.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
np.array: The resized and padded image.
|
324 |
+
"""
|
325 |
+
|
326 |
+
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
327 |
+
# Resize the image
|
328 |
+
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
329 |
+
|
330 |
+
return resized_image
|
331 |
+
|
332 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
|
333 |
+
def _pad_for_patching(
|
334 |
+
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
335 |
+
) -> np.array:
|
336 |
+
"""
|
337 |
+
Pad an image to a target resolution while maintaining aspect ratio.
|
338 |
+
"""
|
339 |
+
target_height, target_width = target_resolution
|
340 |
+
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
341 |
+
|
342 |
+
paste_x = (target_width - new_width) // 2
|
343 |
+
paste_y = (target_height - new_height) // 2
|
344 |
+
|
345 |
+
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
346 |
+
|
347 |
+
return padded_image
|
348 |
+
|
349 |
+
|
350 |
+
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
351 |
+
"""
|
352 |
+
previous version mainly foucs on ratio.
|
353 |
+
We also consider area ratio here.
|
354 |
+
"""
|
355 |
+
best_factor = float('-inf')
|
356 |
+
best_ratio = (1, 1)
|
357 |
+
area = width * height
|
358 |
+
for ratio in target_ratios:
|
359 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
360 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
361 |
+
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
|
362 |
+
"""
|
363 |
+
new area > 60% of original image area is enough.
|
364 |
+
"""
|
365 |
+
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
|
366 |
+
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
|
367 |
+
|
368 |
+
if factor_based_on_area_n_ratio > best_factor:
|
369 |
+
best_factor = factor_based_on_area_n_ratio
|
370 |
+
best_ratio = ratio
|
371 |
+
|
372 |
+
return best_ratio
|
373 |
+
|
374 |
+
|
375 |
+
def get_image_patches(
|
376 |
+
self,
|
377 |
+
image: np.array,
|
378 |
+
min_num: int,
|
379 |
+
max_num: int,
|
380 |
+
size: tuple,
|
381 |
+
tile_size: int,
|
382 |
+
use_thumbnail: bool,
|
383 |
+
resample: PILImageResampling,
|
384 |
+
data_format: ChannelDimension,
|
385 |
+
input_data_format: ChannelDimension,
|
386 |
+
):
|
387 |
+
image_size = get_image_size(image, channel_dim=input_data_format)
|
388 |
+
orig_height, orig_width = image_size
|
389 |
+
aspect_ratio = orig_width / orig_height
|
390 |
+
|
391 |
+
# calculate the existing image aspect ratio
|
392 |
+
target_ratios = set(
|
393 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
394 |
+
i * j <= max_num and i * j >= min_num)
|
395 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
396 |
+
|
397 |
+
# find the closest aspect ratio to the target
|
398 |
+
target_aspect_ratio = self.find_closest_aspect_ratio(
|
399 |
+
aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
|
400 |
+
|
401 |
+
# calculate the target width and height
|
402 |
+
target_width = tile_size * target_aspect_ratio[0]
|
403 |
+
target_height = tile_size * target_aspect_ratio[1]
|
404 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
405 |
+
if self.pad_during_tiling:
|
406 |
+
resized_image = self._resize_for_patching(
|
407 |
+
image, (target_height, target_width), resample=resample, input_data_format=input_data_format
|
408 |
+
)
|
409 |
+
padded_image = self._pad_for_patching(resized_image, (target_height, target_width), input_data_format=input_data_format)
|
410 |
+
image_used_to_split = padded_image
|
411 |
+
else:
|
412 |
+
image_used_to_split = resize(image, (target_height, target_width), resample=resample, input_data_format=input_data_format)
|
413 |
+
|
414 |
+
processed_tiles = []
|
415 |
+
for i in range(blocks):
|
416 |
+
box = (
|
417 |
+
(i % (target_width // tile_size)) * tile_size,
|
418 |
+
(i // (target_width // tile_size)) * tile_size,
|
419 |
+
((i % (target_width // tile_size)) + 1) * tile_size,
|
420 |
+
((i // (target_width // tile_size)) + 1) * tile_size
|
421 |
+
)
|
422 |
+
# split the image
|
423 |
+
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3], input_data_format)
|
424 |
+
processed_tiles.append(split_img)
|
425 |
+
assert len(processed_tiles) == blocks
|
426 |
+
|
427 |
+
if use_thumbnail and len(processed_tiles) != 1:
|
428 |
+
thumbnail_img = resize(image, (tile_size, tile_size), resample=resample, input_data_format=input_data_format)
|
429 |
+
processed_tiles.append(thumbnail_img)
|
430 |
+
|
431 |
+
# make sure that all patches are in the input data format
|
432 |
+
processed_tiles = [
|
433 |
+
to_channel_dimension_format(tile, channel_dim=data_format, input_channel_dim=input_data_format)
|
434 |
+
for tile in processed_tiles
|
435 |
+
]
|
436 |
+
return processed_tiles
|
437 |
+
|
438 |
+
|
439 |
+
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_batching
|
440 |
+
def _pad_for_batching(
|
441 |
+
self,
|
442 |
+
pixel_values: List[np.ndarray],
|
443 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
444 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
445 |
+
):
|
446 |
+
"""
|
447 |
+
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
pixel_values (`List[np.ndarray]`):
|
451 |
+
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
452 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
453 |
+
The channel dimension format for the output image. Can be one of:
|
454 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
455 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
456 |
+
If unset, will use same as the input image.
|
457 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
458 |
+
The channel dimension format for the input image. Can be one of:
|
459 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
460 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
461 |
+
If unset, will use the inferred format of the input image.
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
List[`np.ndarray`]: The padded images.
|
465 |
+
"""
|
466 |
+
max_patch = max(len(x) for x in pixel_values)
|
467 |
+
pixel_values = [
|
468 |
+
self.pad(
|
469 |
+
image,
|
470 |
+
padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
|
471 |
+
data_format=data_format,
|
472 |
+
input_data_format=input_data_format,
|
473 |
+
)
|
474 |
+
for image in pixel_values
|
475 |
+
]
|
476 |
+
|
477 |
+
return pixel_values
|
478 |
+
|
479 |
+
def _preprocess(
|
480 |
+
self,
|
481 |
+
images: ImageInput,
|
482 |
+
do_resize: Optional[bool] = None,
|
483 |
+
size: Dict[str, int] = None,
|
484 |
+
resample: PILImageResampling = None,
|
485 |
+
do_rescale: Optional[bool] = None,
|
486 |
+
rescale_factor: Optional[float] = None,
|
487 |
+
do_normalize: Optional[bool] = None,
|
488 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
489 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
490 |
+
do_convert_rgb: Optional[bool] = None,
|
491 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
492 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
493 |
+
) -> Image.Image:
|
494 |
+
"""
|
495 |
+
Args:
|
496 |
+
images (`ImageInput`):
|
497 |
+
Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
|
498 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
499 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
500 |
+
Whether to resize the image.
|
501 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
502 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
503 |
+
the longest edge resized to keep the input aspect ratio.
|
504 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
505 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
506 |
+
has an effect if `do_resize` is set to `True`.
|
507 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
508 |
+
Whether to rescale the image.
|
509 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
510 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
511 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
512 |
+
Whether to normalize the image.
|
513 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
514 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
515 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
516 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
517 |
+
`True`.
|
518 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
519 |
+
The channel dimension format for the output image. Can be one of:
|
520 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
521 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
522 |
+
- Unset: Use the channel dimension format of the input image.
|
523 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
524 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
525 |
+
from the input image. Can be one of:
|
526 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
527 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
528 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
529 |
+
"""
|
530 |
+
if do_resize:
|
531 |
+
assert False, 'do_resize is not supported'
|
532 |
+
images = [
|
533 |
+
resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
534 |
+
for image in images
|
535 |
+
]
|
536 |
+
|
537 |
+
if do_rescale:
|
538 |
+
images = [
|
539 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
540 |
+
for image in images
|
541 |
+
]
|
542 |
+
|
543 |
+
if do_normalize:
|
544 |
+
images = [
|
545 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
546 |
+
for image in images
|
547 |
+
]
|
548 |
+
|
549 |
+
images = [
|
550 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
551 |
+
]
|
552 |
+
|
553 |
+
return images
|
554 |
+
|
555 |
+
def preprocess(
|
556 |
+
self,
|
557 |
+
images: ImageInput,
|
558 |
+
do_resize: Optional[bool] = None,
|
559 |
+
size: Dict[str, int] = None,
|
560 |
+
resample: PILImageResampling = None,
|
561 |
+
do_rescale: Optional[bool] = None,
|
562 |
+
rescale_factor: Optional[float] = None,
|
563 |
+
do_normalize: Optional[bool] = None,
|
564 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
565 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
566 |
+
do_pad: Optional[bool] = None,
|
567 |
+
do_convert_rgb: Optional[bool] = None,
|
568 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
569 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
570 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
571 |
+
):
|
572 |
+
"""
|
573 |
+
Args:
|
574 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
575 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
576 |
+
tensor. Both channels-first and channels-last formats are supported.
|
577 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
578 |
+
Whether to resize the image.
|
579 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
580 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
581 |
+
the longest edge resized to keep the input aspect ratio.
|
582 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
583 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
584 |
+
has an effect if `do_resize` is set to `True`.
|
585 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
586 |
+
Whether to rescale the image.
|
587 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
588 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
589 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
590 |
+
Whether to normalize the image.
|
591 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
592 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
593 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
594 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
595 |
+
`True`.
|
596 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
597 |
+
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
598 |
+
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
599 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
600 |
+
Whether to convert the image to RGB.
|
601 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
602 |
+
The type of tensors to return. Can be one of:
|
603 |
+
- Unset: Return a list of `np.ndarray`.
|
604 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
605 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
606 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
607 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
608 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
609 |
+
The channel dimension format for the output image. Can be one of:
|
610 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
611 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
612 |
+
- Unset: Use the channel dimension format of the input image.
|
613 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
614 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
615 |
+
from the input image. Can be one of:
|
616 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
617 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
618 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
619 |
+
|
620 |
+
"""
|
621 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
622 |
+
size = size if size is not None else self.size
|
623 |
+
size = get_size_dict(size, default_to_square=False)
|
624 |
+
resample = resample if resample is not None else self.resample
|
625 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
626 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
627 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
628 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
629 |
+
image_std = image_std if image_std is not None else self.image_std
|
630 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
631 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
632 |
+
|
633 |
+
images = make_flat_list_of_images(images)
|
634 |
+
|
635 |
+
if not valid_images(images):
|
636 |
+
raise ValueError(
|
637 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
638 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
639 |
+
)
|
640 |
+
|
641 |
+
validate_preprocess_arguments(
|
642 |
+
do_rescale=do_rescale,
|
643 |
+
rescale_factor=rescale_factor,
|
644 |
+
do_normalize=do_normalize,
|
645 |
+
image_mean=image_mean,
|
646 |
+
image_std=image_std,
|
647 |
+
do_resize=do_resize,
|
648 |
+
size=size,
|
649 |
+
resample=resample,
|
650 |
+
)
|
651 |
+
|
652 |
+
if do_convert_rgb:
|
653 |
+
images = [convert_to_rgb(image) for image in images]
|
654 |
+
|
655 |
+
# All transformations expect numpy arrays.
|
656 |
+
images = [to_numpy_array(image) for image in images]
|
657 |
+
|
658 |
+
if do_rescale and is_scaled_image(images[0]):
|
659 |
+
logger.warning_once(
|
660 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
661 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
662 |
+
)
|
663 |
+
|
664 |
+
if input_data_format is None:
|
665 |
+
# We assume that all images have the same channel dimension format.
|
666 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
667 |
+
|
668 |
+
processed_images = []
|
669 |
+
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
670 |
+
for image in images:
|
671 |
+
# convert image into a list of patches
|
672 |
+
# we intentially use the same data format as the input data format
|
673 |
+
size_tuple = (
|
674 |
+
(size["height"], size["width"])
|
675 |
+
if "height" in size and "width" in size
|
676 |
+
else (size["shortest_edge"], size["shortest_edge"])
|
677 |
+
)
|
678 |
+
image_patches = self.get_image_patches(
|
679 |
+
image,
|
680 |
+
min_num=self.min_dynamic_tiles,
|
681 |
+
max_num=self.max_dynamic_tiles,
|
682 |
+
size=size_tuple,
|
683 |
+
tile_size=size["height"],
|
684 |
+
resample=resample,
|
685 |
+
data_format=input_data_format,
|
686 |
+
input_data_format=input_data_format,
|
687 |
+
use_thumbnail=self.use_thumbnail,
|
688 |
+
)
|
689 |
+
|
690 |
+
# preprocess patches
|
691 |
+
pixel_values = self._preprocess(
|
692 |
+
image_patches,
|
693 |
+
do_resize=do_resize,
|
694 |
+
size=size_tuple,
|
695 |
+
resample=resample,
|
696 |
+
do_rescale=do_rescale,
|
697 |
+
rescale_factor=rescale_factor,
|
698 |
+
do_normalize=do_normalize,
|
699 |
+
image_mean=image_mean,
|
700 |
+
image_std=image_std,
|
701 |
+
data_format=data_format,
|
702 |
+
input_data_format=input_data_format,
|
703 |
+
)
|
704 |
+
pixel_values = np.array(pixel_values)
|
705 |
+
processed_images.append(pixel_values)
|
706 |
+
|
707 |
+
if do_pad:
|
708 |
+
processed_images = self._pad_for_batching(processed_images)
|
709 |
+
|
710 |
+
return BatchFeature(
|
711 |
+
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
712 |
+
)
|
713 |
+
|
714 |
+
|
715 |
+
__all__ = ["Eagle2ImageProcessor"]
|
image_processing_eagle2_5_vl_fast.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# NVIDIA
|
3 |
+
# Copyright (c) 2025 NVIDIA
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
8 |
+
from typing import List, Optional, Union
|
9 |
+
|
10 |
+
from transformers.image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
|
11 |
+
from transformers.image_processing_utils_fast import (
|
12 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
13 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
14 |
+
BaseImageProcessorFast,
|
15 |
+
DefaultFastImageProcessorKwargs,
|
16 |
+
divide_to_patches,
|
17 |
+
group_images_by_shape,
|
18 |
+
reorder_images,
|
19 |
+
)
|
20 |
+
from transformers.image_utils import (
|
21 |
+
OPENAI_CLIP_MEAN,
|
22 |
+
OPENAI_CLIP_STD,
|
23 |
+
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
24 |
+
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
25 |
+
ChannelDimension,
|
26 |
+
ImageInput,
|
27 |
+
VideoInput,
|
28 |
+
PILImageResampling,
|
29 |
+
SizeDict,
|
30 |
+
get_image_size,
|
31 |
+
make_flat_list_of_images,
|
32 |
+
make_batched_videos,
|
33 |
+
validate_kwargs
|
34 |
+
)
|
35 |
+
from transformers.processing_utils import Unpack
|
36 |
+
from transformers.utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available
|
37 |
+
|
38 |
+
|
39 |
+
if is_torch_available():
|
40 |
+
import torch
|
41 |
+
if is_torchvision_v2_available():
|
42 |
+
from transformers.image_utils import pil_torch_interpolation_mapping
|
43 |
+
|
44 |
+
from torchvision.transforms.v2 import functional as F
|
45 |
+
else:
|
46 |
+
from torchvision.transforms import functional as F
|
47 |
+
|
48 |
+
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
49 |
+
"""Crop the given numpy array.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
53 |
+
left (int): The left coordinate of the crop box.
|
54 |
+
top (int): The top coordinate of the crop box.
|
55 |
+
right (int): The right coordinate of the crop box.
|
56 |
+
bottom (int): The bottom coordinate of the crop box.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Cropped image.
|
60 |
+
"""
|
61 |
+
if not isinstance(img, torch.Tensor):
|
62 |
+
raise TypeError('img should be torch.Tensor. Got {}'.format(type(img)))
|
63 |
+
|
64 |
+
if img.ndim not in [2, 3]:
|
65 |
+
raise ValueError('Image should have 2 or 3 dimensions. Got {}'.format(img.ndim))
|
66 |
+
|
67 |
+
img_height = img.shape[1]
|
68 |
+
img_width = img.shape[2]
|
69 |
+
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
70 |
+
raise ValueError('Crop coordinates out of bounds')
|
71 |
+
|
72 |
+
if top >= bottom or left >= right:
|
73 |
+
raise ValueError('Invalid crop coordinates')
|
74 |
+
|
75 |
+
return img[:, top:bottom, left:right]
|
76 |
+
|
77 |
+
|
78 |
+
class Eagle2_5_VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
79 |
+
max_dynamic_tiles: Optional[int]
|
80 |
+
min_dynamic_tiles: Optional[int]
|
81 |
+
use_thumbnail: Optional[bool]
|
82 |
+
pad_during_tiling: Optional[bool]
|
83 |
+
do_pad: Optional[bool]
|
84 |
+
|
85 |
+
|
86 |
+
@add_start_docstrings(
|
87 |
+
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
88 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
89 |
+
"""
|
90 |
+
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
91 |
+
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
92 |
+
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
93 |
+
method. Not used for processing videos.
|
94 |
+
do_pad (`bool`, *optional*):
|
95 |
+
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
96 |
+
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
97 |
+
""",
|
98 |
+
)
|
99 |
+
class Eagle2_5_VLImageProcessorFast(BaseImageProcessorFast):
|
100 |
+
resample = PILImageResampling.BICUBIC
|
101 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
102 |
+
image_std = IMAGENET_STANDARD_STD
|
103 |
+
size = {"height": 448, "width": 448}
|
104 |
+
default_to_square = False
|
105 |
+
crop_size = None
|
106 |
+
do_resize = True
|
107 |
+
do_center_crop = None
|
108 |
+
do_rescale = True
|
109 |
+
do_normalize = True
|
110 |
+
do_convert_rgb = True
|
111 |
+
do_pad = True
|
112 |
+
max_dynamic_tiles = 12
|
113 |
+
min_dynamic_tiles = 1
|
114 |
+
use_thumbnail = True
|
115 |
+
pad_during_tiling = False
|
116 |
+
valid_kwargs = Eagle2_5_VLFastImageProcessorKwargs
|
117 |
+
model_input_names = ["pixel_values_videos"]
|
118 |
+
|
119 |
+
def __init__(self, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]):
|
120 |
+
super().__init__(**kwargs)
|
121 |
+
|
122 |
+
@add_start_docstrings(
|
123 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
124 |
+
"""
|
125 |
+
max_dynamic_tiles (`int`, *optional*):
|
126 |
+
The maximum number of dynamic tiles to use for processing high resolution images.
|
127 |
+
min_dynamic_tiles (`int`, *optional*):
|
128 |
+
The minimum number of dynamic tiles to use for processing high resolution images.
|
129 |
+
use_thumbnail (`bool`, *optional*):
|
130 |
+
Whether to use a thumbnail for processing high resolution images.
|
131 |
+
pad_during_tiling (`bool`, *optional*):
|
132 |
+
Whether to pad the image during tiling.
|
133 |
+
do_pad (`bool`, *optional*):
|
134 |
+
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
135 |
+
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
136 |
+
""",
|
137 |
+
)
|
138 |
+
def preprocess(self, images: ImageInput, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]) -> BatchFeature:
|
139 |
+
return super().preprocess(images, **kwargs)
|
140 |
+
|
141 |
+
def _prepare_images_structure(
|
142 |
+
self,
|
143 |
+
images: ImageInput,
|
144 |
+
) -> ImageInput:
|
145 |
+
"""
|
146 |
+
Prepare the images structure for processing.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
images (`ImageInput`):
|
150 |
+
The input images to process.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
`ImageInput`: The images with a valid nesting.
|
154 |
+
"""
|
155 |
+
return make_flat_list_of_images(images)
|
156 |
+
|
157 |
+
def _prepare_videos_structure(self, videos: VideoInput) -> VideoInput:
|
158 |
+
return self._prepare_images_structure(videos)
|
159 |
+
|
160 |
+
def _prepare_input_videos(
|
161 |
+
self,
|
162 |
+
videos: VideoInput,
|
163 |
+
do_convert_rgb: Optional[bool] = None,
|
164 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
165 |
+
device: Optional["torch.device"] = None,
|
166 |
+
) -> list["torch.Tensor"]:
|
167 |
+
"""
|
168 |
+
Prepare the input images for processing.
|
169 |
+
"""
|
170 |
+
videos = self._prepare_videos_structure(videos)
|
171 |
+
process_video_fn = partial(
|
172 |
+
self._process_image,
|
173 |
+
do_convert_rgb=do_convert_rgb,
|
174 |
+
input_data_format=input_data_format,
|
175 |
+
device=device,
|
176 |
+
)
|
177 |
+
# todo: yoni - check if we can parallelize this efficiently
|
178 |
+
processed_videos = []
|
179 |
+
for video in videos:
|
180 |
+
processed_videos.append(process_video_fn(video))
|
181 |
+
|
182 |
+
return processed_videos
|
183 |
+
|
184 |
+
def _resize_for_patching(
|
185 |
+
self,
|
186 |
+
image: "torch.Tensor",
|
187 |
+
target_resolution: tuple,
|
188 |
+
interpolation: "F.InterpolationMode",
|
189 |
+
input_data_format: ChannelDimension,
|
190 |
+
) -> "torch.Tensor":
|
191 |
+
"""
|
192 |
+
Resizes an image to a target resolution while maintaining aspect ratio.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
image ("torch.Tensor"):
|
196 |
+
The input image.
|
197 |
+
target_resolution (tuple):
|
198 |
+
The target resolution (height, width) of the image.
|
199 |
+
interpolation (`InterpolationMode`):
|
200 |
+
Resampling filter to use if resizing the image.
|
201 |
+
input_data_format (`ChannelDimension` or `str`):
|
202 |
+
The channel dimension format of the input image.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
"torch.Tensor": The resized and padded image.
|
206 |
+
"""
|
207 |
+
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
208 |
+
|
209 |
+
# Resize the image
|
210 |
+
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
211 |
+
|
212 |
+
return resized_image
|
213 |
+
|
214 |
+
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
215 |
+
"""
|
216 |
+
previous version mainly foucs on ratio.
|
217 |
+
We also consider area ratio here.
|
218 |
+
"""
|
219 |
+
best_factor = float('-inf')
|
220 |
+
best_ratio = (1, 1)
|
221 |
+
area = width * height
|
222 |
+
for ratio in target_ratios:
|
223 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
224 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
225 |
+
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
|
226 |
+
"""
|
227 |
+
new area > 60% of original image area is enough.
|
228 |
+
"""
|
229 |
+
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
|
230 |
+
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
|
231 |
+
|
232 |
+
if factor_based_on_area_n_ratio > best_factor:
|
233 |
+
best_factor = factor_based_on_area_n_ratio
|
234 |
+
best_ratio = ratio
|
235 |
+
|
236 |
+
return best_ratio
|
237 |
+
|
238 |
+
def _pad_for_patching(
|
239 |
+
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
240 |
+
) -> "torch.Tensor":
|
241 |
+
"""
|
242 |
+
Pad an image to a target resolution while maintaining aspect ratio.
|
243 |
+
"""
|
244 |
+
target_height, target_width = target_resolution
|
245 |
+
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
246 |
+
|
247 |
+
paste_x = (target_width - new_width) // 2
|
248 |
+
paste_y = (target_height - new_height) // 2
|
249 |
+
|
250 |
+
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
251 |
+
|
252 |
+
return padded_image
|
253 |
+
|
254 |
+
def _get_image_patches(
|
255 |
+
self,
|
256 |
+
image: "torch.Tensor",
|
257 |
+
min_num: int,
|
258 |
+
max_num: int,
|
259 |
+
size: tuple,
|
260 |
+
tile_size: int,
|
261 |
+
use_thumbnail: bool,
|
262 |
+
interpolation: "F.InterpolationMode",
|
263 |
+
pad_during_tiling: bool,
|
264 |
+
) -> List["torch.Tensor"] :
|
265 |
+
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
266 |
+
orig_height, orig_width = image_size
|
267 |
+
aspect_ratio = orig_width / orig_height
|
268 |
+
|
269 |
+
# calculate the existing image aspect ratio
|
270 |
+
target_ratios = set(
|
271 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
272 |
+
i * j <= max_num and i * j >= min_num)
|
273 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
274 |
+
|
275 |
+
# find the closest aspect ratio to the target
|
276 |
+
target_aspect_ratio = self.find_closest_aspect_ratio(
|
277 |
+
aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
|
278 |
+
|
279 |
+
# calculate the target width and height
|
280 |
+
target_width = tile_size * target_aspect_ratio[0]
|
281 |
+
target_height = tile_size * target_aspect_ratio[1]
|
282 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
283 |
+
if pad_during_tiling:
|
284 |
+
resized_image = self._resize_for_patching(
|
285 |
+
image, (target_height, target_width), interpolation=interpolation, input_data_format=ChannelDimension.FIRST
|
286 |
+
)
|
287 |
+
padded_image = self._pad_for_patching(resized_image, (target_height, target_width), input_data_format=ChannelDimension.FIRST)
|
288 |
+
image_used_to_split = padded_image
|
289 |
+
else:
|
290 |
+
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
291 |
+
|
292 |
+
processed_tiles = []
|
293 |
+
for i in range(blocks):
|
294 |
+
box = (
|
295 |
+
(i % (target_width // tile_size)) * tile_size,
|
296 |
+
(i // (target_width // tile_size)) * tile_size,
|
297 |
+
((i % (target_width // tile_size)) + 1) * tile_size,
|
298 |
+
((i // (target_width // tile_size)) + 1) * tile_size
|
299 |
+
)
|
300 |
+
# split the image
|
301 |
+
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
302 |
+
processed_tiles.append(split_img)
|
303 |
+
assert len(processed_tiles) == blocks
|
304 |
+
|
305 |
+
if use_thumbnail and len(processed_tiles) != 1:
|
306 |
+
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
307 |
+
processed_tiles.append(thumbnail_img)
|
308 |
+
|
309 |
+
return processed_tiles
|
310 |
+
|
311 |
+
def _pad_for_batching(
|
312 |
+
self,
|
313 |
+
pixel_values: List["torch.Tensor"],
|
314 |
+
) -> List["torch.Tensor"]:
|
315 |
+
"""
|
316 |
+
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
pixel_values (`List[torch.Tensor]`):
|
320 |
+
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
List[`torch.Tensor`]: The padded images.
|
324 |
+
"""
|
325 |
+
max_patch = max(len(x) for x in pixel_values)
|
326 |
+
pixel_values = [
|
327 |
+
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
328 |
+
for image in pixel_values
|
329 |
+
]
|
330 |
+
|
331 |
+
return pixel_values
|
332 |
+
|
333 |
+
def _preprocess(
|
334 |
+
self,
|
335 |
+
images: List["torch.Tensor"],
|
336 |
+
do_resize: bool,
|
337 |
+
size: SizeDict,
|
338 |
+
max_dynamic_tiles: int,
|
339 |
+
min_dynamic_tiles: int,
|
340 |
+
use_thumbnail: bool,
|
341 |
+
pad_during_tiling: bool,
|
342 |
+
interpolation: Optional["F.InterpolationMode"],
|
343 |
+
do_center_crop: bool,
|
344 |
+
crop_size: SizeDict,
|
345 |
+
do_rescale: bool,
|
346 |
+
rescale_factor: float,
|
347 |
+
do_normalize: bool,
|
348 |
+
image_mean: Optional[Union[float, List[float]]],
|
349 |
+
image_std: Optional[Union[float, List[float]]],
|
350 |
+
do_pad: bool,
|
351 |
+
return_tensors: Optional[Union[str, TensorType]],
|
352 |
+
) -> BatchFeature:
|
353 |
+
processed_images = []
|
354 |
+
image_sizes = []
|
355 |
+
# Determine the size tuple
|
356 |
+
if size and size.height and size.width:
|
357 |
+
size_tuple = (size.height, size.width)
|
358 |
+
else:
|
359 |
+
size_tuple = (size.shortest_edge, size.shortest_edge)
|
360 |
+
|
361 |
+
# Determine the patch size
|
362 |
+
if crop_size and crop_size.height:
|
363 |
+
tile_size = crop_size.height
|
364 |
+
elif size and size.height:
|
365 |
+
tile_size = size.height
|
366 |
+
else:
|
367 |
+
tile_size = size.shortest_edge
|
368 |
+
|
369 |
+
for image in images:
|
370 |
+
image_patches = self._get_image_patches(
|
371 |
+
image,
|
372 |
+
min_num=min_dynamic_tiles,
|
373 |
+
max_num=max_dynamic_tiles,
|
374 |
+
size=size_tuple,
|
375 |
+
tile_size=tile_size,
|
376 |
+
use_thumbnail=use_thumbnail,
|
377 |
+
interpolation=interpolation,
|
378 |
+
pad_during_tiling=pad_during_tiling,
|
379 |
+
)
|
380 |
+
|
381 |
+
# Group images by size for batched processing
|
382 |
+
processed_image_patches_grouped = {}
|
383 |
+
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
|
384 |
+
|
385 |
+
for shape, stacked_image_patches in grouped_image_patches.items():
|
386 |
+
if do_resize:
|
387 |
+
stacked_image_patches = self.resize(
|
388 |
+
image=stacked_image_patches,
|
389 |
+
size=size,
|
390 |
+
interpolation=interpolation,
|
391 |
+
)
|
392 |
+
if do_center_crop:
|
393 |
+
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
394 |
+
# Fused rescale and normalize
|
395 |
+
stacked_image_patches = self.rescale_and_normalize(
|
396 |
+
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
397 |
+
)
|
398 |
+
processed_image_patches_grouped[shape] = stacked_image_patches
|
399 |
+
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
|
400 |
+
processed_image_patches = (
|
401 |
+
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
402 |
+
)
|
403 |
+
processed_images.append(processed_image_patches)
|
404 |
+
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
405 |
+
|
406 |
+
if do_pad:
|
407 |
+
processed_images = self._pad_for_batching(processed_images)
|
408 |
+
|
409 |
+
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
410 |
+
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
411 |
+
return BatchFeature(
|
412 |
+
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
413 |
+
)
|
414 |
+
|
415 |
+
|
416 |
+
def preprocess(self, images: ImageInput, videos: VideoInput=None, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]) -> BatchFeature:
|
417 |
+
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
418 |
+
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
419 |
+
# by the user, it gets its default value from the instance, or is set to None.
|
420 |
+
for kwarg_name in self.valid_kwargs.__annotations__:
|
421 |
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
422 |
+
|
423 |
+
# Extract parameters that are only used for preparing the input images
|
424 |
+
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
425 |
+
input_data_format = kwargs.pop("input_data_format")
|
426 |
+
device = kwargs.pop("device")
|
427 |
+
# Prepare input images
|
428 |
+
if images is not None:
|
429 |
+
images = self._prepare_input_images(
|
430 |
+
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
431 |
+
)
|
432 |
+
|
433 |
+
if videos is not None:
|
434 |
+
videos = self._prepare_input_images(
|
435 |
+
images=videos, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
436 |
+
)
|
437 |
+
|
438 |
+
# Update kwargs that need further processing before being validated
|
439 |
+
kwargs = self._further_process_kwargs(**kwargs)
|
440 |
+
|
441 |
+
# Validate kwargs
|
442 |
+
self._validate_preprocess_kwargs(**kwargs)
|
443 |
+
|
444 |
+
# torch resize uses interpolation instead of resample
|
445 |
+
resample = kwargs.pop("resample")
|
446 |
+
kwargs["interpolation"] = (
|
447 |
+
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
448 |
+
)
|
449 |
+
|
450 |
+
# Pop kwargs that are not needed in _preprocess
|
451 |
+
kwargs.pop("default_to_square")
|
452 |
+
kwargs.pop("data_format")
|
453 |
+
if images is not None:
|
454 |
+
return self._preprocess(images, **kwargs)
|
455 |
+
elif videos is not None:
|
456 |
+
return self._preprocess(videos, **kwargs)
|
457 |
+
|
458 |
+
__all__ = ["Eagle2_5_VLImageProcessorFast"]
|
modeling_eagle_chat.py → modeling_eagle2_5_vl.py
RENAMED
@@ -1,84 +1,118 @@
|
|
1 |
# --------------------------------------------------------
|
2 |
-
#
|
3 |
# Copyright (c) 2025 NVIDIA
|
4 |
-
# Licensed under The
|
5 |
# --------------------------------------------------------
|
6 |
|
7 |
import warnings
|
|
|
8 |
from typing import Any, List, Optional, Tuple, Union
|
9 |
-
|
10 |
-
import torch.utils.checkpoint
|
11 |
-
import transformers
|
12 |
from torch import nn
|
|
|
13 |
from torch.nn import CrossEntropyLoss
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers.utils import ModelOutput, logging
|
19 |
-
from
|
20 |
-
from transformers.
|
21 |
-
|
22 |
-
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
23 |
|
24 |
logger = logging.get_logger(__name__)
|
25 |
-
from .configuration_eagle_chat import Eagle2ChatConfig
|
26 |
-
|
27 |
-
def version_cmp(v1, v2, op='eq'):
|
28 |
-
import operator
|
29 |
|
30 |
-
from packaging import version
|
31 |
-
op_func = getattr(operator, op)
|
32 |
-
return op_func(version.parse(v1), version.parse(v2))
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
_supports_flash_attn_2 = True
|
|
|
|
|
|
|
40 |
_supports_sdpa = True
|
41 |
-
_supports_flex_attn = False
|
42 |
-
_supports_cache_class = False
|
43 |
-
_supports_quantized_cache = False
|
44 |
-
_supports_static_cache = False
|
45 |
-
_supports_attention_backend = False
|
46 |
|
47 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
super().__init__(config)
|
49 |
|
50 |
image_size = config.force_image_size or config.vision_config.image_size
|
51 |
-
|
52 |
patch_size = config.vision_config.patch_size
|
53 |
self.patch_size = patch_size
|
54 |
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
55 |
|
56 |
self.select_layer = config.select_layer
|
57 |
-
self.template = config.template
|
58 |
self.downsample_ratio = config.downsample_ratio
|
59 |
-
|
|
|
|
|
60 |
logger.info(f'num_image_token: {self.num_image_token}')
|
|
|
61 |
if vision_model is not None:
|
62 |
self.vision_model = vision_model
|
63 |
else:
|
64 |
if config.vision_config.model_type == 'siglip_vision_model':
|
65 |
-
|
66 |
-
config.vision_config._attn_implementation = 'eager'
|
67 |
self.vision_model = SiglipVisionModel(config.vision_config)
|
|
|
|
|
68 |
|
69 |
if language_model is not None:
|
70 |
self.language_model = language_model
|
71 |
else:
|
72 |
-
if config.
|
73 |
-
self.language_model = LlamaForCausalLM(config.
|
74 |
-
elif config.
|
75 |
-
|
|
|
76 |
else:
|
77 |
-
raise NotImplementedError(f'{config.
|
78 |
|
79 |
vit_hidden_size = config.vision_config.hidden_size
|
80 |
-
|
81 |
-
llm_hidden_size = config.llm_config.hidden_size
|
82 |
|
83 |
self.mlp1 = nn.Sequential(
|
84 |
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
@@ -86,26 +120,39 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
86 |
nn.GELU(),
|
87 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
88 |
)
|
89 |
-
self.
|
90 |
-
self.
|
|
|
91 |
|
92 |
if config.use_backbone_lora:
|
93 |
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
94 |
|
|
|
95 |
if config.use_llm_lora:
|
96 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
99 |
lora_config = LoraConfig(
|
100 |
r=r,
|
101 |
-
target_modules=['
|
|
|
102 |
lora_alpha=lora_alpha,
|
103 |
lora_dropout=lora_dropout,
|
104 |
)
|
105 |
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
106 |
self.vision_model.print_trainable_parameters()
|
107 |
|
108 |
-
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
109 |
lora_config = LoraConfig(
|
110 |
r=r,
|
111 |
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
@@ -117,8 +164,8 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
117 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
118 |
self.language_model.enable_input_require_grads()
|
119 |
self.language_model.print_trainable_parameters()
|
120 |
-
|
121 |
-
|
122 |
def forward(
|
123 |
self,
|
124 |
pixel_values: torch.FloatTensor,
|
@@ -132,7 +179,7 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
132 |
output_attentions: Optional[bool] = None,
|
133 |
output_hidden_states: Optional[bool] = None,
|
134 |
return_dict: Optional[bool] = None,
|
135 |
-
|
136 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
137 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
138 |
|
@@ -154,7 +201,7 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
154 |
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
|
155 |
|
156 |
input_ids = input_ids.reshape(B * N)
|
157 |
-
selected = (input_ids == self.
|
158 |
try:
|
159 |
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
160 |
except Exception as e:
|
@@ -211,165 +258,36 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
211 |
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
212 |
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
213 |
int(c / (scale_factor * scale_factor)))
|
|
|
214 |
x = x.permute(0, 2, 1, 3).contiguous()
|
215 |
return x
|
216 |
|
217 |
def extract_feature(self, pixel_values):
|
218 |
-
"""
|
219 |
-
"""
|
220 |
-
|
221 |
if self.select_layer == -1:
|
222 |
vit_embeds = self.vision_model(
|
223 |
pixel_values=pixel_values,
|
224 |
output_hidden_states=False,
|
225 |
return_dict=True)
|
226 |
-
# if there is vit_embeds.last_hidden_state, use it.
|
227 |
if hasattr(vit_embeds, 'last_hidden_state'):
|
228 |
vit_embeds = vit_embeds.last_hidden_state
|
|
|
229 |
else:
|
230 |
vit_embeds = self.vision_model(
|
231 |
pixel_values=pixel_values,
|
232 |
output_hidden_states=True,
|
233 |
return_dict=True).hidden_states[self.select_layer]
|
234 |
-
|
235 |
-
pass
|
236 |
-
else:
|
237 |
-
vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
|
238 |
-
|
239 |
-
if self.training and self.neftune_alpha is not None:
|
240 |
-
vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
|
241 |
-
|
242 |
-
|
243 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
244 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
245 |
-
|
246 |
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
247 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
248 |
-
|
|
|
|
|
|
|
249 |
|
250 |
return vit_embeds
|
251 |
-
|
252 |
-
def batch_chat(self,
|
253 |
-
tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
|
254 |
-
history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
|
255 |
-
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
|
256 |
-
if history is not None or return_history:
|
257 |
-
print('Now multi-turn chat is not supported in batch_chat.')
|
258 |
-
raise NotImplementedError
|
259 |
-
|
260 |
-
if image_counts is not None:
|
261 |
-
num_patches_list = image_counts
|
262 |
-
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
|
263 |
-
|
264 |
-
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
265 |
-
self.img_context_token_id = img_context_token_id
|
266 |
-
|
267 |
-
if verbose and pixel_values is not None:
|
268 |
-
image_bs = pixel_values.shape[0]
|
269 |
-
print(f'dynamic ViT batch size: {image_bs}')
|
270 |
-
|
271 |
-
queries = []
|
272 |
-
for idx, num_patches in enumerate(num_patches_list):
|
273 |
-
question = questions[idx]
|
274 |
-
if pixel_values is not None and '<image>' not in question:
|
275 |
-
question = '<image>\n' + question
|
276 |
-
template_messages = []
|
277 |
-
sep = tokenizer.eos_token
|
278 |
-
template_messages.append(('<|im_start|>user', question))
|
279 |
-
template_messages.append(('<|im_end|>assistant', None))
|
280 |
-
query = self.get_prompt(self.system_message, template_messages, sep)
|
281 |
-
|
282 |
-
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
283 |
-
query = query.replace('<image>', image_tokens, 1)
|
284 |
-
queries.append(query)
|
285 |
-
|
286 |
-
tokenizer.padding_side = 'left'
|
287 |
-
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
288 |
-
input_ids = model_inputs['input_ids'].cuda()
|
289 |
-
attention_mask = model_inputs['attention_mask'].cuda()
|
290 |
-
eos_token_id = tokenizer.convert_tokens_to_ids(sep)
|
291 |
-
generation_config['eos_token_id'] = eos_token_id
|
292 |
-
generation_output = self.generate(
|
293 |
-
pixel_values=pixel_values,
|
294 |
-
input_ids=input_ids,
|
295 |
-
attention_mask=attention_mask,
|
296 |
-
**generation_config
|
297 |
-
)
|
298 |
-
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
299 |
-
responses = [response.split(sep)[0].strip() for response in responses]
|
300 |
-
return responses
|
301 |
-
|
302 |
-
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
303 |
-
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
304 |
-
verbose=False, llm_only=False):
|
305 |
-
|
306 |
-
if history is None and pixel_values is not None and '<image>' not in question:
|
307 |
-
question = '<image>\n' + question
|
308 |
-
|
309 |
-
if num_patches_list is None:
|
310 |
-
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
|
311 |
-
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
312 |
-
|
313 |
-
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
314 |
-
self.img_context_token_id = img_context_token_id
|
315 |
-
|
316 |
-
template_messages = []
|
317 |
-
system_message = f'<|im_start|>system\n{self.system_message}'
|
318 |
-
sep = tokenizer.eos_token
|
319 |
-
eos_token_id = tokenizer.convert_tokens_to_ids(sep)
|
320 |
-
|
321 |
-
history = [] if history is None else history
|
322 |
-
for (old_question, old_answer) in history:
|
323 |
-
template_messages.append(('<|im_start|>user', old_question))
|
324 |
-
template_messages.append(('<|im_start|>assistant', old_answer))
|
325 |
-
template_messages.append(('<|im_start|>user', question))
|
326 |
-
template_messages.append(('<|im_end|>assistant', None))
|
327 |
-
query = self.get_prompt(system_message, template_messages, sep)
|
328 |
-
|
329 |
-
if verbose and pixel_values is not None:
|
330 |
-
image_bs = pixel_values.shape[0]
|
331 |
-
print(f'dynamic ViT batch size: {image_bs}')
|
332 |
-
|
333 |
-
for num_patches in num_patches_list:
|
334 |
-
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
335 |
-
if llm_only:
|
336 |
-
query = query.replace('<image>', '', 1)
|
337 |
-
else:
|
338 |
-
query = query.replace('<image>', image_tokens, 1)
|
339 |
-
|
340 |
-
model_inputs = tokenizer(query, return_tensors='pt')
|
341 |
-
input_ids = model_inputs['input_ids'].cuda()
|
342 |
-
attention_mask = model_inputs['attention_mask'].cuda()
|
343 |
-
generation_config['eos_token_id'] = eos_token_id
|
344 |
-
generation_output = self.generate(
|
345 |
-
pixel_values=pixel_values,
|
346 |
-
input_ids=input_ids,
|
347 |
-
attention_mask=attention_mask,
|
348 |
-
**generation_config
|
349 |
-
)
|
350 |
-
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
351 |
-
response = response.split(sep)[0].strip()
|
352 |
-
history.append((question, response))
|
353 |
-
if return_history:
|
354 |
-
return response, history
|
355 |
-
else:
|
356 |
-
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
|
357 |
-
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
|
358 |
-
if verbose:
|
359 |
-
print(query_to_print, response)
|
360 |
-
return response
|
361 |
-
|
362 |
-
def get_prompt(self, system_prompt, messages, sep) -> str:
|
363 |
-
"""Get the prompt for generation."""
|
364 |
-
|
365 |
-
ret = '' if system_prompt == '' else system_prompt + sep + '\n'
|
366 |
-
for role, message in messages:
|
367 |
-
if message:
|
368 |
-
ret += role + '\n' + message + sep + '\n'
|
369 |
-
else:
|
370 |
-
ret += role + '\n'
|
371 |
-
return ret
|
372 |
-
|
373 |
@torch.no_grad()
|
374 |
def generate(
|
375 |
self,
|
@@ -379,11 +297,10 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
379 |
visual_features: Optional[torch.FloatTensor] = None,
|
380 |
generation_config: Optional[GenerationConfig] = None,
|
381 |
output_hidden_states: Optional[bool] = None,
|
382 |
-
|
383 |
**generate_kwargs,
|
384 |
) -> torch.LongTensor:
|
385 |
|
386 |
-
assert self.img_context_token_id is not None
|
387 |
if pixel_values is not None:
|
388 |
if visual_features is not None:
|
389 |
vit_embeds = visual_features
|
@@ -395,7 +312,7 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
395 |
input_embeds = input_embeds.reshape(B * N, C)
|
396 |
|
397 |
input_ids = input_ids.reshape(B * N)
|
398 |
-
selected = (input_ids == self.
|
399 |
assert selected.sum() != 0
|
400 |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
401 |
|
@@ -413,9 +330,28 @@ class Eagle2ChatModel(PreTrainedModel):
|
|
413 |
)
|
414 |
|
415 |
return outputs
|
416 |
-
|
|
|
417 |
def get_input_embeddings(self):
|
418 |
return self.language_model.get_input_embeddings()
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
420 |
def get_output_embeddings(self):
|
421 |
-
return self.language_model.get_output_embeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# --------------------------------------------------------
|
2 |
+
# NVIDIA
|
3 |
# Copyright (c) 2025 NVIDIA
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
|
7 |
import warnings
|
8 |
+
import inspect
|
9 |
from typing import Any, List, Optional, Tuple, Union
|
10 |
+
import torch
|
|
|
|
|
11 |
from torch import nn
|
12 |
+
import torch.distributed as dist
|
13 |
from torch.nn import CrossEntropyLoss
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
16 |
+
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
17 |
+
import torch.utils.checkpoint as cp
|
18 |
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
19 |
+
from peft import LoraConfig, get_peft_model
|
20 |
+
from transformers.generation import GenerationMixin
|
21 |
+
from transformers import GenerationConfig
|
22 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
24 |
from transformers.utils import ModelOutput, logging
|
25 |
+
from .configuration_eagle2_5_vl import Eagle2_5_VLConfig
|
26 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
|
|
|
|
27 |
|
28 |
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
29 |
|
|
|
|
|
|
|
30 |
|
31 |
+
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
32 |
+
EAGLE2_5_VL_START_DOCSTRING = r"""
|
33 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
34 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
35 |
+
etc.)
|
36 |
+
|
37 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
38 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
39 |
+
and behavior.
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
config ([`Eagle2_5_VLConfig`]):
|
43 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
44 |
+
load the weights associated with the model, only the configuration. Check out the
|
45 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
46 |
+
"""
|
47 |
+
|
48 |
+
@add_start_docstrings(
|
49 |
+
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
50 |
+
EAGLE2_5_VL_START_DOCSTRING,
|
51 |
+
)
|
52 |
+
class Eagle2_5_VLPreTrainedModel(PreTrainedModel):
|
53 |
+
config_class = Eagle2_5_VLConfig
|
54 |
+
base_model_prefix = "model"
|
55 |
+
main_input_name = 'input_ids'
|
56 |
+
supports_gradient_checkpointing = True
|
57 |
+
_no_split_modules = ["Qwen2DecoderLayer", "LlamaDecoderLayer" ,"Siglip2EncoderLayer", "SiglipEncoderLayer"]
|
58 |
+
_skip_keys_device_placement = "past_key_values"
|
59 |
_supports_flash_attn_2 = True
|
60 |
+
_supports_cache_class = True
|
61 |
+
_supports_static_cache = True
|
62 |
+
_supports_quantized_cache = True
|
63 |
_supports_sdpa = True
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
def _init_weights(self, module):
|
66 |
+
std = self.config.initializer_range
|
67 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
68 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
69 |
+
if module.bias is not None:
|
70 |
+
module.bias.data.zero_()
|
71 |
+
elif isinstance(module, nn.Embedding):
|
72 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
73 |
+
if module.padding_idx is not None:
|
74 |
+
module.weight.data[module.padding_idx].zero_()
|
75 |
+
|
76 |
+
|
77 |
+
class Eagle2_5_VLForConditionalGeneration(Eagle2_5_VLPreTrainedModel, GenerationMixin):
|
78 |
+
config_class = Eagle2_5_VLConfig
|
79 |
+
def __init__(self, config: Eagle2_5_VLConfig, vision_model=None, language_model=None):
|
80 |
super().__init__(config)
|
81 |
|
82 |
image_size = config.force_image_size or config.vision_config.image_size
|
|
|
83 |
patch_size = config.vision_config.patch_size
|
84 |
self.patch_size = patch_size
|
85 |
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
86 |
|
87 |
self.select_layer = config.select_layer
|
|
|
88 |
self.downsample_ratio = config.downsample_ratio
|
89 |
+
self.loss_version = config.loss_version
|
90 |
+
self.mlp_checkpoint = config.mlp_checkpoint
|
91 |
+
|
92 |
logger.info(f'num_image_token: {self.num_image_token}')
|
93 |
+
logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}')
|
94 |
if vision_model is not None:
|
95 |
self.vision_model = vision_model
|
96 |
else:
|
97 |
if config.vision_config.model_type == 'siglip_vision_model':
|
98 |
+
config.vision_config._attn_implementation = 'flash_attention_2'
|
|
|
99 |
self.vision_model = SiglipVisionModel(config.vision_config)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(f'{config.vision_config.model_type} is not implemented.')
|
102 |
|
103 |
if language_model is not None:
|
104 |
self.language_model = language_model
|
105 |
else:
|
106 |
+
if config.text_config.architectures[0] == 'LlamaForCausalLM':
|
107 |
+
self.language_model = LlamaForCausalLM(config.text_config)
|
108 |
+
elif config.text_config.architectures[0] == 'Qwen2ForCausalLM':
|
109 |
+
# assert config.text_config._attn_implementation == 'flash_attention_2', f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
110 |
+
self.language_model = Qwen2ForCausalLM(config.text_config)
|
111 |
else:
|
112 |
+
raise NotImplementedError(f'{config.text_config.architectures[0]} is not implemented.')
|
113 |
|
114 |
vit_hidden_size = config.vision_config.hidden_size
|
115 |
+
llm_hidden_size = config.text_config.hidden_size
|
|
|
116 |
|
117 |
self.mlp1 = nn.Sequential(
|
118 |
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
|
|
120 |
nn.GELU(),
|
121 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
122 |
)
|
123 |
+
self.image_token_index = config.image_token_index
|
124 |
+
self.neftune_alpha = None
|
125 |
+
|
126 |
|
127 |
if config.use_backbone_lora:
|
128 |
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
129 |
|
130 |
+
self.use_llm_lora = config.use_llm_lora
|
131 |
if config.use_llm_lora:
|
132 |
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
133 |
+
|
134 |
+
self.check_forward_kwargs()
|
135 |
+
|
136 |
+
def check_forward_kwargs(self):
|
137 |
+
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
138 |
+
# has special handling for functions with **kwargs parameters that would affect
|
139 |
+
# how our model is processed during training and inference.
|
140 |
+
forward_params = inspect.signature(self.forward).parameters
|
141 |
+
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
142 |
+
|
143 |
+
|
144 |
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
145 |
lora_config = LoraConfig(
|
146 |
r=r,
|
147 |
+
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj',
|
148 |
+
'mlp.fc1', 'mlp.fc2'],
|
149 |
lora_alpha=lora_alpha,
|
150 |
lora_dropout=lora_dropout,
|
151 |
)
|
152 |
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
153 |
self.vision_model.print_trainable_parameters()
|
154 |
|
155 |
+
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
156 |
lora_config = LoraConfig(
|
157 |
r=r,
|
158 |
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
|
|
164 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
165 |
self.language_model.enable_input_require_grads()
|
166 |
self.language_model.print_trainable_parameters()
|
167 |
+
self.use_llm_lora = True
|
168 |
+
|
169 |
def forward(
|
170 |
self,
|
171 |
pixel_values: torch.FloatTensor,
|
|
|
179 |
output_attentions: Optional[bool] = None,
|
180 |
output_hidden_states: Optional[bool] = None,
|
181 |
return_dict: Optional[bool] = None,
|
182 |
+
num_tiles_list: Optional[List[torch.Tensor]] = None,
|
183 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
184 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
185 |
|
|
|
201 |
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
|
202 |
|
203 |
input_ids = input_ids.reshape(B * N)
|
204 |
+
selected = (input_ids == self.image_token_index)
|
205 |
try:
|
206 |
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
207 |
except Exception as e:
|
|
|
258 |
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
259 |
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
260 |
int(c / (scale_factor * scale_factor)))
|
261 |
+
|
262 |
x = x.permute(0, 2, 1, 3).contiguous()
|
263 |
return x
|
264 |
|
265 |
def extract_feature(self, pixel_values):
|
|
|
|
|
|
|
266 |
if self.select_layer == -1:
|
267 |
vit_embeds = self.vision_model(
|
268 |
pixel_values=pixel_values,
|
269 |
output_hidden_states=False,
|
270 |
return_dict=True)
|
|
|
271 |
if hasattr(vit_embeds, 'last_hidden_state'):
|
272 |
vit_embeds = vit_embeds.last_hidden_state
|
273 |
+
|
274 |
else:
|
275 |
vit_embeds = self.vision_model(
|
276 |
pixel_values=pixel_values,
|
277 |
output_hidden_states=True,
|
278 |
return_dict=True).hidden_states[self.select_layer]
|
279 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
281 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
|
|
282 |
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
283 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
284 |
+
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
285 |
+
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
286 |
+
else:
|
287 |
+
vit_embeds = self.mlp1(vit_embeds)
|
288 |
|
289 |
return vit_embeds
|
290 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
@torch.no_grad()
|
292 |
def generate(
|
293 |
self,
|
|
|
297 |
visual_features: Optional[torch.FloatTensor] = None,
|
298 |
generation_config: Optional[GenerationConfig] = None,
|
299 |
output_hidden_states: Optional[bool] = None,
|
300 |
+
image_sizes: Optional[List[Tuple[int, int]]] = None,
|
301 |
**generate_kwargs,
|
302 |
) -> torch.LongTensor:
|
303 |
|
|
|
304 |
if pixel_values is not None:
|
305 |
if visual_features is not None:
|
306 |
vit_embeds = visual_features
|
|
|
312 |
input_embeds = input_embeds.reshape(B * N, C)
|
313 |
|
314 |
input_ids = input_ids.reshape(B * N)
|
315 |
+
selected = (input_ids == self.config.image_token_index)
|
316 |
assert selected.sum() != 0
|
317 |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
318 |
|
|
|
330 |
)
|
331 |
|
332 |
return outputs
|
333 |
+
|
334 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
335 |
def get_input_embeddings(self):
|
336 |
return self.language_model.get_input_embeddings()
|
337 |
+
|
338 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
339 |
+
def set_input_embeddings(self, value):
|
340 |
+
self.language_model.set_input_embeddings(value)
|
341 |
+
|
342 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
343 |
def get_output_embeddings(self):
|
344 |
+
return self.language_model.get_output_embeddings()
|
345 |
+
|
346 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
347 |
+
def set_output_embeddings(self, new_embeddings):
|
348 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
349 |
+
|
350 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
351 |
+
def set_decoder(self, decoder):
|
352 |
+
self.language_model.set_decoder(decoder)
|
353 |
+
|
354 |
+
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
355 |
+
def get_decoder(self):
|
356 |
+
return self.language_model.get_decoder()
|
357 |
+
|
preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_eagle2_5_vl.Eagle2_5_VLProcessor",
|
4 |
+
"AutoImageProcessor": "image_processing_eagle2_5_vl_fast.Eagle2_5_VLImageProcessorFast"
|
5 |
+
},
|
6 |
+
"image_processor_type": "Eagle2_5_VLImageProcessorFast",
|
7 |
+
"processor_class": "Eagle2_5_VLProcessor",
|
8 |
+
"image_mean": [0.5, 0.5, 0.5],
|
9 |
+
"image_std": [0.5, 0.5, 0.5],
|
10 |
+
"do_resize": false,
|
11 |
+
"size": {
|
12 |
+
"height": 448,
|
13 |
+
"width": 448
|
14 |
+
},
|
15 |
+
"max_dynamic_tiles": 12,
|
16 |
+
"min_dynamic_tiles": 1,
|
17 |
+
"tokens_per_tile": 256,
|
18 |
+
"use_thumbnail": true,
|
19 |
+
"do_rescale": true,
|
20 |
+
"do_normalize": true,
|
21 |
+
"do_pad": false,
|
22 |
+
"do_convert_rgb": true
|
23 |
+
}
|
processing_eagle2_5_vl.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for Eagle2_5_VL.
|
17 |
+
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
18 |
+
"""
|
19 |
+
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
from typing import Iterable, List, Union, Literal
|
23 |
+
import base64
|
24 |
+
import sys
|
25 |
+
import time
|
26 |
+
import warnings
|
27 |
+
from functools import lru_cache
|
28 |
+
from io import BytesIO
|
29 |
+
import re
|
30 |
+
import requests
|
31 |
+
import torch
|
32 |
+
import torchvision
|
33 |
+
from packaging import version
|
34 |
+
from PIL import Image
|
35 |
+
from torchvision import io
|
36 |
+
from typing import Optional, Any
|
37 |
+
import numpy as np
|
38 |
+
|
39 |
+
from transformers.feature_extraction_utils import BatchFeature
|
40 |
+
from transformers.image_processing_utils import select_best_resolution
|
41 |
+
from transformers.image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
42 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
43 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
44 |
+
from transformers.utils import logging
|
45 |
+
from transformers.models.auto import AutoImageProcessor
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
FRAME_FACTOR = 2
|
52 |
+
FPS = 2.0
|
53 |
+
FPS_MIN_FRAMES = 4
|
54 |
+
FPS_MAX_FRAMES = 256
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def adjust_by_factor(number: int, factor: int, method: Literal['round', 'ceil', 'floor'] = 'round') -> int:
|
59 |
+
"""Adjusts 'number' to the nearest, ceiling, or floor multiple of 'factor'."""
|
60 |
+
op = {'round': round, 'ceil': math.ceil, 'floor': math.floor}[method]
|
61 |
+
return op(number / factor) * factor
|
62 |
+
|
63 |
+
|
64 |
+
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
65 |
+
if pil_image.mode == 'RGBA':
|
66 |
+
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
67 |
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
68 |
+
return white_background
|
69 |
+
else:
|
70 |
+
return pil_image.convert("RGB")
|
71 |
+
|
72 |
+
|
73 |
+
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
74 |
+
if "image" in ele:
|
75 |
+
image = ele["image"]
|
76 |
+
else:
|
77 |
+
image = ele["image_url"]
|
78 |
+
image_obj = None
|
79 |
+
if isinstance(image, Image.Image):
|
80 |
+
image_obj = image
|
81 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
82 |
+
response = requests.get(image, stream=True)
|
83 |
+
image_obj = Image.open(BytesIO(response.content))
|
84 |
+
elif image.startswith("file://"):
|
85 |
+
image_obj = Image.open(image[7:])
|
86 |
+
elif image.startswith("data:image"):
|
87 |
+
if "base64," in image:
|
88 |
+
_, base64_data = image.split("base64,", 1)
|
89 |
+
data = base64.b64decode(base64_data)
|
90 |
+
image_obj = Image.open(BytesIO(data))
|
91 |
+
else:
|
92 |
+
image_obj = Image.open(image)
|
93 |
+
if image_obj is None:
|
94 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
95 |
+
image = to_rgb(image_obj)
|
96 |
+
if 'scale_factor' in ele:
|
97 |
+
scale_factor = ele['scale_factor']
|
98 |
+
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
99 |
+
return image
|
100 |
+
|
101 |
+
|
102 |
+
def smart_nframes(
|
103 |
+
ele: dict,
|
104 |
+
total_frames: int,
|
105 |
+
video_fps: int | float,
|
106 |
+
) -> int:
|
107 |
+
"""calculate the number of frames for video used for model inputs.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
ele (dict): a dict contains the configuration of video.
|
111 |
+
support either `fps` or `nframes`:
|
112 |
+
- nframes: the number of frames to extract for model inputs.
|
113 |
+
- fps: the fps to extract frames for model inputs.
|
114 |
+
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
115 |
+
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
116 |
+
total_frames (int): the original total number of frames of the video.
|
117 |
+
video_fps (int | float): the original fps of the video.
|
118 |
+
|
119 |
+
Raises:
|
120 |
+
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
int: the number of frames for video used for model inputs.
|
124 |
+
"""
|
125 |
+
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
|
126 |
+
if "nframes" in ele:
|
127 |
+
nframes = adjust_by_factor(ele["nframes"], FRAME_FACTOR, method='round')
|
128 |
+
else:
|
129 |
+
fps = ele.get("fps", FPS)
|
130 |
+
min_frames = adjust_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR, method='ceil')
|
131 |
+
max_frames = adjust_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR, method='floor')
|
132 |
+
nframes = total_frames / video_fps * fps
|
133 |
+
if nframes > total_frames:
|
134 |
+
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
|
135 |
+
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
|
136 |
+
nframes = adjust_by_factor(nframes, FRAME_FACTOR, method='floor')
|
137 |
+
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
138 |
+
nframes = total_frames
|
139 |
+
# raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
|
140 |
+
return nframes
|
141 |
+
|
142 |
+
|
143 |
+
def _read_video_torchvision(
|
144 |
+
ele: dict,
|
145 |
+
) -> (torch.Tensor, float, list):
|
146 |
+
"""read video using torchvision.io.read_video and return also per-frame timestamps"""
|
147 |
+
video_path = ele["video"]
|
148 |
+
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
149 |
+
if "http://" in video_path or "https://" in video_path:
|
150 |
+
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
|
151 |
+
if "file://" in video_path:
|
152 |
+
video_path = video_path[7:]
|
153 |
+
st = time.time()
|
154 |
+
video, audio, info = io.read_video(
|
155 |
+
video_path,
|
156 |
+
start_pts=ele.get("video_start", 0.0),
|
157 |
+
end_pts=ele.get("video_end", None),
|
158 |
+
pts_unit="sec",
|
159 |
+
output_format="TCHW",
|
160 |
+
)
|
161 |
+
total_frames, video_fps = video.size(0), info["video_fps"]
|
162 |
+
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
163 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
164 |
+
# Calculate frame indices and corresponding timestamps (based on video start time)
|
165 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
166 |
+
start_time = ele.get("video_start", 0.0)
|
167 |
+
timestamps = (start_time + idx.to(torch.float32) / video_fps).tolist()
|
168 |
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
169 |
+
video = video[idx]
|
170 |
+
return video, sample_fps, timestamps
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
def is_decord_available() -> bool:
|
175 |
+
import importlib.util
|
176 |
+
|
177 |
+
return importlib.util.find_spec("decord") is not None
|
178 |
+
|
179 |
+
def _read_video_decord(
|
180 |
+
ele: dict,
|
181 |
+
) -> (torch.Tensor, float, list):
|
182 |
+
"""read video using decord.VideoReader and return also per-frame timestamps"""
|
183 |
+
import decord
|
184 |
+
video_path = ele["video"]
|
185 |
+
st = time.time()
|
186 |
+
vr = decord.VideoReader(video_path)
|
187 |
+
if 'video_start' in ele or 'video_end' in ele:
|
188 |
+
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
|
189 |
+
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
190 |
+
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
191 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
192 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
193 |
+
start_time = ele.get("video_start", 0.0) # TODO:
|
194 |
+
timestamps = [start_time + i / video_fps for i in idx]
|
195 |
+
video = vr.get_batch(idx).asnumpy()
|
196 |
+
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
197 |
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
198 |
+
return video, sample_fps, timestamps
|
199 |
+
|
200 |
+
|
201 |
+
VIDEO_READER_BACKENDS = {
|
202 |
+
"decord": _read_video_decord,
|
203 |
+
"torchvision": _read_video_torchvision,
|
204 |
+
}
|
205 |
+
|
206 |
+
|
207 |
+
@lru_cache(maxsize=1)
|
208 |
+
def get_video_reader_backend() -> str:
|
209 |
+
if is_decord_available():
|
210 |
+
video_reader_backend = "decord"
|
211 |
+
else:
|
212 |
+
video_reader_backend = "torchvision"
|
213 |
+
return video_reader_backend
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
def fetch_video(ele: dict, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
|
219 |
+
if isinstance(ele["video"], str):
|
220 |
+
video_reader_backend = get_video_reader_backend()
|
221 |
+
try:
|
222 |
+
video, sample_fps, timestamps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
223 |
+
except Exception as e:
|
224 |
+
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
|
225 |
+
video, sample_fps, timestamps = VIDEO_READER_BACKENDS["torchvision"](ele)
|
226 |
+
|
227 |
+
nframes, _, height, width = video.shape
|
228 |
+
|
229 |
+
if return_video_sample_fps:
|
230 |
+
return video, sample_fps, timestamps
|
231 |
+
return video
|
232 |
+
else:
|
233 |
+
assert isinstance(ele["video"], (list, tuple))
|
234 |
+
process_info = ele.copy()
|
235 |
+
process_info.pop("type", None)
|
236 |
+
process_info.pop("video", None)
|
237 |
+
images = [
|
238 |
+
fetch_image({"image": video_element, **process_info})
|
239 |
+
for video_element in ele["video"]
|
240 |
+
]
|
241 |
+
nframes = adjust_by_factor(len(images), FRAME_FACTOR, method='ceil')
|
242 |
+
if len(images) < nframes:
|
243 |
+
images.extend([images[-1]] * (nframes - len(images)))
|
244 |
+
|
245 |
+
timestamps = [-1 for i in range(nframes)] # not sure about this
|
246 |
+
if return_video_sample_fps:
|
247 |
+
return images, process_info.pop("fps", 2.0), timestamps
|
248 |
+
return images
|
249 |
+
|
250 |
+
class Eagle2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
251 |
+
# see processing_utils.ProcessingKwargs documentation for usage.
|
252 |
+
_defaults = {
|
253 |
+
"text_kwargs": {
|
254 |
+
"padding": False,
|
255 |
+
},
|
256 |
+
"images_kwargs": {},
|
257 |
+
"videos_kwargs": {"max_dynamic_tiles": 1},
|
258 |
+
}
|
259 |
+
|
260 |
+
|
261 |
+
class Eagle2_5_VLProcessor(ProcessorMixin):
|
262 |
+
r"""
|
263 |
+
Constructs a Eagle2_5_VL processor which wraps a Eagle2_5_VL video processor, Eagle2_5_VL image processor and a Eagle2_5_VL tokenizer into a single processor.
|
264 |
+
|
265 |
+
[`Eagle2_5_VLProcessor`] offers all the functionalities of [`Eagle2_5_VLVideoProcessor`], [`Eagle2_5_VLImageProcessor`] and [`Eagle2_5_VLTokenizer`]. See the
|
266 |
+
[`~Eagle2_5_VLVideoProcessor.__call__`], [`~Eagle2_5_VLProcessor.__call__`] and [`~Eagle2_5_VLProcessor.decode`] for more information.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
270 |
+
The image processor is a required input.
|
271 |
+
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
272 |
+
The tokenizer is a required input.
|
273 |
+
num_image_tokens (`int`, *optional*):
|
274 |
+
Number of image tokens for one imagethat will be returned by vision tower.
|
275 |
+
vision_feature_select_strategy (`str`, *optional*):
|
276 |
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
277 |
+
Shoudl be same as in model's config
|
278 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
279 |
+
in a chat into a tokenizable string.
|
280 |
+
image_token (`str`, *optional*, defaults to `"<image>"`):
|
281 |
+
Special token used to denote image location.
|
282 |
+
video_token (`str`, *optional*, defaults to `"<video>"`):
|
283 |
+
Special token used to denote video location.
|
284 |
+
"""
|
285 |
+
|
286 |
+
attributes = ["image_processor", "tokenizer"]
|
287 |
+
valid_kwargs = [
|
288 |
+
"chat_template",
|
289 |
+
"num_image_tokens",
|
290 |
+
"vision_feature_select_strategy",
|
291 |
+
"image_token",
|
292 |
+
"video_token",
|
293 |
+
"images_kwargs",
|
294 |
+
"videos_kwargs",
|
295 |
+
"text_kwargs",
|
296 |
+
]
|
297 |
+
image_processor_class = "AutoImageProcessor"
|
298 |
+
tokenizer_class = "AutoTokenizer"
|
299 |
+
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
image_processor=None,
|
303 |
+
tokenizer=None,
|
304 |
+
vision_feature_select_strategy=None,
|
305 |
+
chat_template=None,
|
306 |
+
image_token='<IMG_CONTEXT>',
|
307 |
+
video_token='<IMG_CONTEXT>',
|
308 |
+
tokens_per_tile=256,
|
309 |
+
image_placeholder='image',
|
310 |
+
video_placeholder='video',
|
311 |
+
image_start_token='<img>',
|
312 |
+
image_end_token='</img>',
|
313 |
+
**kwargs,
|
314 |
+
):
|
315 |
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
316 |
+
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
317 |
+
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
318 |
+
self.image_token_id = (
|
319 |
+
tokenizer.image_token_id
|
320 |
+
if getattr(tokenizer, "image_token_id", None)
|
321 |
+
else tokenizer.convert_tokens_to_ids(self.image_token)
|
322 |
+
)
|
323 |
+
self.video_token_id = (
|
324 |
+
tokenizer.video_token_id
|
325 |
+
if getattr(tokenizer, "video_token_id", None)
|
326 |
+
else tokenizer.convert_tokens_to_ids(self.video_token)
|
327 |
+
)
|
328 |
+
self.image_placeholder = image_placeholder
|
329 |
+
self.video_placeholder = video_placeholder
|
330 |
+
self.tokens_per_tile = tokens_per_tile
|
331 |
+
self.image_start_token = image_start_token
|
332 |
+
self.image_end_token = image_end_token
|
333 |
+
if 'auto_map' in kwargs:
|
334 |
+
self.auto_map = kwargs['auto_map']
|
335 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
336 |
+
|
337 |
+
|
338 |
+
def replace_media_placeholder(self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs):
|
339 |
+
|
340 |
+
num_of_images_in_this_sample = 0
|
341 |
+
num_of_videos_in_this_sample = 0
|
342 |
+
# Regular expression pattern to match formats like <image-1> or <video-2>
|
343 |
+
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
344 |
+
unified_frame_list = []
|
345 |
+
|
346 |
+
image_min_dynamic_tiles = output_kwargs['images_kwargs'].get("min_dynamic_tiles", self.image_processor.min_dynamic_tiles)
|
347 |
+
image_max_dynamic_tiles = output_kwargs['images_kwargs'].get("max_dynamic_tiles", self.image_processor.max_dynamic_tiles)
|
348 |
+
image_use_thumbnail = output_kwargs['images_kwargs'].get("use_thumbnail", self.image_processor.use_thumbnail)
|
349 |
+
video_min_dynamic_tiles = output_kwargs['videos_kwargs'].get("min_dynamic_tiles", self.image_processor.min_dynamic_tiles)
|
350 |
+
video_max_dynamic_tiles = output_kwargs['videos_kwargs'].get("max_dynamic_tiles", self.image_processor.max_dynamic_tiles)
|
351 |
+
video_use_thumbnail = output_kwargs['videos_kwargs'].get("use_thumbnail", self.image_processor.use_thumbnail)
|
352 |
+
|
353 |
+
tile_size = self.image_processor.size.get("height", 448)
|
354 |
+
|
355 |
+
|
356 |
+
# Function to replace tags in a single text
|
357 |
+
def replace_in_text(text):
|
358 |
+
# repl callback function for each match replacement operation
|
359 |
+
def repl(match):
|
360 |
+
nonlocal unified_frame_list
|
361 |
+
nonlocal num_of_images_in_this_sample
|
362 |
+
nonlocal num_of_videos_in_this_sample
|
363 |
+
media_type = match.group(1) # 'image' or 'video'
|
364 |
+
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
365 |
+
# Select the corresponding path based on media type
|
366 |
+
idx_mapper = {0: "first", 1: "second", 2: "third", 3: "fourth", 4: "fifth", 5: "sixth", 6: "seventh", 7: "eighth", 8: "ninth", 9: "tenth"}
|
367 |
+
if media_type == 'image':
|
368 |
+
image_inputs = self.image_processor(images=[image_list[idx_in_list]], videos=None, **output_kwargs["images_kwargs"])
|
369 |
+
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
370 |
+
special_placeholder = f"<image {idx_in_list+1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
371 |
+
unified_frame_list.append(image_inputs)
|
372 |
+
num_of_images_in_this_sample += 1
|
373 |
+
|
374 |
+
elif media_type == 'video':
|
375 |
+
video_inputs = self.image_processor(images=None, videos=[video_list[idx_in_list]], **output_kwargs["videos_kwargs"])
|
376 |
+
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
377 |
+
image_sizes = video_inputs["image_sizes"]
|
378 |
+
if timestamps_list is not None and -1 not in timestamps_list:
|
379 |
+
frame_timestamps = timestamps_list[idx_in_list]
|
380 |
+
else:
|
381 |
+
frame_timestamps = None
|
382 |
+
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
383 |
+
|
384 |
+
num_of_tiles_each_frame = [
|
385 |
+
self.get_number_tiles_based_on_image_size(image_size, video_min_dynamic_tiles, video_max_dynamic_tiles, video_use_thumbnail, tile_size)
|
386 |
+
for image_size in image_sizes
|
387 |
+
]
|
388 |
+
assert sum(num_of_tiles_each_frame) == num_all_tiles, f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
389 |
+
|
390 |
+
if frame_timestamps is not None:
|
391 |
+
assert len(frame_timestamps) == len(num_of_tiles_each_frame), f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
392 |
+
special_placeholder = [f"Frame {i+1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}" for i, num_of_tiles in enumerate(num_of_tiles_each_frame)]
|
393 |
+
else:
|
394 |
+
special_placeholder = [f"Frame {i+1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}" for i, num_of_tiles in enumerate(num_of_tiles_each_frame)]
|
395 |
+
|
396 |
+
if sampled_fps is not None:
|
397 |
+
special_placeholder = f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: " + "".join(special_placeholder)
|
398 |
+
else:
|
399 |
+
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(special_placeholder)
|
400 |
+
unified_frame_list.append(video_inputs)
|
401 |
+
num_of_videos_in_this_sample += 1
|
402 |
+
else:
|
403 |
+
raise ValueError(f'Unknown media type: {media_type}')
|
404 |
+
return special_placeholder
|
405 |
+
return pattern.sub(repl, text)
|
406 |
+
text = replace_in_text(text)
|
407 |
+
if len(unified_frame_list) > 0:
|
408 |
+
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
409 |
+
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
410 |
+
else:
|
411 |
+
pixel_values = None
|
412 |
+
image_sizes = None
|
413 |
+
return text, pixel_values, image_sizes, num_of_images_in_this_sample, num_of_videos_in_this_sample
|
414 |
+
|
415 |
+
def __call__(
|
416 |
+
self,
|
417 |
+
images: ImageInput = None,
|
418 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
419 |
+
audio=None,
|
420 |
+
videos: VideoInput = None,
|
421 |
+
**kwargs: Unpack[Eagle2_5_VLProcessorKwargs],
|
422 |
+
) -> BatchFeature:
|
423 |
+
"""
|
424 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
425 |
+
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
426 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
427 |
+
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
428 |
+
of the above two methods for more information.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
432 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
433 |
+
tensor. Both channels-first and channels-last formats are supported.
|
434 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
435 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
436 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
437 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
438 |
+
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
439 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
443 |
+
|
444 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
445 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
446 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
447 |
+
`None`).
|
448 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
449 |
+
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
450 |
+
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
451 |
+
"""
|
452 |
+
|
453 |
+
|
454 |
+
output_kwargs = self._merge_kwargs(
|
455 |
+
Eagle2_5_VLProcessorKwargs,
|
456 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
457 |
+
**kwargs,
|
458 |
+
)
|
459 |
+
|
460 |
+
if isinstance(text, str):
|
461 |
+
text_list = [text]
|
462 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
463 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
464 |
+
elif isinstance(text, list) and isinstance(text[0], str):
|
465 |
+
text_list = text
|
466 |
+
|
467 |
+
if images is None: images = []
|
468 |
+
if videos is None: videos = []
|
469 |
+
|
470 |
+
pixel_values_list = []
|
471 |
+
image_sizes_list = []
|
472 |
+
new_sample_list = []
|
473 |
+
image_start_idx = 0
|
474 |
+
video_start_idx = 0
|
475 |
+
timestamps_batch = output_kwargs['videos_kwargs'].pop("timestamps", None)
|
476 |
+
fps_batch = output_kwargs['videos_kwargs'].pop("fps", None)
|
477 |
+
for sample in text_list:
|
478 |
+
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
479 |
+
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
480 |
+
sample, pixel_values, image_sizes, num_of_images_in_this_sample, num_of_videos_in_this_sample = self.replace_media_placeholder(sample, images[image_start_idx:], videos[video_start_idx:], timestamps_list, fps_list, **output_kwargs)
|
481 |
+
new_sample_list.append(sample)
|
482 |
+
if pixel_values is not None:
|
483 |
+
pixel_values_list.append(pixel_values)
|
484 |
+
image_sizes_list.append(image_sizes)
|
485 |
+
image_start_idx += num_of_images_in_this_sample
|
486 |
+
video_start_idx += num_of_videos_in_this_sample
|
487 |
+
|
488 |
+
if len(pixel_values_list) > 0:
|
489 |
+
image_inputs = {"pixel_values": torch.cat(pixel_values_list), "image_sizes": torch.cat(image_sizes_list)}
|
490 |
+
else:
|
491 |
+
image_inputs = {}
|
492 |
+
video_inputs = {}
|
493 |
+
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
494 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
495 |
+
|
496 |
+
def get_number_tiles_based_on_image_size(self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int) -> int:
|
497 |
+
"""
|
498 |
+
Get the number of tiles based on the image size.
|
499 |
+
"""
|
500 |
+
orig_height, orig_width = image_size
|
501 |
+
aspect_ratio = orig_width / orig_height
|
502 |
+
# calculate the existing image aspect ratio
|
503 |
+
target_ratios = set(
|
504 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
505 |
+
i * j <= max_num and i * j >= min_num)
|
506 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
507 |
+
|
508 |
+
# find the closest aspect ratio to the target
|
509 |
+
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
510 |
+
aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
|
511 |
+
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
512 |
+
if use_thumbnail and tiles_num > 1:
|
513 |
+
tiles_num += 1
|
514 |
+
return tiles_num
|
515 |
+
|
516 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
517 |
+
def batch_decode(self, *args, **kwargs):
|
518 |
+
"""
|
519 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
520 |
+
refer to the docstring of this method for more information.
|
521 |
+
"""
|
522 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
523 |
+
|
524 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
525 |
+
def decode(self, *args, **kwargs):
|
526 |
+
"""
|
527 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
528 |
+
the docstring of this method for more information.
|
529 |
+
"""
|
530 |
+
return self.tokenizer.decode(*args, **kwargs)
|
531 |
+
|
532 |
+
@property
|
533 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
534 |
+
def model_input_names(self):
|
535 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
536 |
+
image_processor_input_names = self.image_processor.model_input_names
|
537 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
538 |
+
|
539 |
+
# override to save video-config in a separate config file
|
540 |
+
def save_pretrained(self, save_directory, **kwargs):
|
541 |
+
if os.path.isfile(save_directory):
|
542 |
+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
543 |
+
os.makedirs(save_directory, exist_ok=True)
|
544 |
+
|
545 |
+
outputs = super().save_pretrained(save_directory, **kwargs)
|
546 |
+
return outputs
|
547 |
+
|
548 |
+
# override to load video-config from a separate config file
|
549 |
+
@classmethod
|
550 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
551 |
+
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
552 |
+
|
553 |
+
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
554 |
+
if isinstance(processor, tuple):
|
555 |
+
processor = processor[0]
|
556 |
+
return processor
|
557 |
+
|
558 |
+
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
559 |
+
def process_vision_info(
|
560 |
+
self,
|
561 |
+
conversations: list[dict] | list[list[dict]],
|
562 |
+
return_video_kwargs: bool = False,
|
563 |
+
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
|
564 |
+
|
565 |
+
vision_infos = self.extract_vision_info(conversations)
|
566 |
+
## Read images or videos
|
567 |
+
image_inputs = []
|
568 |
+
video_inputs = []
|
569 |
+
video_sample_fps_list = []
|
570 |
+
video_timestamps_list = []
|
571 |
+
for vision_info in vision_infos:
|
572 |
+
if "image" in vision_info or "image_url" in vision_info:
|
573 |
+
image_inputs.append(fetch_image(vision_info))
|
574 |
+
elif "video" in vision_info:
|
575 |
+
video_input, video_sample_fps, video_timestamps = fetch_video(vision_info, return_video_sample_fps=True)
|
576 |
+
video_sample_fps_list.append(video_sample_fps)
|
577 |
+
video_inputs.append(video_input)
|
578 |
+
video_timestamps_list.append(video_timestamps)
|
579 |
+
else:
|
580 |
+
raise ValueError("image, image_url or video should in content.")
|
581 |
+
if len(image_inputs) == 0:
|
582 |
+
image_inputs = None
|
583 |
+
if len(video_inputs) == 0:
|
584 |
+
video_inputs = None
|
585 |
+
if return_video_kwargs:
|
586 |
+
return image_inputs, video_inputs, {'fps': video_sample_fps_list, 'timestamps': video_timestamps_list}
|
587 |
+
return image_inputs, video_inputs
|
588 |
+
|
589 |
+
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
590 |
+
vision_infos = []
|
591 |
+
if isinstance(conversations[0], dict):
|
592 |
+
conversations = [conversations]
|
593 |
+
for conversation in conversations:
|
594 |
+
for message in conversation:
|
595 |
+
if isinstance(message["content"], list):
|
596 |
+
for ele in message["content"]:
|
597 |
+
if (
|
598 |
+
"image" in ele
|
599 |
+
or "image_url" in ele
|
600 |
+
or "video" in ele
|
601 |
+
or ele["type"] in ("image", "image_url", "video")
|
602 |
+
):
|
603 |
+
vision_infos.append(ele)
|
604 |
+
return vision_infos
|
605 |
+
|
606 |
+
def py_apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
|
607 |
+
"""
|
608 |
+
Renders a chat conversation using a custom template with verification of tokens.
|
609 |
+
|
610 |
+
The purpose is to check for the existence of tokens like "<image-1>" or "<video-1>"
|
611 |
+
in the message text and skip adding them if they already exist.
|
612 |
+
|
613 |
+
Args:
|
614 |
+
messages (list): A list of message dictionaries. Each message should contain:
|
615 |
+
- 'role': The role of the speaker (e.g., 'system', 'user', 'assistant').
|
616 |
+
- 'content': Either a string or a list of content blocks. In the list each block may contain:
|
617 |
+
* 'type': The type of content, such as 'image' or 'video'.
|
618 |
+
* 'text': The actual text if present.
|
619 |
+
* Other keys such as 'image', 'image_url', or 'video'.
|
620 |
+
add_generation_prompt (bool): If True, appends "<|im_start|>assistant" at the end of the rendered string.
|
621 |
+
tokenize (bool): If True, tokenize the rendered string.
|
622 |
+
Returns:
|
623 |
+
str: The final rendered chat string according to the specified template.
|
624 |
+
"""
|
625 |
+
assert tokenize == False, "tokenize is not supported yet"
|
626 |
+
result = ""
|
627 |
+
image_count = 0
|
628 |
+
video_count = 0
|
629 |
+
|
630 |
+
message_text = ""
|
631 |
+
for idx, message in enumerate(messages):
|
632 |
+
if message.get('role') != 'user': continue
|
633 |
+
# If content is a string, simply output it.
|
634 |
+
content = message.get('content')
|
635 |
+
if isinstance(content, str):
|
636 |
+
message_text += content
|
637 |
+
elif isinstance(content, list):
|
638 |
+
# Process each content item.
|
639 |
+
for item in content:
|
640 |
+
# If the block is a dictionary and contains text, add it to message_text.
|
641 |
+
if isinstance(item, dict) and "text" in item:
|
642 |
+
message_text += item["text"]
|
643 |
+
# If an item is already a string in the list, add it directly.
|
644 |
+
elif isinstance(item, str):
|
645 |
+
message_text += item
|
646 |
+
|
647 |
+
for idx, message in enumerate(messages):
|
648 |
+
# If the first message is not from the system, prepend a default system message.
|
649 |
+
if idx == 0 and message.get('role') != 'system':
|
650 |
+
result += "<|im_start|>system\n"
|
651 |
+
result += "You are a helpful assistant.\n"
|
652 |
+
result += "<|im_end|>\n"
|
653 |
+
|
654 |
+
# Start the current message block with its role.
|
655 |
+
result += f"<|im_start|>{message.get('role', '')}\n"
|
656 |
+
content = message.get('content')
|
657 |
+
|
658 |
+
# If content is a string, simply output it.
|
659 |
+
if isinstance(content, str):
|
660 |
+
result += content
|
661 |
+
result += "<|im_end|>\n"
|
662 |
+
else:
|
663 |
+
# Process each content item.
|
664 |
+
for item in content:
|
665 |
+
# Check if the item is an image (explicitly by type or by key presence).
|
666 |
+
if (isinstance(item, dict) and (item.get('type') == 'image' or 'image' in item or 'image_url' in item)):
|
667 |
+
image_count += 1
|
668 |
+
candidate_token = f"<image-{image_count}>"
|
669 |
+
# Only add the token if it is not already present in the collected text.
|
670 |
+
if candidate_token not in message_text:
|
671 |
+
result += candidate_token
|
672 |
+
# Check if the item is a video.
|
673 |
+
elif (isinstance(item, dict) and (item.get('type') == 'video' or 'video' in item)):
|
674 |
+
video_count += 1
|
675 |
+
candidate_token = f"<video-{video_count}>"
|
676 |
+
# Only add the token if it is not already present.
|
677 |
+
if candidate_token not in message_text:
|
678 |
+
result += candidate_token
|
679 |
+
# If the item contains text, add it.
|
680 |
+
elif isinstance(item, dict) and 'text' in item:
|
681 |
+
result += item['text']
|
682 |
+
# If the item is a string (and not handled already), add it.
|
683 |
+
elif isinstance(item, str):
|
684 |
+
result += item
|
685 |
+
result += "<|im_end|>\n"
|
686 |
+
|
687 |
+
# Optionally add assistant generation prompt at the end.
|
688 |
+
if add_generation_prompt:
|
689 |
+
result += "<|im_start|>assistant\n"
|
690 |
+
|
691 |
+
return result
|
692 |
+
|
693 |
+
|
694 |
+
@classmethod
|
695 |
+
def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
|
696 |
+
"""
|
697 |
+
Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters.
|
698 |
+
|
699 |
+
Args:
|
700 |
+
processor_dict (`Dict[str, Any]`):
|
701 |
+
Dictionary that will be used to instantiate the processor object. Such a dictionary can be
|
702 |
+
retrieved from a pretrained checkpoint by leveraging the
|
703 |
+
[`~processing_utils.ProcessingMixin.to_dict`] method.
|
704 |
+
kwargs (`Dict[str, Any]`):
|
705 |
+
Additional parameters from which to initialize the processor object.
|
706 |
+
|
707 |
+
Returns:
|
708 |
+
[`~processing_utils.ProcessingMixin`]: The processor object instantiated from those
|
709 |
+
parameters.
|
710 |
+
"""
|
711 |
+
processor_dict = processor_dict.copy()
|
712 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
713 |
+
|
714 |
+
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
|
715 |
+
# If we don't pop, some specific kwargs will raise a warning
|
716 |
+
if "processor_class" in processor_dict:
|
717 |
+
del processor_dict["processor_class"]
|
718 |
+
|
719 |
+
#if "auto_map" in processor_dict:
|
720 |
+
# del processor_dict["auto_map"]
|
721 |
+
|
722 |
+
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
|
723 |
+
processor = cls(*args, **processor_dict)
|
724 |
+
|
725 |
+
# Update processor with kwargs if needed
|
726 |
+
for key in set(kwargs.keys()):
|
727 |
+
if hasattr(processor, key):
|
728 |
+
setattr(processor, key, kwargs.pop(key))
|
729 |
+
|
730 |
+
kwargs.update(unused_kwargs)
|
731 |
+
logger.info(f"Processor {processor}")
|
732 |
+
if return_unused_kwargs:
|
733 |
+
return processor, kwargs
|
734 |
+
else:
|
735 |
+
return processor
|
736 |
+
|
737 |
+
|
738 |
+
__all__ = ["Eagle2_5_VLProcessor"]
|
processor_config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_eagle2_5_vl.Eagle2_5_VLProcessor",
|
4 |
+
"AutoImageProcessor": "image_processing_eagle2_5_vl_fast.Eagle2_5_VLImageProcessorFast"
|
5 |
+
},
|
6 |
+
"image_end_token": "</img>",
|
7 |
+
"image_placeholder": "image",
|
8 |
+
"image_start_token": "<img>",
|
9 |
+
"image_token": "<IMG_CONTEXT>",
|
10 |
+
"processor_class": "Eagle2_5_VLProcessor",
|
11 |
+
"tokens_per_tile": 256,
|
12 |
+
"video_placeholder": "video",
|
13 |
+
"video_token": "<IMG_CONTEXT>",
|
14 |
+
"vision_feature_select_strategy": null
|
15 |
+
}
|