Zhiding commited on
Commit
3ab0a53
·
1 Parent(s): c906038

upload model arch

Browse files
README.md CHANGED
@@ -18,7 +18,12 @@ tags:
18
  # Eagle-2
19
 
20
  [\[📂 GitHub\]](https://github.com/NVlabs/EAGLE) [\[📜 Eagle2 Tech Report\]](http://arxiv.org/abs/2501.14818)
21
- [\[🗨️ Chat Demo\]](http://eagle-vlm.xyz/) [\[🤗 HF Demo\]](TODO)
 
 
 
 
 
22
  ## Introduction
23
 
24
  We are thrilled to release our latest Eagle2 series Vision-Language Model. Open-source Vision-Language Models (VLMs) have made significant strides in narrowing the gap with proprietary models. However, critical details about data strategies and implementation are often missing, limiting reproducibility and innovation. In this project, we focus on VLM post-training from a data-centric perspective, sharing insights into building effective data strategies from scratch. By combining these strategies with robust training recipes and model design, we introduce Eagle2, a family of performant VLMs. Our work aims to empower the open-source community to develop competitive VLMs with transparent processes.
@@ -66,490 +71,290 @@ We provide the following models:
66
 
67
 
68
 
69
- We provide a [demo inference script](./demo.py) to help you quickly start using the model. We support different input types:
70
  - pure text input
71
  - single image input
72
  - multiple image input
73
  - video input
74
 
75
- ### 0. Install the dependencies
76
 
77
  ```bash
78
  pip install transformers
79
  pip install flash-attn
80
  ```
81
- **Note**: Latest version of transformers is not compatible with the model.
82
 
83
- ### 1. Prepare the Model worker
84
 
85
- <details>
86
- <summary>Click to expand</summary>
87
 
88
  ```python
89
-
90
- """
91
- A model worker executes the model.
92
- Copied and modified from https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py
93
- """
94
- # Importing torch before transformers can cause `segmentation fault`
95
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
96
-
97
- import argparse
98
- import base64
99
- import json
100
- import os
101
- import decord
102
- import threading
103
- import time
104
- from io import BytesIO
105
- from threading import Thread
106
- import math
107
  import requests
 
108
  import torch
109
- import torchvision.transforms as T
110
- from PIL import Image
111
- from torchvision.transforms.functional import InterpolationMode
112
- import numpy as np
113
-
114
-
115
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
116
- IMAGENET_STD = (0.229, 0.224, 0.225)
117
-
118
- SIGLIP_MEAN = (0.5, 0.5, 0.5)
119
- SIGLIP_STD = (0.5, 0.5, 0.5)
120
-
121
-
122
- def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
123
- """
124
- Calculate the indices of frames to extract from a video.
125
-
126
- Parameters:
127
- total_num_frames (int): Total number of frames in the video.
128
- desired_num_frames (int): Desired number of frames to extract.
129
-
130
- Returns:
131
- list: List of indices of frames to extract.
132
- """
133
-
134
- assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
135
-
136
- if stride > 0:
137
- return list(range(0, total_num_frames, stride))
138
-
139
- # Calculate the size of each segment from which a frame will be extracted
140
- seg_size = float(total_num_frames - 1) / desired_num_frames
141
-
142
- seq = []
143
- for i in range(desired_num_frames):
144
- # Calculate the start and end indices of each segment
145
- start = int(np.round(seg_size * i))
146
- end = int(np.round(seg_size * (i + 1)))
147
-
148
- # Append the middle index of the segment to the list
149
- seq.append((start + end) // 2)
150
-
151
- return seq
152
-
153
- def build_video_prompt(meta_list, num_frames, time_position=False):
154
- # if time_position is True, the frame_timestamp is used.
155
- # 1. pass time_position, 2. use env TIME_POSITION
156
- time_position = os.environ.get("TIME_POSITION", time_position)
157
- prefix = f"This is a video:\n"
158
- for i in range(num_frames):
159
- if time_position:
160
- frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
161
- else:
162
- frame_txt = f"Frame {i+1}: <image>\n"
163
- prefix += frame_txt
164
- return prefix
165
-
166
- def load_video(video_path, num_frames=64, frame_cache_root=None):
167
- if isinstance(video_path, str):
168
- video = decord.VideoReader(video_path)
169
- elif isinstance(video_path, dict):
170
- assert False, 'we not support vidoe: "video_path" as input'
171
- fps = video.get_avg_fps()
172
- sampled_frames = get_seq_frames(len(video), num_frames)
173
- samepld_timestamps = [i / fps for i in sampled_frames]
174
- frames = video.get_batch(sampled_frames).asnumpy()
175
- images = [Image.fromarray(frame) for frame in frames]
176
-
177
- return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
178
-
179
- def load_image(image):
180
- if isinstance(image, str) and os.path.exists(image):
181
- return Image.open(image)
182
- elif isinstance(image, dict):
183
- if 'disk_path' in image:
184
- return Image.open(image['disk_path'])
185
- elif 'base64' in image:
186
- return Image.open(BytesIO(base64.b64decode(image['base64'])))
187
- elif 'url' in image:
188
- response = requests.get(image['url'])
189
- return Image.open(BytesIO(response.content))
190
- elif 'bytes' in image:
191
- return Image.open(BytesIO(image['bytes']))
192
- else:
193
- raise ValueError(f'Invalid image: {image}')
194
- else:
195
- raise ValueError(f'Invalid image: {image}')
196
-
197
- def build_transform(input_size, norm_type='imagenet'):
198
- if norm_type == 'imagenet':
199
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
200
- elif norm_type == 'siglip':
201
- MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
202
-
203
- transform = T.Compose([
204
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
205
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
206
- T.ToTensor(),
207
- T.Normalize(mean=MEAN, std=STD)
208
- ])
209
- return transform
210
-
211
-
212
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
213
- """
214
- previous version mainly foucs on ratio.
215
- We also consider area ratio here.
216
- """
217
- best_factor = float('-inf')
218
- best_ratio = (1, 1)
219
- area = width * height
220
- for ratio in target_ratios:
221
- target_aspect_ratio = ratio[0] / ratio[1]
222
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
223
- area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
224
- """
225
- new area > 60% of original image area is enough.
226
- """
227
- factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
228
- min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
229
-
230
- if factor_based_on_area_n_ratio > best_factor:
231
- best_factor = factor_based_on_area_n_ratio
232
- best_ratio = ratio
233
-
234
- return best_ratio
235
-
236
-
237
- def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
238
- orig_width, orig_height = image.size
239
- aspect_ratio = orig_width / orig_height
240
-
241
- # calculate the existing image aspect ratio
242
- target_ratios = set(
243
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
244
- i * j <= max_num and i * j >= min_num)
245
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
246
-
247
- # find the closest aspect ratio to the target
248
- target_aspect_ratio = find_closest_aspect_ratio(
249
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
250
-
251
- # calculate the target width and height
252
- target_width = image_size * target_aspect_ratio[0]
253
- target_height = image_size * target_aspect_ratio[1]
254
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
255
-
256
- # resize the image
257
- resized_img = image.resize((target_width, target_height))
258
- processed_images = []
259
- for i in range(blocks):
260
- box = (
261
- (i % (target_width // image_size)) * image_size,
262
- (i // (target_width // image_size)) * image_size,
263
- ((i % (target_width // image_size)) + 1) * image_size,
264
- ((i // (target_width // image_size)) + 1) * image_size
265
- )
266
- # split the image
267
- split_img = resized_img.crop(box)
268
- processed_images.append(split_img)
269
- assert len(processed_images) == blocks
270
- if use_thumbnail and len(processed_images) != 1:
271
- thumbnail_img = image.resize((image_size, image_size))
272
- processed_images.append(thumbnail_img)
273
- return processed_images
274
-
275
- def split_model(model_path, device):
276
-
277
- device_map = {}
278
- world_size = torch.cuda.device_count()
279
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
280
- num_layers = config.llm_config.num_hidden_layers
281
-
282
- print('world_size', world_size)
283
- num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
284
- num_layers_per_gpu = [num_layers_per_gpu_] * world_size
285
- num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
286
- print(num_layers_per_gpu)
287
- layer_cnt = 0
288
- for i, num_layer in enumerate(num_layers_per_gpu):
289
- for j in range(num_layer):
290
- device_map[f'language_model.model.layers.{layer_cnt}'] = i
291
- layer_cnt += 1
292
- device_map['vision_model'] = device
293
- device_map['mlp1'] = device
294
- device_map['language_model.model.tok_embeddings'] = device
295
- device_map['language_model.model.embed_tokens'] = device
296
- device_map['language_model.output'] = device
297
- device_map['language_model.model.norm'] = device
298
- device_map['language_model.lm_head'] = device
299
- device_map['language_model.model.rotary_emb'] = device
300
- device_map[f'language_model.model.layers.{num_layers - 1}'] = device
301
- return device_map
302
-
303
- class ModelWorker:
304
- def __init__(self, model_path, model_name,
305
- load_8bit, device):
306
-
307
- if model_path.endswith('/'):
308
- model_path = model_path[:-1]
309
- if model_name is None:
310
- model_paths = model_path.split('/')
311
- if model_paths[-1].startswith('checkpoint-'):
312
- self.model_name = model_paths[-2] + '_' + model_paths[-1]
313
- else:
314
- self.model_name = model_paths[-1]
315
- else:
316
- self.model_name = model_name
317
-
318
- print(f'Loading the model {self.model_name}')
319
-
320
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
321
- tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
322
- tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
323
- self.tokenizer = tokenizer
324
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
325
- model_type = config.vision_config.model_type
326
- self.device = torch.cuda.current_device()
327
- if model_type == 'siglip_vision_model':
328
- self.norm_type = 'siglip'
329
- elif model_type == 'MOB':
330
- self.norm_type = 'siglip'
331
- else:
332
- self.norm_type = 'imagenet'
333
-
334
- if any(x in model_path.lower() for x in ['34b']):
335
- device_map = split_model(model_path, self.device)
336
- else:
337
- device_map = None
338
-
339
- if device_map is not None:
340
- self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
341
- low_cpu_mem_usage=True,
342
- device_map=device_map,
343
- trust_remote_code=True,
344
- load_in_8bit=load_8bit).eval()
345
- else:
346
- self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
347
- trust_remote_code=True,
348
- load_in_8bit=load_8bit).eval()
349
-
350
- if not load_8bit and device_map is None:
351
- self.model = self.model.to(device)
352
- self.load_8bit = load_8bit
353
-
354
- self.model_path = model_path
355
- self.image_size = self.model.config.force_image_size
356
- self.context_len = tokenizer.model_max_length
357
- self.per_tile_len = 256
358
-
359
- def reload_model(self):
360
- del self.model
361
- torch.cuda.empty_cache()
362
- if self.device == 'auto':
363
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
364
- # This can make distributed deployment work properly
365
- self.model = AutoModel.from_pretrained(
366
- self.model_path,
367
- load_in_8bit=self.load_8bit,
368
- torch_dtype=torch.bfloat16,
369
- device_map=self.device_map,
370
- trust_remote_code=True).eval()
371
- else:
372
- self.model = AutoModel.from_pretrained(
373
- self.model_path,
374
- load_in_8bit=self.load_8bit,
375
- torch_dtype=torch.bfloat16,
376
- trust_remote_code=True).eval()
377
- if not self.load_8bit and not self.device == 'auto':
378
- self.model = self.model.cuda()
379
-
380
- @torch.inference_mode()
381
- def generate(self, params):
382
- system_message = params['prompt'][0]['content']
383
- send_messages = params['prompt'][1:]
384
- max_input_tiles = params['max_input_tiles']
385
- temperature = params['temperature']
386
- top_p = params['top_p']
387
- max_new_tokens = params['max_new_tokens']
388
- repetition_penalty = params['repetition_penalty']
389
- video_frame_num = params.get('video_frame_num', 64)
390
- do_sample = True if temperature > 0.0 else False
391
-
392
- global_image_cnt = 0
393
- history, pil_images, max_input_tile_list = [], [], []
394
- for message in send_messages:
395
- if message['role'] == 'user':
396
- prefix = ''
397
- if 'image' in message:
398
- for image_data in message['image']:
399
- pil_images.append(load_image(image_data))
400
- prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
401
- global_image_cnt += 1
402
- max_input_tile_list.append(max_input_tiles)
403
- if 'video' in message:
404
- for video_data in message['video']:
405
- video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
406
- pil_images.extend(video_frames)
407
- prefix = prefix + tmp_prefix
408
- global_image_cnt += len(video_frames)
409
- max_input_tile_list.extend([1] * len(video_frames))
410
- content = prefix + message['content']
411
- history.append([content, ])
412
- else:
413
- history[-1].append(message['content'])
414
- question, history = history[-1][0], history[:-1]
415
-
416
- if global_image_cnt == 1:
417
- question = question.replace('<image 1><image>\n', '<image>\n')
418
- history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
419
-
420
-
421
- try:
422
- assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
423
- except Exception as e:
424
- from IPython import embed; embed()
425
- exit()
426
- print(f'Error: {e}')
427
- print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
428
- # raise e
429
-
430
- old_system_message = self.model.system_message
431
- self.model.system_message = system_message
432
-
433
- transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
434
- if len(pil_images) > 0:
435
- max_input_tiles_limited_by_contect = params['max_input_tiles']
436
- while True:
437
- image_tiles = []
438
- for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
439
- if self.model.config.dynamic_image_size:
440
- tiles = dynamic_preprocess(
441
- pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
442
- use_thumbnail=self.model.config.use_thumbnail)
443
- else:
444
- tiles = [pil_image]
445
- image_tiles += tiles
446
- if (len(image_tiles) * self.per_tile_len < self.context_len):
447
- break
448
- else:
449
- max_input_tiles_limited_by_contect -= 2
450
-
451
- if max_input_tiles_limited_by_contect < 1:
452
- break
453
-
454
- pixel_values = [transform(item) for item in image_tiles]
455
- pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
456
- print(f'Split images to {pixel_values.shape}')
457
- else:
458
- pixel_values = None
459
-
460
- generation_config = dict(
461
- num_beams=1,
462
- max_new_tokens=max_new_tokens,
463
- do_sample=do_sample,
464
- temperature=temperature,
465
- repetition_penalty=repetition_penalty,
466
- max_length=self.context_len,
467
- top_p=top_p,
468
- )
469
-
470
- response = self.model.chat(
471
- tokenizer=self.tokenizer,
472
- pixel_values=pixel_values,
473
- question=question,
474
- history=history,
475
- return_history=False,
476
- generation_config=generation_config,
477
- )
478
- self.model.system_message = old_system_message
479
- return {'text': response, 'error_code': 0}
480
-
481
-
482
-
483
-
484
-
485
- if __name__ == '__main__':
486
- parser = argparse.ArgumentParser()
487
- parser.add_argument('--model-path', type=str, default='nvidia/Eagle2-2B')
488
- parser.add_argument('--model-name', type=str, default='Eagle2-2B')
489
- parser.add_argument('--device', type=str, default='cuda')
490
- parser.add_argument('--load-8bit', action='store_true')
491
- args = parser.parse_args()
492
- print(f'args: {args}')
493
-
494
- worker = ModelWorker(
495
- args.model_path,
496
- args.model_name,
497
- args.load_8bit,
498
- args.device)
499
  ```
500
- </details>
 
 
 
 
 
 
 
 
 
 
501
 
502
 
503
- ### 2. Prepare the Prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
 
505
- - Single image input
506
  ```python
507
- prompt = [
508
- {'role': 'system', 'content': 'You are a helpful assistant.'},
509
- {'role': 'user', 'content': 'Describe this image in details.',
510
- 'image':[
511
- {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/01-nvidia-logo-vert-500x200-2c50-d@2x.png'}
512
- ],
513
- }
514
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  ```
516
 
517
- - Multiple image input
 
518
  ```python
519
- prompt = [
520
- {'role': 'system', 'content': 'You are a helpful assistant.'},
521
- {'role': 'user', 'content': 'Describe these two images in details.',
522
- 'image':[
523
- {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'},
524
- {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/01-nvidia-logo-vert-500x200-2c50-d@2x.png'}
525
- ],
526
- }
527
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  ```
529
 
530
- - Video input
 
531
  ```python
532
- prompt = [
533
- {'role': 'system', 'content': 'You are a helpful assistant.'},
534
- {'role': 'user', 'content': 'Describe this video in details.',
535
- 'video':[
536
- 'path/to/your/video.mp4'
537
- ],
538
- }
539
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  ```
541
 
542
- ### 3. Generate the response
 
543
  ```python
544
- params = {
545
- 'prompt': prompt,
546
- 'max_input_tiles': 24,
547
- 'temperature': 0.7,
548
- 'top_p': 1.0,
549
- 'max_new_tokens': 4096,
550
- 'repetition_penalty': 1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  }
552
- worker.generate(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  ```
554
 
555
  ## TODO
 
18
  # Eagle-2
19
 
20
  [\[📂 GitHub\]](https://github.com/NVlabs/EAGLE) [\[📜 Eagle2 Tech Report\]](http://arxiv.org/abs/2501.14818)
21
+ [\[🤗 HF Demo\]](https://huggingface.co/spaces/nvidia/Eagle2-Demo)
22
+
23
+ # News:
24
+ - We update the model arch to `eagle_2_5_vl` to support `generate` feature.
25
+
26
+
27
  ## Introduction
28
 
29
  We are thrilled to release our latest Eagle2 series Vision-Language Model. Open-source Vision-Language Models (VLMs) have made significant strides in narrowing the gap with proprietary models. However, critical details about data strategies and implementation are often missing, limiting reproducibility and innovation. In this project, we focus on VLM post-training from a data-centric perspective, sharing insights into building effective data strategies from scratch. By combining these strategies with robust training recipes and model design, we introduce Eagle2, a family of performant VLMs. Our work aims to empower the open-source community to develop competitive VLMs with transparent processes.
 
71
 
72
 
73
 
74
+ We provide a [inference script](./demo.py) to help you quickly start using the model. We support different input types:
75
  - pure text input
76
  - single image input
77
  - multiple image input
78
  - video input
79
 
80
+ ### Install the dependencies
81
 
82
  ```bash
83
  pip install transformers
84
  pip install flash-attn
85
  ```
 
86
 
 
87
 
88
+ ### single image
 
89
 
90
  ```python
91
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  import requests
93
+ from transformers import AutoProcessor, AutoModel
94
  import torch
95
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
96
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
97
+ processor.tokenizer.padding_side = "left"
98
+
99
+ messages = [
100
+ {
101
+ "role": "user",
102
+ "content": [
103
+ {
104
+ "type": "image",
105
+ "image": "https://www.ilankelman.org/stopsigns/australia.jpg",
106
+ },
107
+ {"type": "text", "text": "Describe this image."},
108
+ ],
109
+ }
110
+ ]
111
+
112
+ text_list = [processor.apply_chat_template(
113
+ messages, tokenize=False, add_generation_prompt=True
114
+ )]
115
+ image_inputs, video_inputs = processor.process_vision_info(messages)
116
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
117
+ inputs = inputs.to("cuda")
118
+ model = model.to("cuda")
119
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
120
+ output_text = processor.batch_decode(
121
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
122
+ )
123
+ print(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ```
125
+
126
+ ### stream generation
127
+
128
+ ```python
129
+ from PIL import Image
130
+ import requests
131
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer
132
+ import torch
133
+
134
+ from transformers import TextIteratorStreamer
135
+ import threading
136
 
137
 
138
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16)
139
+ tokenizer = AutoTokenizer.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
140
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
141
+ processor.tokenizer.padding_side = "left"
142
+
143
+ messages = [
144
+ {
145
+ "role": "user",
146
+ "content": [
147
+ {
148
+ "type": "image",
149
+ "image": "https://www.ilankelman.org/stopsigns/australia.jpg",
150
+ },
151
+ {"type": "text", "text": "Describe this image."},
152
+ ],
153
+ }
154
+ ]
155
+
156
+ text_list = [processor.apply_chat_template(
157
+ messages, tokenize=False, add_generation_prompt=True
158
+ )]
159
+ image_inputs, video_inputs = processor.process_vision_info(messages)
160
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
161
+ inputs = inputs.to("cuda")
162
+ model = model.to("cuda")
163
+
164
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
165
+
166
+ generation_kwargs = dict(
167
+ **inputs,
168
+ streamer=streamer,
169
+ max_new_tokens=1024,
170
+ do_sample=True,
171
+ top_p=0.95,
172
+ temperature=0.8
173
+ )
174
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
175
+ thread.start()
176
+
177
+
178
+ for new_text in streamer:
179
+ print(new_text, end="", flush=True)
180
+ ```
181
+
182
+ ### multiple-images
183
 
 
184
  ```python
185
+ from PIL import Image
186
+ import requests
187
+ from transformers import AutoProcessor, AutoModel
188
+ import torch
189
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
190
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
191
+ processor.tokenizer.padding_side = "left"
192
+
193
+ messages = [
194
+ {
195
+ "role": "user",
196
+ "content": [
197
+ {
198
+ "type": "image",
199
+ "image": "https://www.ilankelman.org/stopsigns/australia.jpg",
200
+ },
201
+ {
202
+ "type": "image",
203
+ "image": "https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]",
204
+ },
205
+ {"type": "text", "text": "Describe these two images."},
206
+ ],
207
+ }
208
+ ]
209
+
210
+ text_list = [processor.apply_chat_template(
211
+ messages, tokenize=False, add_generation_prompt=True
212
+ )]
213
+ image_inputs, video_inputs = processor.process_vision_info(messages)
214
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
215
+ inputs = inputs.to("cuda")
216
+ model = model.to("cuda")
217
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
218
+ output_text = processor.batch_decode(
219
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
220
+ )
221
+ print(output_text)
222
  ```
223
 
224
+ ### single video
225
+
226
  ```python
227
+
228
+ from PIL import Image
229
+ import requests
230
+ from transformers import AutoProcessor, AutoModel
231
+ import torch
232
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
233
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
234
+ processor.tokenizer.padding_side = "left"
235
+
236
+ messages = [
237
+ {
238
+ "role": "user",
239
+ "content": [
240
+ {
241
+ "type": "video",
242
+ "video": "../Eagle2-8B/space_woaudio.mp4",
243
+ },
244
+ {"type": "text", "text": "Describe this video."},
245
+ ],
246
+ }
247
+ ]
248
+
249
+ text_list = [processor.apply_chat_template(
250
+ messages, tokenize=False, add_generation_prompt=True
251
+ )]
252
+ image_inputs, video_inputs, video_kwargs = processor.process_vision_info(messages, return_video_kwargs=True)
253
+
254
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True, videos_kwargs=video_kwargs)
255
+ inputs = inputs.to("cuda")
256
+ model = model.to("cuda")
257
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
258
+ output_text = processor.batch_decode(
259
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
260
+ )
261
+ print(output_text)
262
+
263
  ```
264
 
265
+ ### multieple videos
266
+
267
  ```python
268
+ from PIL import Image
269
+ import requests
270
+ from transformers import AutoProcessor, AutoModel
271
+ import torch
272
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
273
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
274
+ processor.tokenizer.padding_side = "left"
275
+
276
+ messages = [
277
+ {
278
+ "role": "user",
279
+ "content": [
280
+ {
281
+ "type": "video",
282
+ "video": "../Eagle2-8B/space_woaudio.mp4",
283
+ "nframes": 10,
284
+ },
285
+ {
286
+ "type": "video",
287
+ "video": "../Eagle2-8B/video_ocr.mp4",
288
+ "nframes": 10,
289
+ },
290
+ {"type": "text", "text": "Describe these two videos respectively."},
291
+ ],
292
+ }
293
+ ]
294
+
295
+ text_list = [processor.apply_chat_template(
296
+ messages, tokenize=False, add_generation_prompt=True
297
+ )]
298
+ image_inputs, video_inputs, video_kwargs = processor.process_vision_info(messages, return_video_kwargs=True)
299
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True, videos_kwargs=video_kwargs)
300
+ inputs = inputs.to("cuda")
301
+ model = model.to("cuda")
302
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
303
+ output_text = processor.batch_decode(
304
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
305
+ )
306
+ print(output_text)
307
  ```
308
 
309
+ ### batch inference
310
+
311
  ```python
312
+ from PIL import Image
313
+ import requests
314
+ from transformers import AutoProcessor, AutoModel
315
+ import torch
316
+ model = AutoModel.from_pretrained("nvidia/Eagle2-1B",trust_remote_code=True, torch_dtype=torch.bfloat16)
317
+ processor = AutoProcessor.from_pretrained("nvidia/Eagle2-1B", trust_remote_code=True, use_fast=True)
318
+ processor.tokenizer.padding_side = "left"
319
+
320
+ messages1 = [
321
+ {
322
+ "role": "user",
323
+ "content": [
324
+ {
325
+ "type": "image",
326
+ "image": "https://www.ilankelman.org/stopsigns/australia.jpg",
327
+ },
328
+ {"type": "text", "text": "Describe this image."},
329
+ ],
330
+ }
331
+ ]
332
+
333
+ messages2 = [
334
+ {
335
+ "role": "user",
336
+ "content": [
337
+ {
338
+ "type": "image",
339
+ "image": "https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]",
340
+ },
341
+ {"type": "text", "text": "Describe this image."},
342
+ ],
343
  }
344
+ ]
345
+
346
+ text_list = [processor.apply_chat_template(
347
+ messages, tokenize=False, add_generation_prompt=True
348
+ ) for messages in [messages1, messages2]]
349
+ image_inputs, video_inputs = processor.process_vision_info([messages1, messages2])
350
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
351
+ inputs = inputs.to("cuda")
352
+ model = model.to("cuda")
353
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
354
+ output_text = processor.batch_decode(
355
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
356
+ )
357
+ print(output_text)
358
  ```
359
 
360
  ## TODO
chat_template.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}<image {{ image_count.value }}>{% endif %}<image-{{ image_count.value }}>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}<video {{ video_count.value }}>{% endif %}<video-{{ video_count.value }}>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"}
config.json CHANGED
@@ -1,32 +1,33 @@
1
  {
2
- "_commit_hash": null,
 
3
  "_name_or_path": "",
4
  "architectures": [
5
- "Eagle2ChatModel"
6
  ],
7
  "auto_map": {
8
- "AutoConfig": "configuration_eagle_chat.Eagle2ChatConfig",
9
- "AutoModel": "modeling_eagle_chat.Eagle2ChatModel",
10
- "AutoModelForCausalLM": "modeling_eagle_chat.Eagle2ChatModel"
11
- },
12
  "downsample_ratio": 0.5,
13
  "dynamic_image_size": true,
14
- "efficient_loss": true,
15
  "force_image_size": 448,
16
- "keep_aspect_ratio": false,
17
- "llm_config": {
 
 
 
 
 
 
 
 
18
  "_name_or_path": "./pretrained/Qwen2_5-1_5B-Instruct",
19
  "add_cross_attention": false,
20
  "architectures": [
21
  "Qwen2ForCausalLM"
22
  ],
23
  "attention_dropout": 0.0,
24
- "attn_implementation": "flash_attention_2",
25
- "auto_map": {
26
- "AutoConfig": "configuration_qwen2.Qwen2Config",
27
- "AutoModel": "modeling_qwen2.Qwen2Model",
28
- "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
29
- },
30
  "bad_words_ids": null,
31
  "begin_suppress_tokens": null,
32
  "bos_token_id": 151643,
@@ -102,105 +103,25 @@
102
  "use_sliding_window": false,
103
  "vocab_size": 151674
104
  },
105
- "loss_version": "v4",
106
- "max_dynamic_patch": 12,
107
- "min_dynamic_patch": 1,
108
- "mlp_checkpoint": true,
109
- "model_type": "eagle_chat",
110
- "pad2square": false,
111
- "pre_feature_reduction": false,
112
- "ps_version": "v2",
113
- "select_layer": -1,
114
- "template": "qwen2-chat",
115
  "torch_dtype": "bfloat16",
116
- "transformers_version": null,
117
  "use_backbone_lora": 0,
118
  "use_llm_lora": 0,
119
  "use_thumbnail": true,
120
  "vision_config": {
121
- "_name_or_path": "",
122
- "add_cross_attention": false,
123
- "architectures": [
124
- "SiglipVisionModel"
125
- ],
126
  "attention_dropout": 0.0,
127
- "auto_map": {
128
- "AutoConfig": "configuration_siglip.SiglipVisionConfig",
129
- "AutoModel": "modeling_siglip.SiglipVisionModel"
130
- },
131
- "bad_words_ids": null,
132
- "begin_suppress_tokens": null,
133
- "bos_token_id": null,
134
- "chunk_size_feed_forward": 0,
135
- "cross_attention_hidden_size": null,
136
- "decoder_start_token_id": null,
137
- "diversity_penalty": 0.0,
138
- "do_sample": false,
139
  "drop_path_rate": 0.1,
140
- "early_stopping": false,
141
- "encoder_no_repeat_ngram_size": 0,
142
- "eos_token_id": null,
143
- "exponential_decay_length_penalty": null,
144
- "finetuning_task": null,
145
- "forced_bos_token_id": null,
146
- "forced_eos_token_id": null,
147
  "hidden_act": "gelu_pytorch_tanh",
148
  "hidden_size": 1152,
149
- "id2label": {
150
- "0": "LABEL_0",
151
- "1": "LABEL_1"
152
- },
153
  "image_size": 448,
154
  "intermediate_size": 4304,
155
- "is_decoder": false,
156
- "is_encoder_decoder": false,
157
- "label2id": {
158
- "LABEL_0": 0,
159
- "LABEL_1": 1
160
- },
161
  "layer_norm_eps": 1e-06,
162
- "length_penalty": 1.0,
163
- "max_length": 20,
164
- "min_length": 0,
165
  "model_type": "siglip_vision_model",
166
- "no_repeat_ngram_size": 0,
167
  "num_attention_heads": 16,
168
- "num_beam_groups": 1,
169
- "num_beams": 1,
170
  "num_channels": 3,
171
  "num_hidden_layers": 27,
172
- "num_image_tokens": 1024,
173
- "num_return_sequences": 1,
174
- "output_attentions": false,
175
- "output_hidden_states": false,
176
- "output_scores": false,
177
- "pad_token_id": null,
178
  "patch_size": 14,
179
- "prefix": null,
180
- "problem_type": null,
181
- "projection_dim": 2048,
182
- "projector_hidden_act": "gelu_fast",
183
- "pruned_heads": {},
184
- "remove_invalid_values": false,
185
- "repetition_penalty": 1.0,
186
- "return_dict": true,
187
- "return_dict_in_generate": false,
188
- "sep_token_id": null,
189
- "suppress_tokens": null,
190
- "task_specific_params": null,
191
- "temperature": 1.0,
192
- "tf_legacy_loss": false,
193
- "tie_encoder_decoder": false,
194
- "tie_word_embeddings": true,
195
- "tokenizer_class": null,
196
- "top_k": 50,
197
- "top_p": 1.0,
198
- "torch_dtype": "bfloat16",
199
- "torchscript": false,
200
- "transformers_version": "4.37.2",
201
- "typical_p": 1.0,
202
- "use_bfloat16": false,
203
- "vision_use_head": false,
204
- "_attn_implementation": "flash_attention_2"
205
- }
206
  }
 
1
  {
2
+ "_attn_implementation": "flash_attention_2",
3
+ "_attn_implementation_autoset": false,
4
  "_name_or_path": "",
5
  "architectures": [
6
+ "Eagle2_5_VLForConditionalGeneration"
7
  ],
8
  "auto_map": {
9
+ "AutoConfig": "configuration_eagle2_5_vl.Eagle2_5_VLConfig",
10
+ "AutoModel": "modeling_eagle2_5_vl.Eagle2_5_VLForConditionalGeneration"
11
+ },
 
12
  "downsample_ratio": 0.5,
13
  "dynamic_image_size": true,
 
14
  "force_image_size": 448,
15
+ "image_token_index": 151667,
16
+ "max_dynamic_tiles": 12,
17
+ "min_dynamic_tiles": 1,
18
+ "mlp_checkpoint": false,
19
+ "model_type": "eagle_2_5_vl",
20
+ "pad2square": false,
21
+ "pre_feature_reduction": false,
22
+ "select_layer": -1,
23
+ "template": "qwen2-chat",
24
+ "text_config": {
25
  "_name_or_path": "./pretrained/Qwen2_5-1_5B-Instruct",
26
  "add_cross_attention": false,
27
  "architectures": [
28
  "Qwen2ForCausalLM"
29
  ],
30
  "attention_dropout": 0.0,
 
 
 
 
 
 
31
  "bad_words_ids": null,
32
  "begin_suppress_tokens": null,
33
  "bos_token_id": 151643,
 
103
  "use_sliding_window": false,
104
  "vocab_size": 151674
105
  },
106
+ "tie_word_embeddings": true,
 
 
 
 
 
 
 
 
 
107
  "torch_dtype": "bfloat16",
108
+ "transformers_version": "4.51.0",
109
  "use_backbone_lora": 0,
110
  "use_llm_lora": 0,
111
  "use_thumbnail": true,
112
  "vision_config": {
 
 
 
 
 
113
  "attention_dropout": 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
114
  "drop_path_rate": 0.1,
 
 
 
 
 
 
 
115
  "hidden_act": "gelu_pytorch_tanh",
116
  "hidden_size": 1152,
 
 
 
 
117
  "image_size": 448,
118
  "intermediate_size": 4304,
 
 
 
 
 
 
119
  "layer_norm_eps": 1e-06,
 
 
 
120
  "model_type": "siglip_vision_model",
 
121
  "num_attention_heads": 16,
 
 
122
  "num_channels": 3,
123
  "num_hidden_layers": 27,
 
 
 
 
 
 
124
  "patch_size": 14,
125
+ "torch_dtype": "bfloat16"
126
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  }
configuration_eagle_chat.py → configuration_eagle2_5_vl.py RENAMED
@@ -1,80 +1,88 @@
1
  # --------------------------------------------------------
2
- # Eagle2
3
  # Copyright (c) 2025 NVIDIA
4
- # Licensed under The Apache License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
7
  import copy
8
 
9
- from transformers import AutoConfig, LlamaConfig
 
10
  from transformers.configuration_utils import PretrainedConfig
11
  from transformers.utils import logging
12
  from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
13
- from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
14
  logger = logging.get_logger(__name__)
15
 
16
- class Eagle2ChatConfig(PretrainedConfig):
17
- model_type = 'eagle_chat'
18
- is_composition = True
19
 
 
 
 
 
20
  def __init__(
21
  self,
22
  vision_config=None,
23
- llm_config=None,
24
  use_backbone_lora=0,
25
  use_llm_lora=0,
26
- select_layer=-1,
 
27
  force_image_size=None,
28
  downsample_ratio=0.5,
29
  template=None,
30
  dynamic_image_size=False,
31
  use_thumbnail=False,
32
- min_dynamic_patch=1,
33
- max_dynamic_patch=6,
34
- mlp_checkpoint=True,
35
- pre_feature_reduction=False,
36
- keep_aspect_ratio=False,
37
- vocab_size=-1,
 
 
 
38
  **kwargs):
39
  super().__init__(**kwargs)
40
 
41
  if vision_config is None:
42
- vision_config = {}
43
- logger.info('vision_config is None. Initializing Vision Encoders with default values.')
44
 
45
- if llm_config is None:
46
- llm_config = {}
47
- logger.info('llm_config is None. Initializing the LLM config with default values')
48
 
49
  if vision_config['model_type'] == 'siglip_vision_model':
50
  self.vision_config = SiglipVisionConfig(**vision_config)
51
  else:
52
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
53
 
54
- if llm_config['architectures'][0] == 'LlamaForCausalLM':
55
- self.llm_config = LlamaConfig(**llm_config)
56
- elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
57
- self.llm_config = Qwen2Config(**llm_config)
 
58
  else:
59
- raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
60
  self.use_backbone_lora = use_backbone_lora
61
  self.use_llm_lora = use_llm_lora
 
 
62
  self.select_layer = select_layer
63
  self.force_image_size = force_image_size
64
  self.downsample_ratio = downsample_ratio
65
  self.template = template
66
  self.dynamic_image_size = dynamic_image_size
67
  self.use_thumbnail = use_thumbnail
68
- self.min_dynamic_patch = min_dynamic_patch
69
- self.max_dynamic_patch = max_dynamic_patch
70
- self.mlp_checkpoint = mlp_checkpoint
71
- self.pre_feature_reduction = pre_feature_reduction
72
- self.keep_aspect_ratio = keep_aspect_ratio
73
- self.vocab_size = self.llm_config.vocab_size
74
- logger.info(f'keep_aspect_ratio: {self.keep_aspect_ratio}')
75
- logger.info(f'vision_select_layer: {self.select_layer}')
76
- logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
77
- logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
78
 
79
  def to_dict(self):
80
  """
@@ -85,18 +93,21 @@ class Eagle2ChatConfig(PretrainedConfig):
85
  """
86
  output = copy.deepcopy(self.__dict__)
87
  output['vision_config'] = self.vision_config.to_dict()
88
- output['llm_config'] = self.llm_config.to_dict()
89
  output['model_type'] = self.__class__.model_type
90
  output['use_backbone_lora'] = self.use_backbone_lora
91
  output['use_llm_lora'] = self.use_llm_lora
 
92
  output['select_layer'] = self.select_layer
93
  output['force_image_size'] = self.force_image_size
94
  output['downsample_ratio'] = self.downsample_ratio
95
  output['template'] = self.template
96
  output['dynamic_image_size'] = self.dynamic_image_size
97
  output['use_thumbnail'] = self.use_thumbnail
98
- output['min_dynamic_patch'] = self.min_dynamic_patch
99
- output['max_dynamic_patch'] = self.max_dynamic_patch
100
- output['keep_aspect_ratio'] = self.keep_aspect_ratio
 
 
101
 
102
  return output
 
1
  # --------------------------------------------------------
2
+ # NVIDIA
3
  # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
7
  import copy
8
 
9
+ from transformers.models.llama.configuration_llama import LlamaConfig
10
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
11
  from transformers.configuration_utils import PretrainedConfig
12
  from transformers.utils import logging
13
  from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
 
14
  logger = logging.get_logger(__name__)
15
 
 
 
 
16
 
17
+ class Eagle2_5_VLConfig(PretrainedConfig):
18
+ model_type = 'eagle_2_5_vl'
19
+ is_composition = True
20
+ sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
21
  def __init__(
22
  self,
23
  vision_config=None,
24
+ text_config=None,
25
  use_backbone_lora=0,
26
  use_llm_lora=0,
27
+ pad2square=False,
28
+ select_layer=-4,
29
  force_image_size=None,
30
  downsample_ratio=0.5,
31
  template=None,
32
  dynamic_image_size=False,
33
  use_thumbnail=False,
34
+ loss_version='v1',
35
+ min_dynamic_tiles=1,
36
+ max_dynamic_tiles=6,
37
+ mlp_checkpoint=False,
38
+ initializer_range=0.02,
39
+ _attn_implementation='flash_attention_2',
40
+ _attn_implementation_autoset=False,
41
+ llm_config=None,
42
+ image_token_index=None,
43
  **kwargs):
44
  super().__init__(**kwargs)
45
 
46
  if vision_config is None:
47
+ vision_config = {'model_type': 'siglip_vision_model'}
48
+ logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
49
 
50
+ if text_config is None:
51
+ text_config = {'architectures': ['Qwen2ForCausalLM']}
52
+ logger.info('text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
53
 
54
  if vision_config['model_type'] == 'siglip_vision_model':
55
  self.vision_config = SiglipVisionConfig(**vision_config)
56
  else:
57
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
58
 
59
+
60
+ if text_config['architectures'][0] == 'LlamaForCausalLM':
61
+ self.text_config = LlamaConfig(**text_config)
62
+ elif text_config['architectures'][0] == 'Qwen2ForCausalLM':
63
+ self.text_config = Qwen2Config(**text_config)
64
  else:
65
+ raise ValueError('Unsupported architecture: {}'.format(text_config['architectures'][0]))
66
  self.use_backbone_lora = use_backbone_lora
67
  self.use_llm_lora = use_llm_lora
68
+ self.mlp_checkpoint = mlp_checkpoint
69
+ self.pad2square = pad2square
70
  self.select_layer = select_layer
71
  self.force_image_size = force_image_size
72
  self.downsample_ratio = downsample_ratio
73
  self.template = template
74
  self.dynamic_image_size = dynamic_image_size
75
  self.use_thumbnail = use_thumbnail
76
+ self.loss_version = loss_version
77
+ self.initializer_range = initializer_range
78
+ self.min_dynamic_tiles = min_dynamic_tiles
79
+ self.max_dynamic_tiles = max_dynamic_tiles
80
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
81
+ self._attn_implementation = _attn_implementation
82
+ self._attn_implementation_autoset = _attn_implementation_autoset
83
+ self.image_token_index = image_token_index
84
+ logger.info(f'min_dynamic_tiles: {self.min_dynamic_tiles}')
85
+ logger.info(f'max_dynamic_tiles: {self.max_dynamic_tiles}')
86
 
87
  def to_dict(self):
88
  """
 
93
  """
94
  output = copy.deepcopy(self.__dict__)
95
  output['vision_config'] = self.vision_config.to_dict()
96
+ output['text_config'] = self.text_config.to_dict()
97
  output['model_type'] = self.__class__.model_type
98
  output['use_backbone_lora'] = self.use_backbone_lora
99
  output['use_llm_lora'] = self.use_llm_lora
100
+ output['pad2square'] = self.pad2square
101
  output['select_layer'] = self.select_layer
102
  output['force_image_size'] = self.force_image_size
103
  output['downsample_ratio'] = self.downsample_ratio
104
  output['template'] = self.template
105
  output['dynamic_image_size'] = self.dynamic_image_size
106
  output['use_thumbnail'] = self.use_thumbnail
107
+ output['min_dynamic_tiles'] = self.min_dynamic_tiles
108
+ output['max_dynamic_tiles'] = self.max_dynamic_tiles
109
+ output['tie_word_embeddings'] = self.tie_word_embeddings
110
+ output['_attn_implementation'] = self._attn_implementation
111
+ output['_attn_implementation_autoset'] = self._attn_implementation_autoset
112
 
113
  return output
demo.py CHANGED
@@ -1,428 +1,55 @@
1
-
2
- """
3
- A model worker executes the model.
4
- """
5
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer, AutoConfig
6
- import argparse
7
- import base64
8
- import json
9
- import os
10
- import decord
11
- import threading
12
- import time
13
- from io import BytesIO
14
- from threading import Thread
15
- import math
16
  import requests
 
17
  import torch
18
- import torchvision.transforms as T
19
- from PIL import Image
20
- from torchvision.transforms.functional import InterpolationMode
21
-
22
- import numpy as np
23
-
24
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
25
- IMAGENET_STD = (0.229, 0.224, 0.225)
26
-
27
- SIGLIP_MEAN = (0.5, 0.5, 0.5)
28
- SIGLIP_STD = (0.5, 0.5, 0.5)
29
-
30
-
31
- def get_seq_frames(total_num_frames, desired_num_frames=-1, stride=-1):
32
- """
33
- Calculate the indices of frames to extract from a video.
34
-
35
- Parameters:
36
- total_num_frames (int): Total number of frames in the video.
37
- desired_num_frames (int): Desired number of frames to extract.
38
-
39
- Returns:
40
- list: List of indices of frames to extract.
41
- """
42
-
43
- assert desired_num_frames > 0 or stride > 0 and not (desired_num_frames > 0 and stride > 0)
44
-
45
- if stride > 0:
46
- return list(range(0, total_num_frames, stride))
47
-
48
- # Calculate the size of each segment from which a frame will be extracted
49
- seg_size = float(total_num_frames - 1) / desired_num_frames
50
-
51
- seq = []
52
- for i in range(desired_num_frames):
53
- # Calculate the start and end indices of each segment
54
- start = int(np.round(seg_size * i))
55
- end = int(np.round(seg_size * (i + 1)))
56
-
57
- # Append the middle index of the segment to the list
58
- seq.append((start + end) // 2)
59
-
60
- return seq
61
-
62
- def build_video_prompt(meta_list, num_frames, time_position=False):
63
- # if time_position is True, the frame_timestamp is used.
64
- # 1. pass time_position, 2. use env TIME_POSITION
65
- time_position = os.environ.get("TIME_POSITION", time_position)
66
- prefix = f"This is a video:\n"
67
- for i in range(num_frames):
68
- if time_position:
69
- frame_txt = f"Frame {i+1} sampled at {meta_list[i]:.2f} seconds: <image>\n"
70
- else:
71
- frame_txt = f"Frame {i+1}: <image>\n"
72
- prefix += frame_txt
73
- return prefix
74
-
75
- def load_video(video_path, num_frames=64, frame_cache_root=None):
76
- if isinstance(video_path, str):
77
- video = decord.VideoReader(video_path)
78
- elif isinstance(video_path, dict):
79
- assert False, 'we not support vidoe: "video_path" as input'
80
- fps = video.get_avg_fps()
81
- sampled_frames = get_seq_frames(len(video), num_frames)
82
- samepld_timestamps = [i / fps for i in sampled_frames]
83
- frames = video.get_batch(sampled_frames).asnumpy()
84
- images = [Image.fromarray(frame) for frame in frames]
85
-
86
- return images, build_video_prompt(samepld_timestamps, len(images), time_position=True)
87
-
88
- def load_image(image):
89
- if isinstance(image, str) and os.path.exists(image):
90
- return Image.open(image)
91
- elif isinstance(image, dict):
92
- if 'disk_path' in image:
93
- return Image.open(image['disk_path'])
94
- elif 'base64' in image:
95
- return Image.open(BytesIO(base64.b64decode(image['base64'])))
96
- elif 'url' in image:
97
- response = requests.get(image['url'])
98
- return Image.open(BytesIO(response.content))
99
- elif 'bytes' in image:
100
- return Image.open(BytesIO(image['bytes']))
101
- else:
102
- raise ValueError(f'Invalid image: {image}')
103
- else:
104
- raise ValueError(f'Invalid image: {image}')
105
-
106
- def build_transform(input_size, norm_type='imagenet'):
107
- if norm_type == 'imagenet':
108
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
109
- elif norm_type == 'siglip':
110
- MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
111
-
112
- transform = T.Compose([
113
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
114
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
115
- T.ToTensor(),
116
- T.Normalize(mean=MEAN, std=STD)
117
- ])
118
- return transform
119
-
120
-
121
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
122
- """
123
- previous version mainly foucs on ratio.
124
- We also consider area ratio here.
125
- """
126
- best_factor = float('-inf')
127
- best_ratio = (1, 1)
128
- area = width * height
129
- for ratio in target_ratios:
130
- target_aspect_ratio = ratio[0] / ratio[1]
131
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
132
- area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
133
- """
134
- new area > 60% of original image area is enough.
135
- """
136
- factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
137
- min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
138
-
139
- if factor_based_on_area_n_ratio > best_factor:
140
- best_factor = factor_based_on_area_n_ratio
141
- best_ratio = ratio
142
-
143
- return best_ratio
144
-
145
-
146
- def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
147
- orig_width, orig_height = image.size
148
- aspect_ratio = orig_width / orig_height
149
-
150
- # calculate the existing image aspect ratio
151
- target_ratios = set(
152
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
153
- i * j <= max_num and i * j >= min_num)
154
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
155
-
156
- # find the closest aspect ratio to the target
157
- target_aspect_ratio = find_closest_aspect_ratio(
158
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
159
-
160
- # calculate the target width and height
161
- target_width = image_size * target_aspect_ratio[0]
162
- target_height = image_size * target_aspect_ratio[1]
163
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
164
-
165
- # resize the image
166
- resized_img = image.resize((target_width, target_height))
167
- processed_images = []
168
- for i in range(blocks):
169
- box = (
170
- (i % (target_width // image_size)) * image_size,
171
- (i // (target_width // image_size)) * image_size,
172
- ((i % (target_width // image_size)) + 1) * image_size,
173
- ((i // (target_width // image_size)) + 1) * image_size
174
- )
175
- # split the image
176
- split_img = resized_img.crop(box)
177
- processed_images.append(split_img)
178
- assert len(processed_images) == blocks
179
- if use_thumbnail and len(processed_images) != 1:
180
- thumbnail_img = image.resize((image_size, image_size))
181
- processed_images.append(thumbnail_img)
182
- return processed_images
183
-
184
- def split_model(model_path, device):
185
-
186
- device_map = {}
187
- world_size = torch.cuda.device_count()
188
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
189
- num_layers = config.llm_config.num_hidden_layers
190
-
191
- num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
192
- num_layers_per_gpu = [num_layers_per_gpu_] * world_size
193
- num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
194
- layer_cnt = 0
195
- for i, num_layer in enumerate(num_layers_per_gpu):
196
- for j in range(num_layer):
197
- device_map[f'language_model.model.layers.{layer_cnt}'] = i
198
- layer_cnt += 1
199
- device_map['vision_model'] = device
200
- device_map['mlp1'] = device
201
- device_map['language_model.model.tok_embeddings'] = device
202
- device_map['language_model.model.embed_tokens'] = device
203
- device_map['language_model.output'] = device
204
- device_map['language_model.model.norm'] = device
205
- device_map['language_model.lm_head'] = device
206
- device_map['language_model.model.rotary_emb'] = device
207
- device_map[f'language_model.model.layers.{num_layers - 1}'] = device
208
- return device_map
209
-
210
- class ModelWorker:
211
- def __init__(self, model_path, model_name,
212
- load_8bit, device):
213
-
214
- if model_path.endswith('/'):
215
- model_path = model_path[:-1]
216
- if model_name is None:
217
- model_paths = model_path.split('/')
218
- if model_paths[-1].startswith('checkpoint-'):
219
- self.model_name = model_paths[-2] + '_' + model_paths[-1]
220
- else:
221
- self.model_name = model_paths[-1]
222
- else:
223
- self.model_name = model_name
224
-
225
- print(f'Loading the model {self.model_name}')
226
-
227
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
228
- tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
229
- tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
230
- self.tokenizer = tokenizer
231
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
232
- model_type = config.vision_config.model_type
233
- self.device = torch.cuda.current_device()
234
- if model_type == 'siglip_vision_model':
235
- self.norm_type = 'siglip'
236
- elif model_type == 'MOB':
237
- self.norm_type = 'siglip'
238
- else:
239
- self.norm_type = 'imagenet'
240
- print('norm_type: ', self.norm_type)
241
- if any(x in model_path.lower() for x in ['34b']):
242
- device_map = split_model(model_path, self.device)
243
- else:
244
- device_map = None
245
-
246
- if device_map is not None:
247
- self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
248
- low_cpu_mem_usage=True,
249
- device_map=device_map,
250
- trust_remote_code=True,
251
- load_in_8bit=load_8bit).eval()
252
- else:
253
- self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
254
- trust_remote_code=True,
255
- load_in_8bit=load_8bit).eval()
256
- if not load_8bit and device_map is None:
257
- self.model = self.model.to(device)
258
- self.load_8bit = load_8bit
259
-
260
- self.model_path = model_path
261
- self.image_size = self.model.config.force_image_size
262
- self.context_len = tokenizer.model_max_length
263
- self.per_tile_len = 256
264
- print(self.model)
265
- def reload_model(self):
266
- del self.model
267
- torch.cuda.empty_cache()
268
- if self.device == 'auto':
269
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
270
- # This can make distributed deployment work properly
271
- self.model = AutoModel.from_pretrained(
272
- self.model_path,
273
- load_in_8bit=self.load_8bit,
274
- torch_dtype=torch.bfloat16,
275
- device_map=self.device_map,
276
- trust_remote_code=True).eval()
277
- else:
278
- self.model = AutoModel.from_pretrained(
279
- self.model_path,
280
- load_in_8bit=self.load_8bit,
281
- torch_dtype=torch.bfloat16,
282
- trust_remote_code=True).eval()
283
- if not self.load_8bit and not self.device == 'auto':
284
- self.model = self.model.cuda()
285
-
286
- @torch.inference_mode()
287
- def generate(self, params):
288
- system_message = params['prompt'][0]['content']
289
- send_messages = params['prompt'][1:]
290
- max_input_tiles = params['max_input_tiles']
291
- temperature = params['temperature']
292
- top_p = params['top_p']
293
- max_new_tokens = params['max_new_tokens']
294
- repetition_penalty = params['repetition_penalty']
295
- video_frame_num = params.get('video_frame_num', 64)
296
- do_sample = True if temperature > 0.0 else False
297
-
298
- global_image_cnt = 0
299
- history, pil_images, max_input_tile_list = [], [], []
300
-
301
- for message in send_messages:
302
- if message['role'] == 'user':
303
- prefix = ''
304
- if 'image' in message:
305
- for image_data in message['image']:
306
- pil_images.append(load_image(image_data))
307
- prefix = prefix + f'<image {global_image_cnt + 1}><image>\n'
308
- global_image_cnt += 1
309
- max_input_tile_list.append(max_input_tiles)
310
- if 'video' in message:
311
- for video_data in message['video']:
312
- video_frames, tmp_prefix = load_video(video_data, num_frames=video_frame_num)
313
- pil_images.extend(video_frames)
314
- prefix = prefix + tmp_prefix
315
- global_image_cnt += len(video_frames)
316
- max_input_tile_list.extend([1] * len(video_frames))
317
- content = prefix + message['content']
318
- history.append([content, ])
319
- else:
320
- history[-1].append(message['content'])
321
- question, history = history[-1][0], history[:-1]
322
-
323
- if global_image_cnt == 1:
324
- question = question.replace('<image 1><image>\n', '<image>\n')
325
- history = [[item[0].replace('<image 1><image>\n', '<image>\n'), item[1]] for item in history]
326
-
327
-
328
- try:
329
- assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
330
- except Exception as e:
331
- from IPython import embed; embed()
332
- exit()
333
- print(f'Error: {e}')
334
- print(f'max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}')
335
- # raise e
336
-
337
- old_system_message = self.model.system_message
338
- self.model.system_message = system_message
339
-
340
- transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
341
- if len(pil_images) > 0:
342
- max_input_tiles_limited_by_contect = params['max_input_tiles']
343
- while True:
344
- image_tiles = []
345
- num_patches_list = []
346
- for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
347
- if self.model.config.dynamic_image_size:
348
- tiles = dynamic_preprocess(
349
- pil_image, image_size=self.image_size, max_num=min(current_max_input_tiles, max_input_tiles_limited_by_contect),
350
- use_thumbnail=self.model.config.use_thumbnail)
351
- else:
352
- tiles = [pil_image]
353
- num_patches_list.append(len(tiles))
354
- image_tiles += tiles
355
- if (len(image_tiles) * self.per_tile_len < self.context_len):
356
- break
357
- else:
358
- max_input_tiles_limited_by_contect -= 2
359
-
360
- if max_input_tiles_limited_by_contect < 1:
361
- break
362
-
363
- pixel_values = [transform(item) for item in image_tiles]
364
-
365
-
366
- pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
367
-
368
- else:
369
- pixel_values = None
370
-
371
- generation_config = dict(
372
- num_beams=1,
373
- max_new_tokens=max_new_tokens,
374
- do_sample=do_sample,
375
- temperature=temperature,
376
- repetition_penalty=repetition_penalty,
377
- max_length=self.context_len,
378
- top_p=top_p,
379
- )
380
- print(f'pixel_values: {pixel_values.shape}')
381
- response = self.model.chat(
382
- tokenizer=self.tokenizer,
383
- pixel_values=pixel_values,
384
- question=question,
385
- history=history,
386
- return_history=False,
387
- num_patches_list=num_patches_list,
388
- generation_config=generation_config,
389
- )
390
- self.model.system_message = old_system_message
391
- return {'text': response, 'error_code': 0}
392
-
393
-
394
-
395
 
 
 
396
 
397
- if __name__ == '__main__':
398
- parser = argparse.ArgumentParser()
399
- parser.add_argument('--model-path', type=str, default='/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B')
400
- parser.add_argument('--model-name', type=str, default='Eagle2')
401
- parser.add_argument('--device', type=str, default='cuda')
402
- parser.add_argument('--load-8bit', action='store_true')
403
- args = parser.parse_args()
404
- print(f'args: {args}')
405
 
406
- worker = ModelWorker(
407
- args.model_path,
408
- args.model_name,
409
- args.load_8bit,
410
- args.device)
411
- prompt = [
412
- {'role': 'system', 'content': 'You are a helpful assistant.'},
413
- {'role': 'user', 'content': 'Describe these two images in details respectively.',
414
- 'image':[
415
- {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'},
416
- {'url': "https://www.google.com.hk/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"}
417
- ]
418
- }
419
- ]
420
- params = {
421
- 'prompt': prompt,
422
- 'max_input_tiles': 24,
423
- 'temperature': 0.7,
424
- 'top_p': 1.0,
425
- 'max_new_tokens': 4096,
426
- 'repetition_penalty': 1.0,
427
  }
428
- print(worker.generate(params))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import requests
3
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from transformers import TextIteratorStreamer
7
+ import threading
8
 
 
 
 
 
 
 
 
 
9
 
10
+ model = AutoModel.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B",trust_remote_code=True, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16)
11
+ tokenizer = AutoTokenizer.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B", trust_remote_code=True, use_fast=True)
12
+ processor = AutoProcessor.from_pretrained("/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B", trust_remote_code=True, use_fast=True)
13
+ processor.tokenizer.padding_side = "left"
14
+
15
+ messages = [
16
+ {
17
+ "role": "user",
18
+ "content": [
19
+ {
20
+ "type": "image",
21
+ "image": "https://www.ilankelman.org/stopsigns/australia.jpg",
22
+ },
23
+ {"type": "text", "text": "Describe this image."},
24
+ ],
 
 
 
 
 
 
25
  }
26
+ ]
27
+
28
+ text_list = [processor.apply_chat_template(
29
+ messages, tokenize=False, add_generation_prompt=True
30
+ )]
31
+ image_inputs, video_inputs = processor.process_vision_info(messages)
32
+ inputs = processor(text = text_list, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
33
+ inputs = inputs.to("cuda")
34
+ model = model.to("cuda")
35
+ # generated_ids = model.generate(**inputs, max_new_tokens=1024)
36
+ # output_text = processor.batch_decode(
37
+ # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
38
+ # )
39
+ # print(output_text)
40
+
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+
43
+ generation_kwargs = dict(
44
+ **inputs,
45
+ streamer=streamer,
46
+ max_new_tokens=1024,
47
+ do_sample=True,
48
+ top_p=0.95,
49
+ temperature=0.8
50
+ )
51
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
52
+ thread.start()
53
+
54
+ for new_text in streamer:
55
+ print(new_text, end="", flush=True)
image_processing_eagle2.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for LLaVa-Onevision."""
16
+
17
+ import math
18
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
23
+ from transformers.image_transforms import (
24
+ PaddingMode,
25
+ convert_to_rgb,
26
+ pad,
27
+ resize,
28
+ to_channel_dimension_format,
29
+ )
30
+ from transformers.image_utils import (
31
+ OPENAI_CLIP_MEAN,
32
+ OPENAI_CLIP_STD,
33
+ IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
34
+ IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
35
+ ChannelDimension,
36
+ ImageInput,
37
+ PILImageResampling,
38
+ get_image_size,
39
+ infer_channel_dimension_format,
40
+ is_scaled_image,
41
+ make_flat_list_of_images,
42
+ to_numpy_array,
43
+ valid_images,
44
+ validate_preprocess_arguments,
45
+ )
46
+ from transformers.utils import TensorType, is_vision_available, logging
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ if is_vision_available():
53
+ from PIL import Image
54
+
55
+ def crop(img: np.ndarray, left: int, top: int, right: int, bottom: int, input_data_format: ChannelDimension) -> np.ndarray:
56
+ """Crop the given numpy array.
57
+
58
+ Args:
59
+ img (np.ndarray): Image to be cropped. Format should be (H, W, C) or (H, W).
60
+ left (int): The left coordinate of the crop box.
61
+ top (int): The top coordinate of the crop box.
62
+ right (int): The right coordinate of the crop box.
63
+ bottom (int): The bottom coordinate of the crop box.
64
+
65
+ Returns:
66
+ np.ndarray: Cropped image.
67
+ """
68
+ if not isinstance(img, np.ndarray):
69
+ raise TypeError('img should be numpy array. Got {}'.format(type(img)))
70
+
71
+ if img.ndim not in [2, 3]:
72
+ raise ValueError('Image should have 2 or 3 dimensions. Got {}'.format(img.ndim))
73
+
74
+ if input_data_format == ChannelDimension.LAST:
75
+ img_height = img.shape[0]
76
+ img_width = img.shape[1]
77
+ else:
78
+ img_height = img.shape[1]
79
+ img_width = img.shape[2]
80
+
81
+ if top < 0 or left < 0 or bottom > img_height or right > img_width:
82
+ raise ValueError('Crop coordinates out of bounds')
83
+
84
+ if top >= bottom or left >= right:
85
+ raise ValueError('Invalid crop coordinates')
86
+ if input_data_format == ChannelDimension.LAST:
87
+ return img[top:bottom, left:right, :]
88
+ else:
89
+ return img[:, top:bottom, left:right]
90
+
91
+ # Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches
92
+ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
93
+ """
94
+ Divides an image into patches of a specified size.
95
+
96
+ Args:
97
+ image (`np.array`):
98
+ The input image.
99
+ patch_size (`int`):
100
+ The size of each patch.
101
+ input_data_format (`ChannelDimension` or `str`):
102
+ The channel dimension format of the input image.
103
+
104
+ Returns:
105
+ list: A list of np.array representing the patches.
106
+ """
107
+ patches = []
108
+ height, width = get_image_size(image, channel_dim=input_data_format)
109
+ for i in range(0, height, patch_size):
110
+ for j in range(0, width, patch_size):
111
+ if input_data_format == ChannelDimension.LAST:
112
+ patch = image[i : i + patch_size, j : j + patch_size]
113
+ else:
114
+ patch = image[:, i : i + patch_size, j : j + patch_size]
115
+ patches.append(patch)
116
+
117
+ return patches
118
+
119
+
120
+ # Copied from transformers.models.llava_next.image_processing_llava_next.expand_to_square
121
+ def expand_to_square(image: np.array, background_color, input_data_format) -> np.array:
122
+ """
123
+ Expands an image to a square by adding a background color.
124
+ """
125
+
126
+ height, width = get_image_size(image, channel_dim=input_data_format)
127
+ if width == height:
128
+ return image
129
+ elif width > height:
130
+ result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
131
+ result[(width - height) // 2 : (width - height) // 2 + height, :] = image
132
+ return result
133
+ else:
134
+ result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
135
+ result[:, (height - width) // 2 : (height - width) // 2 + width] = image
136
+ return result
137
+
138
+
139
+ # Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
140
+ def _get_patch_output_size(image, target_resolution, input_data_format):
141
+ original_height, original_width = get_image_size(image, channel_dim=input_data_format)
142
+ target_height, target_width = target_resolution
143
+
144
+ scale_w = target_width / original_width
145
+ scale_h = target_height / original_height
146
+
147
+ if scale_w < scale_h:
148
+ new_width = target_width
149
+ new_height = min(math.ceil(original_height * scale_w), target_height)
150
+ else:
151
+ new_height = target_height
152
+ new_width = min(math.ceil(original_width * scale_h), target_width)
153
+
154
+ return new_height, new_width
155
+
156
+
157
+ class Eagle2ImageProcessor(BaseImageProcessor):
158
+ r"""
159
+ Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
160
+
161
+ Args:
162
+ do_resize (`bool`, *optional*, defaults to `True`):
163
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
164
+ `do_resize` in the `preprocess` method.
165
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
166
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
167
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
168
+ method.
169
+ image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
170
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
171
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
172
+ method. Not used for processinf videos.
173
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
174
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
175
+ do_rescale (`bool`, *optional*, defaults to `True`):
176
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
177
+ the `preprocess` method.
178
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
179
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
180
+ method.
181
+ do_normalize (`bool`, *optional*, defaults to `True`):
182
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
183
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
184
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
185
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
186
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
187
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
188
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
189
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
190
+ do_pad (`bool`, *optional*, defaults to `True`):
191
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
192
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
193
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
194
+ Whether to convert the image to RGB.
195
+ """
196
+
197
+ model_input_names = ["pixel_values_videos"]
198
+
199
+ def __init__(
200
+ self,
201
+ do_resize: bool = True,
202
+ size: Dict[str, int] = None,
203
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
204
+ do_rescale: bool = True,
205
+ rescale_factor: Union[int, float] = 1 / 255,
206
+ do_normalize: bool = True,
207
+ image_mean: Optional[Union[float, List[float]]] = None,
208
+ image_std: Optional[Union[float, List[float]]] = None,
209
+ do_pad: Optional[bool] = True,
210
+ do_convert_rgb: bool = True,
211
+ min_dynamic_tiles: int = 1,
212
+ max_dynamic_tiles: int = 12,
213
+ use_thumbnail: bool = True,
214
+ pad_during_tiling: bool = False,
215
+ **kwargs,
216
+ ) -> None:
217
+ super().__init__(**kwargs)
218
+ size = size if size is not None else {"height": 384, "width": 384}
219
+ size = get_size_dict(size, default_to_square=False)
220
+
221
+ self.do_resize = do_resize
222
+ self.size = size
223
+ self.resample = resample
224
+ self.do_rescale = do_rescale
225
+ self.rescale_factor = rescale_factor
226
+ self.do_normalize = do_normalize
227
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
228
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
229
+ self.do_pad = do_pad
230
+ self.do_convert_rgb = do_convert_rgb
231
+ self.min_dynamic_tiles = min_dynamic_tiles
232
+ self.max_dynamic_tiles = max_dynamic_tiles
233
+ self.use_thumbnail = use_thumbnail
234
+ self.pad_during_tiling = pad_during_tiling
235
+
236
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.pad
237
+ def pad(
238
+ self,
239
+ image: np.ndarray,
240
+ padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
241
+ mode: PaddingMode = PaddingMode.CONSTANT,
242
+ constant_values: Union[float, Iterable[float]] = 0.0,
243
+ data_format: Optional[Union[str, ChannelDimension]] = None,
244
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
245
+ ) -> np.ndarray:
246
+ """
247
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
248
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
249
+ as input.
250
+
251
+ Args:
252
+ image (`np.ndarray`):
253
+ The image to pad.
254
+ padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
255
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
256
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
257
+ - `((before, after),)` yields same before and after pad for height and width.
258
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
259
+ mode (`PaddingMode`):
260
+ The padding mode to use. Can be one of:
261
+ - `"constant"`: pads with a constant value.
262
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
263
+ vector along each axis.
264
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
265
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
266
+ constant_values (`float` or `Iterable[float]`, *optional*):
267
+ The value to use for the padding if `mode` is `"constant"`.
268
+ data_format (`str` or `ChannelDimension`, *optional*):
269
+ The channel dimension format for the output image. Can be one of:
270
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
271
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
272
+ If unset, will use same as the input image.
273
+ input_data_format (`str` or `ChannelDimension`, *optional*):
274
+ The channel dimension format for the input image. Can be one of:
275
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
276
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
277
+ If unset, will use the inferred format of the input image.
278
+
279
+ Returns:
280
+ `np.ndarray`: The padded image.
281
+
282
+ """
283
+
284
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
285
+ if isinstance(padding, int) or len(padding) != 4:
286
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
287
+
288
+ if input_data_format is None:
289
+ input_data_format = infer_channel_dimension_format(image)
290
+ if mode == PaddingMode.CONSTANT:
291
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
292
+ elif mode == PaddingMode.REFLECT:
293
+ image = np.pad(image, padding, mode="reflect")
294
+ elif mode == PaddingMode.REPLICATE:
295
+ image = np.pad(image, padding, mode="edge")
296
+ elif mode == PaddingMode.SYMMETRIC:
297
+ image = np.pad(image, padding, mode="symmetric")
298
+ else:
299
+ raise ValueError(f"Invalid padding mode: {mode}")
300
+ image = (
301
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
302
+ )
303
+ return image
304
+
305
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._resize_for_patching
306
+ def _resize_for_patching(
307
+ self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
308
+ ) -> np.array:
309
+ """
310
+ Resizes an image to a target resolution while maintaining aspect ratio.
311
+
312
+ Args:
313
+ image (np.array):
314
+ The input image.
315
+ target_resolution (tuple):
316
+ The target resolution (height, width) of the image.
317
+ resample (`PILImageResampling`):
318
+ Resampling filter to use if resizing the image.
319
+ input_data_format (`ChannelDimension` or `str`):
320
+ The channel dimension format of the input image.
321
+
322
+ Returns:
323
+ np.array: The resized and padded image.
324
+ """
325
+
326
+ new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
327
+ # Resize the image
328
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
329
+
330
+ return resized_image
331
+
332
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
333
+ def _pad_for_patching(
334
+ self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
335
+ ) -> np.array:
336
+ """
337
+ Pad an image to a target resolution while maintaining aspect ratio.
338
+ """
339
+ target_height, target_width = target_resolution
340
+ new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
341
+
342
+ paste_x = (target_width - new_width) // 2
343
+ paste_y = (target_height - new_height) // 2
344
+
345
+ padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
346
+
347
+ return padded_image
348
+
349
+
350
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
351
+ """
352
+ previous version mainly foucs on ratio.
353
+ We also consider area ratio here.
354
+ """
355
+ best_factor = float('-inf')
356
+ best_ratio = (1, 1)
357
+ area = width * height
358
+ for ratio in target_ratios:
359
+ target_aspect_ratio = ratio[0] / ratio[1]
360
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
361
+ area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
362
+ """
363
+ new area > 60% of original image area is enough.
364
+ """
365
+ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
366
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
367
+
368
+ if factor_based_on_area_n_ratio > best_factor:
369
+ best_factor = factor_based_on_area_n_ratio
370
+ best_ratio = ratio
371
+
372
+ return best_ratio
373
+
374
+
375
+ def get_image_patches(
376
+ self,
377
+ image: np.array,
378
+ min_num: int,
379
+ max_num: int,
380
+ size: tuple,
381
+ tile_size: int,
382
+ use_thumbnail: bool,
383
+ resample: PILImageResampling,
384
+ data_format: ChannelDimension,
385
+ input_data_format: ChannelDimension,
386
+ ):
387
+ image_size = get_image_size(image, channel_dim=input_data_format)
388
+ orig_height, orig_width = image_size
389
+ aspect_ratio = orig_width / orig_height
390
+
391
+ # calculate the existing image aspect ratio
392
+ target_ratios = set(
393
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
394
+ i * j <= max_num and i * j >= min_num)
395
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
396
+
397
+ # find the closest aspect ratio to the target
398
+ target_aspect_ratio = self.find_closest_aspect_ratio(
399
+ aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
400
+
401
+ # calculate the target width and height
402
+ target_width = tile_size * target_aspect_ratio[0]
403
+ target_height = tile_size * target_aspect_ratio[1]
404
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
405
+ if self.pad_during_tiling:
406
+ resized_image = self._resize_for_patching(
407
+ image, (target_height, target_width), resample=resample, input_data_format=input_data_format
408
+ )
409
+ padded_image = self._pad_for_patching(resized_image, (target_height, target_width), input_data_format=input_data_format)
410
+ image_used_to_split = padded_image
411
+ else:
412
+ image_used_to_split = resize(image, (target_height, target_width), resample=resample, input_data_format=input_data_format)
413
+
414
+ processed_tiles = []
415
+ for i in range(blocks):
416
+ box = (
417
+ (i % (target_width // tile_size)) * tile_size,
418
+ (i // (target_width // tile_size)) * tile_size,
419
+ ((i % (target_width // tile_size)) + 1) * tile_size,
420
+ ((i // (target_width // tile_size)) + 1) * tile_size
421
+ )
422
+ # split the image
423
+ split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3], input_data_format)
424
+ processed_tiles.append(split_img)
425
+ assert len(processed_tiles) == blocks
426
+
427
+ if use_thumbnail and len(processed_tiles) != 1:
428
+ thumbnail_img = resize(image, (tile_size, tile_size), resample=resample, input_data_format=input_data_format)
429
+ processed_tiles.append(thumbnail_img)
430
+
431
+ # make sure that all patches are in the input data format
432
+ processed_tiles = [
433
+ to_channel_dimension_format(tile, channel_dim=data_format, input_channel_dim=input_data_format)
434
+ for tile in processed_tiles
435
+ ]
436
+ return processed_tiles
437
+
438
+
439
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_batching
440
+ def _pad_for_batching(
441
+ self,
442
+ pixel_values: List[np.ndarray],
443
+ data_format: Optional[Union[str, ChannelDimension]] = None,
444
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
445
+ ):
446
+ """
447
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
448
+
449
+ Args:
450
+ pixel_values (`List[np.ndarray]`):
451
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
452
+ data_format (`str` or `ChannelDimension`, *optional*):
453
+ The channel dimension format for the output image. Can be one of:
454
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
455
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
456
+ If unset, will use same as the input image.
457
+ input_data_format (`str` or `ChannelDimension`, *optional*):
458
+ The channel dimension format for the input image. Can be one of:
459
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
460
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
461
+ If unset, will use the inferred format of the input image.
462
+
463
+ Returns:
464
+ List[`np.ndarray`]: The padded images.
465
+ """
466
+ max_patch = max(len(x) for x in pixel_values)
467
+ pixel_values = [
468
+ self.pad(
469
+ image,
470
+ padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
471
+ data_format=data_format,
472
+ input_data_format=input_data_format,
473
+ )
474
+ for image in pixel_values
475
+ ]
476
+
477
+ return pixel_values
478
+
479
+ def _preprocess(
480
+ self,
481
+ images: ImageInput,
482
+ do_resize: Optional[bool] = None,
483
+ size: Dict[str, int] = None,
484
+ resample: PILImageResampling = None,
485
+ do_rescale: Optional[bool] = None,
486
+ rescale_factor: Optional[float] = None,
487
+ do_normalize: Optional[bool] = None,
488
+ image_mean: Optional[Union[float, List[float]]] = None,
489
+ image_std: Optional[Union[float, List[float]]] = None,
490
+ do_convert_rgb: Optional[bool] = None,
491
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
492
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
493
+ ) -> Image.Image:
494
+ """
495
+ Args:
496
+ images (`ImageInput`):
497
+ Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
498
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
499
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
500
+ Whether to resize the image.
501
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
502
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
503
+ the longest edge resized to keep the input aspect ratio.
504
+ resample (`int`, *optional*, defaults to `self.resample`):
505
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
506
+ has an effect if `do_resize` is set to `True`.
507
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
508
+ Whether to rescale the image.
509
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
510
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
511
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
512
+ Whether to normalize the image.
513
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
514
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
515
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
516
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
517
+ `True`.
518
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
519
+ The channel dimension format for the output image. Can be one of:
520
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
521
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
522
+ - Unset: Use the channel dimension format of the input image.
523
+ input_data_format (`ChannelDimension` or `str`, *optional*):
524
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
525
+ from the input image. Can be one of:
526
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
527
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
528
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
529
+ """
530
+ if do_resize:
531
+ assert False, 'do_resize is not supported'
532
+ images = [
533
+ resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
534
+ for image in images
535
+ ]
536
+
537
+ if do_rescale:
538
+ images = [
539
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
540
+ for image in images
541
+ ]
542
+
543
+ if do_normalize:
544
+ images = [
545
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
546
+ for image in images
547
+ ]
548
+
549
+ images = [
550
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
551
+ ]
552
+
553
+ return images
554
+
555
+ def preprocess(
556
+ self,
557
+ images: ImageInput,
558
+ do_resize: Optional[bool] = None,
559
+ size: Dict[str, int] = None,
560
+ resample: PILImageResampling = None,
561
+ do_rescale: Optional[bool] = None,
562
+ rescale_factor: Optional[float] = None,
563
+ do_normalize: Optional[bool] = None,
564
+ image_mean: Optional[Union[float, List[float]]] = None,
565
+ image_std: Optional[Union[float, List[float]]] = None,
566
+ do_pad: Optional[bool] = None,
567
+ do_convert_rgb: Optional[bool] = None,
568
+ return_tensors: Optional[Union[str, TensorType]] = None,
569
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
570
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
571
+ ):
572
+ """
573
+ Args:
574
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
575
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
576
+ tensor. Both channels-first and channels-last formats are supported.
577
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
578
+ Whether to resize the image.
579
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
580
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
581
+ the longest edge resized to keep the input aspect ratio.
582
+ resample (`int`, *optional*, defaults to `self.resample`):
583
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
584
+ has an effect if `do_resize` is set to `True`.
585
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
586
+ Whether to rescale the image.
587
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
588
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
589
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
590
+ Whether to normalize the image.
591
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
592
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
593
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
594
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
595
+ `True`.
596
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
597
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
598
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
599
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
600
+ Whether to convert the image to RGB.
601
+ return_tensors (`str` or `TensorType`, *optional*):
602
+ The type of tensors to return. Can be one of:
603
+ - Unset: Return a list of `np.ndarray`.
604
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
605
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
606
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
607
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
608
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
609
+ The channel dimension format for the output image. Can be one of:
610
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
611
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
612
+ - Unset: Use the channel dimension format of the input image.
613
+ input_data_format (`ChannelDimension` or `str`, *optional*):
614
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
615
+ from the input image. Can be one of:
616
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
617
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
618
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
619
+
620
+ """
621
+ do_resize = do_resize if do_resize is not None else self.do_resize
622
+ size = size if size is not None else self.size
623
+ size = get_size_dict(size, default_to_square=False)
624
+ resample = resample if resample is not None else self.resample
625
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
626
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
627
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
628
+ image_mean = image_mean if image_mean is not None else self.image_mean
629
+ image_std = image_std if image_std is not None else self.image_std
630
+ do_pad = do_pad if do_pad is not None else self.do_pad
631
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
632
+
633
+ images = make_flat_list_of_images(images)
634
+
635
+ if not valid_images(images):
636
+ raise ValueError(
637
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
638
+ "torch.Tensor, tf.Tensor or jax.ndarray."
639
+ )
640
+
641
+ validate_preprocess_arguments(
642
+ do_rescale=do_rescale,
643
+ rescale_factor=rescale_factor,
644
+ do_normalize=do_normalize,
645
+ image_mean=image_mean,
646
+ image_std=image_std,
647
+ do_resize=do_resize,
648
+ size=size,
649
+ resample=resample,
650
+ )
651
+
652
+ if do_convert_rgb:
653
+ images = [convert_to_rgb(image) for image in images]
654
+
655
+ # All transformations expect numpy arrays.
656
+ images = [to_numpy_array(image) for image in images]
657
+
658
+ if do_rescale and is_scaled_image(images[0]):
659
+ logger.warning_once(
660
+ "It looks like you are trying to rescale already rescaled images. If the input"
661
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
662
+ )
663
+
664
+ if input_data_format is None:
665
+ # We assume that all images have the same channel dimension format.
666
+ input_data_format = infer_channel_dimension_format(images[0])
667
+
668
+ processed_images = []
669
+ image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
670
+ for image in images:
671
+ # convert image into a list of patches
672
+ # we intentially use the same data format as the input data format
673
+ size_tuple = (
674
+ (size["height"], size["width"])
675
+ if "height" in size and "width" in size
676
+ else (size["shortest_edge"], size["shortest_edge"])
677
+ )
678
+ image_patches = self.get_image_patches(
679
+ image,
680
+ min_num=self.min_dynamic_tiles,
681
+ max_num=self.max_dynamic_tiles,
682
+ size=size_tuple,
683
+ tile_size=size["height"],
684
+ resample=resample,
685
+ data_format=input_data_format,
686
+ input_data_format=input_data_format,
687
+ use_thumbnail=self.use_thumbnail,
688
+ )
689
+
690
+ # preprocess patches
691
+ pixel_values = self._preprocess(
692
+ image_patches,
693
+ do_resize=do_resize,
694
+ size=size_tuple,
695
+ resample=resample,
696
+ do_rescale=do_rescale,
697
+ rescale_factor=rescale_factor,
698
+ do_normalize=do_normalize,
699
+ image_mean=image_mean,
700
+ image_std=image_std,
701
+ data_format=data_format,
702
+ input_data_format=input_data_format,
703
+ )
704
+ pixel_values = np.array(pixel_values)
705
+ processed_images.append(pixel_values)
706
+
707
+ if do_pad:
708
+ processed_images = self._pad_for_batching(processed_images)
709
+
710
+ return BatchFeature(
711
+ data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
712
+ )
713
+
714
+
715
+ __all__ = ["Eagle2ImageProcessor"]
image_processing_eagle2_5_vl_fast.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # NVIDIA
3
+ # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ # copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
8
+ from typing import List, Optional, Union
9
+
10
+ from transformers.image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
11
+ from transformers.image_processing_utils_fast import (
12
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
13
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
14
+ BaseImageProcessorFast,
15
+ DefaultFastImageProcessorKwargs,
16
+ divide_to_patches,
17
+ group_images_by_shape,
18
+ reorder_images,
19
+ )
20
+ from transformers.image_utils import (
21
+ OPENAI_CLIP_MEAN,
22
+ OPENAI_CLIP_STD,
23
+ IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
24
+ IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
25
+ ChannelDimension,
26
+ ImageInput,
27
+ VideoInput,
28
+ PILImageResampling,
29
+ SizeDict,
30
+ get_image_size,
31
+ make_flat_list_of_images,
32
+ make_batched_videos,
33
+ validate_kwargs
34
+ )
35
+ from transformers.processing_utils import Unpack
36
+ from transformers.utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available
37
+
38
+
39
+ if is_torch_available():
40
+ import torch
41
+ if is_torchvision_v2_available():
42
+ from transformers.image_utils import pil_torch_interpolation_mapping
43
+
44
+ from torchvision.transforms.v2 import functional as F
45
+ else:
46
+ from torchvision.transforms import functional as F
47
+
48
+ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
49
+ """Crop the given numpy array.
50
+
51
+ Args:
52
+ img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
53
+ left (int): The left coordinate of the crop box.
54
+ top (int): The top coordinate of the crop box.
55
+ right (int): The right coordinate of the crop box.
56
+ bottom (int): The bottom coordinate of the crop box.
57
+
58
+ Returns:
59
+ torch.Tensor: Cropped image.
60
+ """
61
+ if not isinstance(img, torch.Tensor):
62
+ raise TypeError('img should be torch.Tensor. Got {}'.format(type(img)))
63
+
64
+ if img.ndim not in [2, 3]:
65
+ raise ValueError('Image should have 2 or 3 dimensions. Got {}'.format(img.ndim))
66
+
67
+ img_height = img.shape[1]
68
+ img_width = img.shape[2]
69
+ if top < 0 or left < 0 or bottom > img_height or right > img_width:
70
+ raise ValueError('Crop coordinates out of bounds')
71
+
72
+ if top >= bottom or left >= right:
73
+ raise ValueError('Invalid crop coordinates')
74
+
75
+ return img[:, top:bottom, left:right]
76
+
77
+
78
+ class Eagle2_5_VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
79
+ max_dynamic_tiles: Optional[int]
80
+ min_dynamic_tiles: Optional[int]
81
+ use_thumbnail: Optional[bool]
82
+ pad_during_tiling: Optional[bool]
83
+ do_pad: Optional[bool]
84
+
85
+
86
+ @add_start_docstrings(
87
+ "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
88
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
89
+ """
90
+ image_grid_pinpoints (`List[List[int]]`, *optional*):
91
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
92
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
93
+ method. Not used for processing videos.
94
+ do_pad (`bool`, *optional*):
95
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
96
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
97
+ """,
98
+ )
99
+ class Eagle2_5_VLImageProcessorFast(BaseImageProcessorFast):
100
+ resample = PILImageResampling.BICUBIC
101
+ image_mean = IMAGENET_STANDARD_MEAN
102
+ image_std = IMAGENET_STANDARD_STD
103
+ size = {"height": 448, "width": 448}
104
+ default_to_square = False
105
+ crop_size = None
106
+ do_resize = True
107
+ do_center_crop = None
108
+ do_rescale = True
109
+ do_normalize = True
110
+ do_convert_rgb = True
111
+ do_pad = True
112
+ max_dynamic_tiles = 12
113
+ min_dynamic_tiles = 1
114
+ use_thumbnail = True
115
+ pad_during_tiling = False
116
+ valid_kwargs = Eagle2_5_VLFastImageProcessorKwargs
117
+ model_input_names = ["pixel_values_videos"]
118
+
119
+ def __init__(self, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]):
120
+ super().__init__(**kwargs)
121
+
122
+ @add_start_docstrings(
123
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
124
+ """
125
+ max_dynamic_tiles (`int`, *optional*):
126
+ The maximum number of dynamic tiles to use for processing high resolution images.
127
+ min_dynamic_tiles (`int`, *optional*):
128
+ The minimum number of dynamic tiles to use for processing high resolution images.
129
+ use_thumbnail (`bool`, *optional*):
130
+ Whether to use a thumbnail for processing high resolution images.
131
+ pad_during_tiling (`bool`, *optional*):
132
+ Whether to pad the image during tiling.
133
+ do_pad (`bool`, *optional*):
134
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
135
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
136
+ """,
137
+ )
138
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]) -> BatchFeature:
139
+ return super().preprocess(images, **kwargs)
140
+
141
+ def _prepare_images_structure(
142
+ self,
143
+ images: ImageInput,
144
+ ) -> ImageInput:
145
+ """
146
+ Prepare the images structure for processing.
147
+
148
+ Args:
149
+ images (`ImageInput`):
150
+ The input images to process.
151
+
152
+ Returns:
153
+ `ImageInput`: The images with a valid nesting.
154
+ """
155
+ return make_flat_list_of_images(images)
156
+
157
+ def _prepare_videos_structure(self, videos: VideoInput) -> VideoInput:
158
+ return self._prepare_images_structure(videos)
159
+
160
+ def _prepare_input_videos(
161
+ self,
162
+ videos: VideoInput,
163
+ do_convert_rgb: Optional[bool] = None,
164
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
165
+ device: Optional["torch.device"] = None,
166
+ ) -> list["torch.Tensor"]:
167
+ """
168
+ Prepare the input images for processing.
169
+ """
170
+ videos = self._prepare_videos_structure(videos)
171
+ process_video_fn = partial(
172
+ self._process_image,
173
+ do_convert_rgb=do_convert_rgb,
174
+ input_data_format=input_data_format,
175
+ device=device,
176
+ )
177
+ # todo: yoni - check if we can parallelize this efficiently
178
+ processed_videos = []
179
+ for video in videos:
180
+ processed_videos.append(process_video_fn(video))
181
+
182
+ return processed_videos
183
+
184
+ def _resize_for_patching(
185
+ self,
186
+ image: "torch.Tensor",
187
+ target_resolution: tuple,
188
+ interpolation: "F.InterpolationMode",
189
+ input_data_format: ChannelDimension,
190
+ ) -> "torch.Tensor":
191
+ """
192
+ Resizes an image to a target resolution while maintaining aspect ratio.
193
+
194
+ Args:
195
+ image ("torch.Tensor"):
196
+ The input image.
197
+ target_resolution (tuple):
198
+ The target resolution (height, width) of the image.
199
+ interpolation (`InterpolationMode`):
200
+ Resampling filter to use if resizing the image.
201
+ input_data_format (`ChannelDimension` or `str`):
202
+ The channel dimension format of the input image.
203
+
204
+ Returns:
205
+ "torch.Tensor": The resized and padded image.
206
+ """
207
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
208
+
209
+ # Resize the image
210
+ resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
211
+
212
+ return resized_image
213
+
214
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
215
+ """
216
+ previous version mainly foucs on ratio.
217
+ We also consider area ratio here.
218
+ """
219
+ best_factor = float('-inf')
220
+ best_ratio = (1, 1)
221
+ area = width * height
222
+ for ratio in target_ratios:
223
+ target_aspect_ratio = ratio[0] / ratio[1]
224
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
225
+ area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
226
+ """
227
+ new area > 60% of original image area is enough.
228
+ """
229
+ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
230
+ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
231
+
232
+ if factor_based_on_area_n_ratio > best_factor:
233
+ best_factor = factor_based_on_area_n_ratio
234
+ best_ratio = ratio
235
+
236
+ return best_ratio
237
+
238
+ def _pad_for_patching(
239
+ self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
240
+ ) -> "torch.Tensor":
241
+ """
242
+ Pad an image to a target resolution while maintaining aspect ratio.
243
+ """
244
+ target_height, target_width = target_resolution
245
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
246
+
247
+ paste_x = (target_width - new_width) // 2
248
+ paste_y = (target_height - new_height) // 2
249
+
250
+ padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
251
+
252
+ return padded_image
253
+
254
+ def _get_image_patches(
255
+ self,
256
+ image: "torch.Tensor",
257
+ min_num: int,
258
+ max_num: int,
259
+ size: tuple,
260
+ tile_size: int,
261
+ use_thumbnail: bool,
262
+ interpolation: "F.InterpolationMode",
263
+ pad_during_tiling: bool,
264
+ ) -> List["torch.Tensor"] :
265
+ image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
266
+ orig_height, orig_width = image_size
267
+ aspect_ratio = orig_width / orig_height
268
+
269
+ # calculate the existing image aspect ratio
270
+ target_ratios = set(
271
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
272
+ i * j <= max_num and i * j >= min_num)
273
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
274
+
275
+ # find the closest aspect ratio to the target
276
+ target_aspect_ratio = self.find_closest_aspect_ratio(
277
+ aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
278
+
279
+ # calculate the target width and height
280
+ target_width = tile_size * target_aspect_ratio[0]
281
+ target_height = tile_size * target_aspect_ratio[1]
282
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
283
+ if pad_during_tiling:
284
+ resized_image = self._resize_for_patching(
285
+ image, (target_height, target_width), interpolation=interpolation, input_data_format=ChannelDimension.FIRST
286
+ )
287
+ padded_image = self._pad_for_patching(resized_image, (target_height, target_width), input_data_format=ChannelDimension.FIRST)
288
+ image_used_to_split = padded_image
289
+ else:
290
+ image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
291
+
292
+ processed_tiles = []
293
+ for i in range(blocks):
294
+ box = (
295
+ (i % (target_width // tile_size)) * tile_size,
296
+ (i // (target_width // tile_size)) * tile_size,
297
+ ((i % (target_width // tile_size)) + 1) * tile_size,
298
+ ((i // (target_width // tile_size)) + 1) * tile_size
299
+ )
300
+ # split the image
301
+ split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
302
+ processed_tiles.append(split_img)
303
+ assert len(processed_tiles) == blocks
304
+
305
+ if use_thumbnail and len(processed_tiles) != 1:
306
+ thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
307
+ processed_tiles.append(thumbnail_img)
308
+
309
+ return processed_tiles
310
+
311
+ def _pad_for_batching(
312
+ self,
313
+ pixel_values: List["torch.Tensor"],
314
+ ) -> List["torch.Tensor"]:
315
+ """
316
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
317
+
318
+ Args:
319
+ pixel_values (`List[torch.Tensor]`):
320
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
321
+
322
+ Returns:
323
+ List[`torch.Tensor`]: The padded images.
324
+ """
325
+ max_patch = max(len(x) for x in pixel_values)
326
+ pixel_values = [
327
+ torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
328
+ for image in pixel_values
329
+ ]
330
+
331
+ return pixel_values
332
+
333
+ def _preprocess(
334
+ self,
335
+ images: List["torch.Tensor"],
336
+ do_resize: bool,
337
+ size: SizeDict,
338
+ max_dynamic_tiles: int,
339
+ min_dynamic_tiles: int,
340
+ use_thumbnail: bool,
341
+ pad_during_tiling: bool,
342
+ interpolation: Optional["F.InterpolationMode"],
343
+ do_center_crop: bool,
344
+ crop_size: SizeDict,
345
+ do_rescale: bool,
346
+ rescale_factor: float,
347
+ do_normalize: bool,
348
+ image_mean: Optional[Union[float, List[float]]],
349
+ image_std: Optional[Union[float, List[float]]],
350
+ do_pad: bool,
351
+ return_tensors: Optional[Union[str, TensorType]],
352
+ ) -> BatchFeature:
353
+ processed_images = []
354
+ image_sizes = []
355
+ # Determine the size tuple
356
+ if size and size.height and size.width:
357
+ size_tuple = (size.height, size.width)
358
+ else:
359
+ size_tuple = (size.shortest_edge, size.shortest_edge)
360
+
361
+ # Determine the patch size
362
+ if crop_size and crop_size.height:
363
+ tile_size = crop_size.height
364
+ elif size and size.height:
365
+ tile_size = size.height
366
+ else:
367
+ tile_size = size.shortest_edge
368
+
369
+ for image in images:
370
+ image_patches = self._get_image_patches(
371
+ image,
372
+ min_num=min_dynamic_tiles,
373
+ max_num=max_dynamic_tiles,
374
+ size=size_tuple,
375
+ tile_size=tile_size,
376
+ use_thumbnail=use_thumbnail,
377
+ interpolation=interpolation,
378
+ pad_during_tiling=pad_during_tiling,
379
+ )
380
+
381
+ # Group images by size for batched processing
382
+ processed_image_patches_grouped = {}
383
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
384
+
385
+ for shape, stacked_image_patches in grouped_image_patches.items():
386
+ if do_resize:
387
+ stacked_image_patches = self.resize(
388
+ image=stacked_image_patches,
389
+ size=size,
390
+ interpolation=interpolation,
391
+ )
392
+ if do_center_crop:
393
+ stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
394
+ # Fused rescale and normalize
395
+ stacked_image_patches = self.rescale_and_normalize(
396
+ stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
397
+ )
398
+ processed_image_patches_grouped[shape] = stacked_image_patches
399
+ processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
400
+ processed_image_patches = (
401
+ torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
402
+ )
403
+ processed_images.append(processed_image_patches)
404
+ image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
405
+
406
+ if do_pad:
407
+ processed_images = self._pad_for_batching(processed_images)
408
+
409
+ # processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
410
+ processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
411
+ return BatchFeature(
412
+ data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
413
+ )
414
+
415
+
416
+ def preprocess(self, images: ImageInput, videos: VideoInput=None, **kwargs: Unpack[Eagle2_5_VLFastImageProcessorKwargs]) -> BatchFeature:
417
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
418
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
419
+ # by the user, it gets its default value from the instance, or is set to None.
420
+ for kwarg_name in self.valid_kwargs.__annotations__:
421
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
422
+
423
+ # Extract parameters that are only used for preparing the input images
424
+ do_convert_rgb = kwargs.pop("do_convert_rgb")
425
+ input_data_format = kwargs.pop("input_data_format")
426
+ device = kwargs.pop("device")
427
+ # Prepare input images
428
+ if images is not None:
429
+ images = self._prepare_input_images(
430
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
431
+ )
432
+
433
+ if videos is not None:
434
+ videos = self._prepare_input_images(
435
+ images=videos, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
436
+ )
437
+
438
+ # Update kwargs that need further processing before being validated
439
+ kwargs = self._further_process_kwargs(**kwargs)
440
+
441
+ # Validate kwargs
442
+ self._validate_preprocess_kwargs(**kwargs)
443
+
444
+ # torch resize uses interpolation instead of resample
445
+ resample = kwargs.pop("resample")
446
+ kwargs["interpolation"] = (
447
+ pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
448
+ )
449
+
450
+ # Pop kwargs that are not needed in _preprocess
451
+ kwargs.pop("default_to_square")
452
+ kwargs.pop("data_format")
453
+ if images is not None:
454
+ return self._preprocess(images, **kwargs)
455
+ elif videos is not None:
456
+ return self._preprocess(videos, **kwargs)
457
+
458
+ __all__ = ["Eagle2_5_VLImageProcessorFast"]
modeling_eagle_chat.py → modeling_eagle2_5_vl.py RENAMED
@@ -1,84 +1,118 @@
1
  # --------------------------------------------------------
2
- # Eagle2
3
  # Copyright (c) 2025 NVIDIA
4
- # Licensed under The Apache License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
7
  import warnings
 
8
  from typing import Any, List, Optional, Tuple, Union
9
-
10
- import torch.utils.checkpoint
11
- import transformers
12
  from torch import nn
 
13
  from torch.nn import CrossEntropyLoss
14
- from transformers import (AutoModel, GenerationConfig,
15
- LlamaTokenizer, LlamaForCausalLM)
 
 
 
 
 
 
16
  from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import ModelOutput, logging
19
- from peft import LoraConfig, get_peft_model
20
- from transformers.models.siglip.modeling_siglip import SiglipVisionModel
21
-
22
- from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
23
 
24
  logger = logging.get_logger(__name__)
25
- from .configuration_eagle_chat import Eagle2ChatConfig
26
-
27
- def version_cmp(v1, v2, op='eq'):
28
- import operator
29
 
30
- from packaging import version
31
- op_func = getattr(operator, op)
32
- return op_func(version.parse(v1), version.parse(v2))
33
 
34
-
35
- class Eagle2ChatModel(PreTrainedModel):
36
- config_class = Eagle2ChatConfig
37
- main_input_name = 'pixel_values'
38
- _no_split_modules = ['LlamaDecoderLayer']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  _supports_flash_attn_2 = True
 
 
 
40
  _supports_sdpa = True
41
- _supports_flex_attn = False
42
- _supports_cache_class = False
43
- _supports_quantized_cache = False
44
- _supports_static_cache = False
45
- _supports_attention_backend = False
46
 
47
- def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  super().__init__(config)
49
 
50
  image_size = config.force_image_size or config.vision_config.image_size
51
-
52
  patch_size = config.vision_config.patch_size
53
  self.patch_size = patch_size
54
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
55
 
56
  self.select_layer = config.select_layer
57
- self.template = config.template
58
  self.downsample_ratio = config.downsample_ratio
59
-
 
 
60
  logger.info(f'num_image_token: {self.num_image_token}')
 
61
  if vision_model is not None:
62
  self.vision_model = vision_model
63
  else:
64
  if config.vision_config.model_type == 'siglip_vision_model':
65
- if version_cmp(transformers.__version__, '4.43.0', 'le'):
66
- config.vision_config._attn_implementation = 'eager'
67
  self.vision_model = SiglipVisionModel(config.vision_config)
 
 
68
 
69
  if language_model is not None:
70
  self.language_model = language_model
71
  else:
72
- if config.llm_config.architectures[0] == 'LlamaForCausalLM':
73
- self.language_model = LlamaForCausalLM(config.llm_config)
74
- elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
75
- self.language_model = Qwen2ForCausalLM(config.llm_config)
 
76
  else:
77
- raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
78
 
79
  vit_hidden_size = config.vision_config.hidden_size
80
-
81
- llm_hidden_size = config.llm_config.hidden_size
82
 
83
  self.mlp1 = nn.Sequential(
84
  nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
@@ -86,26 +120,39 @@ class Eagle2ChatModel(PreTrainedModel):
86
  nn.GELU(),
87
  nn.Linear(llm_hidden_size, llm_hidden_size)
88
  )
89
- self.img_context_token_id = None
90
- self.system_message = 'You are a helpful assistant.' # Default system message
 
91
 
92
  if config.use_backbone_lora:
93
  self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
94
 
 
95
  if config.use_llm_lora:
96
  self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
97
-
 
 
 
 
 
 
 
 
 
 
98
  def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
99
  lora_config = LoraConfig(
100
  r=r,
101
- target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
 
102
  lora_alpha=lora_alpha,
103
  lora_dropout=lora_dropout,
104
  )
105
  self.vision_model = get_peft_model(self.vision_model, lora_config)
106
  self.vision_model.print_trainable_parameters()
107
 
108
- def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
109
  lora_config = LoraConfig(
110
  r=r,
111
  target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
@@ -117,8 +164,8 @@ class Eagle2ChatModel(PreTrainedModel):
117
  self.language_model = get_peft_model(self.language_model, lora_config)
118
  self.language_model.enable_input_require_grads()
119
  self.language_model.print_trainable_parameters()
120
-
121
-
122
  def forward(
123
  self,
124
  pixel_values: torch.FloatTensor,
@@ -132,7 +179,7 @@ class Eagle2ChatModel(PreTrainedModel):
132
  output_attentions: Optional[bool] = None,
133
  output_hidden_states: Optional[bool] = None,
134
  return_dict: Optional[bool] = None,
135
- num_patches_list: Optional[List[torch.Tensor]] = None,
136
  ) -> Union[Tuple, CausalLMOutputWithPast]:
137
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
138
 
@@ -154,7 +201,7 @@ class Eagle2ChatModel(PreTrainedModel):
154
  print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
155
 
156
  input_ids = input_ids.reshape(B * N)
157
- selected = (input_ids == self.img_context_token_id)
158
  try:
159
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
160
  except Exception as e:
@@ -211,165 +258,36 @@ class Eagle2ChatModel(PreTrainedModel):
211
  # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
212
  x = x.view(n, int(h * scale_factor), int(w * scale_factor),
213
  int(c / (scale_factor * scale_factor)))
 
214
  x = x.permute(0, 2, 1, 3).contiguous()
215
  return x
216
 
217
  def extract_feature(self, pixel_values):
218
- """
219
- """
220
-
221
  if self.select_layer == -1:
222
  vit_embeds = self.vision_model(
223
  pixel_values=pixel_values,
224
  output_hidden_states=False,
225
  return_dict=True)
226
- # if there is vit_embeds.last_hidden_state, use it.
227
  if hasattr(vit_embeds, 'last_hidden_state'):
228
  vit_embeds = vit_embeds.last_hidden_state
 
229
  else:
230
  vit_embeds = self.vision_model(
231
  pixel_values=pixel_values,
232
  output_hidden_states=True,
233
  return_dict=True).hidden_states[self.select_layer]
234
- if type(self.vision_model) == SiglipVisionModel:
235
- pass
236
- else:
237
- vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
238
-
239
- if self.training and self.neftune_alpha is not None:
240
- vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
241
-
242
-
243
  h = w = int(vit_embeds.shape[1] ** 0.5)
244
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
245
-
246
  vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
247
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
248
- vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
 
 
 
249
 
250
  return vit_embeds
251
-
252
- def batch_chat(self,
253
- tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
254
- history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
255
- IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
256
- if history is not None or return_history:
257
- print('Now multi-turn chat is not supported in batch_chat.')
258
- raise NotImplementedError
259
-
260
- if image_counts is not None:
261
- num_patches_list = image_counts
262
- print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
263
-
264
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
265
- self.img_context_token_id = img_context_token_id
266
-
267
- if verbose and pixel_values is not None:
268
- image_bs = pixel_values.shape[0]
269
- print(f'dynamic ViT batch size: {image_bs}')
270
-
271
- queries = []
272
- for idx, num_patches in enumerate(num_patches_list):
273
- question = questions[idx]
274
- if pixel_values is not None and '<image>' not in question:
275
- question = '<image>\n' + question
276
- template_messages = []
277
- sep = tokenizer.eos_token
278
- template_messages.append(('<|im_start|>user', question))
279
- template_messages.append(('<|im_end|>assistant', None))
280
- query = self.get_prompt(self.system_message, template_messages, sep)
281
-
282
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
283
- query = query.replace('<image>', image_tokens, 1)
284
- queries.append(query)
285
-
286
- tokenizer.padding_side = 'left'
287
- model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
288
- input_ids = model_inputs['input_ids'].cuda()
289
- attention_mask = model_inputs['attention_mask'].cuda()
290
- eos_token_id = tokenizer.convert_tokens_to_ids(sep)
291
- generation_config['eos_token_id'] = eos_token_id
292
- generation_output = self.generate(
293
- pixel_values=pixel_values,
294
- input_ids=input_ids,
295
- attention_mask=attention_mask,
296
- **generation_config
297
- )
298
- responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
299
- responses = [response.split(sep)[0].strip() for response in responses]
300
- return responses
301
-
302
- def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
303
- num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
304
- verbose=False, llm_only=False):
305
-
306
- if history is None and pixel_values is not None and '<image>' not in question:
307
- question = '<image>\n' + question
308
-
309
- if num_patches_list is None:
310
- num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
311
- assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
312
-
313
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
314
- self.img_context_token_id = img_context_token_id
315
-
316
- template_messages = []
317
- system_message = f'<|im_start|>system\n{self.system_message}'
318
- sep = tokenizer.eos_token
319
- eos_token_id = tokenizer.convert_tokens_to_ids(sep)
320
-
321
- history = [] if history is None else history
322
- for (old_question, old_answer) in history:
323
- template_messages.append(('<|im_start|>user', old_question))
324
- template_messages.append(('<|im_start|>assistant', old_answer))
325
- template_messages.append(('<|im_start|>user', question))
326
- template_messages.append(('<|im_end|>assistant', None))
327
- query = self.get_prompt(system_message, template_messages, sep)
328
-
329
- if verbose and pixel_values is not None:
330
- image_bs = pixel_values.shape[0]
331
- print(f'dynamic ViT batch size: {image_bs}')
332
-
333
- for num_patches in num_patches_list:
334
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
335
- if llm_only:
336
- query = query.replace('<image>', '', 1)
337
- else:
338
- query = query.replace('<image>', image_tokens, 1)
339
-
340
- model_inputs = tokenizer(query, return_tensors='pt')
341
- input_ids = model_inputs['input_ids'].cuda()
342
- attention_mask = model_inputs['attention_mask'].cuda()
343
- generation_config['eos_token_id'] = eos_token_id
344
- generation_output = self.generate(
345
- pixel_values=pixel_values,
346
- input_ids=input_ids,
347
- attention_mask=attention_mask,
348
- **generation_config
349
- )
350
- response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
351
- response = response.split(sep)[0].strip()
352
- history.append((question, response))
353
- if return_history:
354
- return response, history
355
- else:
356
- query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
357
- query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
358
- if verbose:
359
- print(query_to_print, response)
360
- return response
361
-
362
- def get_prompt(self, system_prompt, messages, sep) -> str:
363
- """Get the prompt for generation."""
364
-
365
- ret = '' if system_prompt == '' else system_prompt + sep + '\n'
366
- for role, message in messages:
367
- if message:
368
- ret += role + '\n' + message + sep + '\n'
369
- else:
370
- ret += role + '\n'
371
- return ret
372
-
373
  @torch.no_grad()
374
  def generate(
375
  self,
@@ -379,11 +297,10 @@ class Eagle2ChatModel(PreTrainedModel):
379
  visual_features: Optional[torch.FloatTensor] = None,
380
  generation_config: Optional[GenerationConfig] = None,
381
  output_hidden_states: Optional[bool] = None,
382
- return_dict: Optional[bool] = None,
383
  **generate_kwargs,
384
  ) -> torch.LongTensor:
385
 
386
- assert self.img_context_token_id is not None
387
  if pixel_values is not None:
388
  if visual_features is not None:
389
  vit_embeds = visual_features
@@ -395,7 +312,7 @@ class Eagle2ChatModel(PreTrainedModel):
395
  input_embeds = input_embeds.reshape(B * N, C)
396
 
397
  input_ids = input_ids.reshape(B * N)
398
- selected = (input_ids == self.img_context_token_id)
399
  assert selected.sum() != 0
400
  input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
401
 
@@ -413,9 +330,28 @@ class Eagle2ChatModel(PreTrainedModel):
413
  )
414
 
415
  return outputs
416
-
 
417
  def get_input_embeddings(self):
418
  return self.language_model.get_input_embeddings()
419
-
 
 
 
 
 
420
  def get_output_embeddings(self):
421
- return self.language_model.get_output_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # --------------------------------------------------------
2
+ # NVIDIA
3
  # Copyright (c) 2025 NVIDIA
4
+ # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
 
7
  import warnings
8
+ import inspect
9
  from typing import Any, List, Optional, Tuple, Union
10
+ import torch
 
 
11
  from torch import nn
12
+ import torch.distributed as dist
13
  from torch.nn import CrossEntropyLoss
14
+ import torch.nn.functional as F
15
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
16
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
17
+ import torch.utils.checkpoint as cp
18
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
19
+ from peft import LoraConfig, get_peft_model
20
+ from transformers.generation import GenerationMixin
21
+ from transformers import GenerationConfig
22
  from transformers.modeling_outputs import CausalLMOutputWithPast
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.utils import ModelOutput, logging
25
+ from .configuration_eagle2_5_vl import Eagle2_5_VLConfig
26
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 
 
27
 
28
  logger = logging.get_logger(__name__)
 
 
 
 
29
 
 
 
 
30
 
31
+ # copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
32
+ EAGLE2_5_VL_START_DOCSTRING = r"""
33
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
34
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
35
+ etc.)
36
+
37
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
38
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
39
+ and behavior.
40
+
41
+ Parameters:
42
+ config ([`Eagle2_5_VLConfig`]):
43
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
44
+ load the weights associated with the model, only the configuration. Check out the
45
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
46
+ """
47
+
48
+ @add_start_docstrings(
49
+ "The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
50
+ EAGLE2_5_VL_START_DOCSTRING,
51
+ )
52
+ class Eagle2_5_VLPreTrainedModel(PreTrainedModel):
53
+ config_class = Eagle2_5_VLConfig
54
+ base_model_prefix = "model"
55
+ main_input_name = 'input_ids'
56
+ supports_gradient_checkpointing = True
57
+ _no_split_modules = ["Qwen2DecoderLayer", "LlamaDecoderLayer" ,"Siglip2EncoderLayer", "SiglipEncoderLayer"]
58
+ _skip_keys_device_placement = "past_key_values"
59
  _supports_flash_attn_2 = True
60
+ _supports_cache_class = True
61
+ _supports_static_cache = True
62
+ _supports_quantized_cache = True
63
  _supports_sdpa = True
 
 
 
 
 
64
 
65
+ def _init_weights(self, module):
66
+ std = self.config.initializer_range
67
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
68
+ module.weight.data.normal_(mean=0.0, std=std)
69
+ if module.bias is not None:
70
+ module.bias.data.zero_()
71
+ elif isinstance(module, nn.Embedding):
72
+ module.weight.data.normal_(mean=0.0, std=std)
73
+ if module.padding_idx is not None:
74
+ module.weight.data[module.padding_idx].zero_()
75
+
76
+
77
+ class Eagle2_5_VLForConditionalGeneration(Eagle2_5_VLPreTrainedModel, GenerationMixin):
78
+ config_class = Eagle2_5_VLConfig
79
+ def __init__(self, config: Eagle2_5_VLConfig, vision_model=None, language_model=None):
80
  super().__init__(config)
81
 
82
  image_size = config.force_image_size or config.vision_config.image_size
 
83
  patch_size = config.vision_config.patch_size
84
  self.patch_size = patch_size
85
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
86
 
87
  self.select_layer = config.select_layer
 
88
  self.downsample_ratio = config.downsample_ratio
89
+ self.loss_version = config.loss_version
90
+ self.mlp_checkpoint = config.mlp_checkpoint
91
+
92
  logger.info(f'num_image_token: {self.num_image_token}')
93
+ logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}')
94
  if vision_model is not None:
95
  self.vision_model = vision_model
96
  else:
97
  if config.vision_config.model_type == 'siglip_vision_model':
98
+ config.vision_config._attn_implementation = 'flash_attention_2'
 
99
  self.vision_model = SiglipVisionModel(config.vision_config)
100
+ else:
101
+ raise NotImplementedError(f'{config.vision_config.model_type} is not implemented.')
102
 
103
  if language_model is not None:
104
  self.language_model = language_model
105
  else:
106
+ if config.text_config.architectures[0] == 'LlamaForCausalLM':
107
+ self.language_model = LlamaForCausalLM(config.text_config)
108
+ elif config.text_config.architectures[0] == 'Qwen2ForCausalLM':
109
+ # assert config.text_config._attn_implementation == 'flash_attention_2', f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
110
+ self.language_model = Qwen2ForCausalLM(config.text_config)
111
  else:
112
+ raise NotImplementedError(f'{config.text_config.architectures[0]} is not implemented.')
113
 
114
  vit_hidden_size = config.vision_config.hidden_size
115
+ llm_hidden_size = config.text_config.hidden_size
 
116
 
117
  self.mlp1 = nn.Sequential(
118
  nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
 
120
  nn.GELU(),
121
  nn.Linear(llm_hidden_size, llm_hidden_size)
122
  )
123
+ self.image_token_index = config.image_token_index
124
+ self.neftune_alpha = None
125
+
126
 
127
  if config.use_backbone_lora:
128
  self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
129
 
130
+ self.use_llm_lora = config.use_llm_lora
131
  if config.use_llm_lora:
132
  self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
133
+
134
+ self.check_forward_kwargs()
135
+
136
+ def check_forward_kwargs(self):
137
+ # We intentionally avoid using **kwargs in forward because Hugging Face Transformers
138
+ # has special handling for functions with **kwargs parameters that would affect
139
+ # how our model is processed during training and inference.
140
+ forward_params = inspect.signature(self.forward).parameters
141
+ assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
142
+
143
+
144
  def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
145
  lora_config = LoraConfig(
146
  r=r,
147
+ target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj',
148
+ 'mlp.fc1', 'mlp.fc2'],
149
  lora_alpha=lora_alpha,
150
  lora_dropout=lora_dropout,
151
  )
152
  self.vision_model = get_peft_model(self.vision_model, lora_config)
153
  self.vision_model.print_trainable_parameters()
154
 
155
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
156
  lora_config = LoraConfig(
157
  r=r,
158
  target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
 
164
  self.language_model = get_peft_model(self.language_model, lora_config)
165
  self.language_model.enable_input_require_grads()
166
  self.language_model.print_trainable_parameters()
167
+ self.use_llm_lora = True
168
+
169
  def forward(
170
  self,
171
  pixel_values: torch.FloatTensor,
 
179
  output_attentions: Optional[bool] = None,
180
  output_hidden_states: Optional[bool] = None,
181
  return_dict: Optional[bool] = None,
182
+ num_tiles_list: Optional[List[torch.Tensor]] = None,
183
  ) -> Union[Tuple, CausalLMOutputWithPast]:
184
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185
 
 
201
  print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
202
 
203
  input_ids = input_ids.reshape(B * N)
204
+ selected = (input_ids == self.image_token_index)
205
  try:
206
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
207
  except Exception as e:
 
258
  # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
259
  x = x.view(n, int(h * scale_factor), int(w * scale_factor),
260
  int(c / (scale_factor * scale_factor)))
261
+
262
  x = x.permute(0, 2, 1, 3).contiguous()
263
  return x
264
 
265
  def extract_feature(self, pixel_values):
 
 
 
266
  if self.select_layer == -1:
267
  vit_embeds = self.vision_model(
268
  pixel_values=pixel_values,
269
  output_hidden_states=False,
270
  return_dict=True)
 
271
  if hasattr(vit_embeds, 'last_hidden_state'):
272
  vit_embeds = vit_embeds.last_hidden_state
273
+
274
  else:
275
  vit_embeds = self.vision_model(
276
  pixel_values=pixel_values,
277
  output_hidden_states=True,
278
  return_dict=True).hidden_states[self.select_layer]
279
+
 
 
 
 
 
 
 
 
280
  h = w = int(vit_embeds.shape[1] ** 0.5)
281
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
 
282
  vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
283
  vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
284
+ if self.mlp_checkpoint and vit_embeds.requires_grad:
285
+ vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
286
+ else:
287
+ vit_embeds = self.mlp1(vit_embeds)
288
 
289
  return vit_embeds
290
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  @torch.no_grad()
292
  def generate(
293
  self,
 
297
  visual_features: Optional[torch.FloatTensor] = None,
298
  generation_config: Optional[GenerationConfig] = None,
299
  output_hidden_states: Optional[bool] = None,
300
+ image_sizes: Optional[List[Tuple[int, int]]] = None,
301
  **generate_kwargs,
302
  ) -> torch.LongTensor:
303
 
 
304
  if pixel_values is not None:
305
  if visual_features is not None:
306
  vit_embeds = visual_features
 
312
  input_embeds = input_embeds.reshape(B * N, C)
313
 
314
  input_ids = input_ids.reshape(B * N)
315
+ selected = (input_ids == self.config.image_token_index)
316
  assert selected.sum() != 0
317
  input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
318
 
 
330
  )
331
 
332
  return outputs
333
+
334
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
335
  def get_input_embeddings(self):
336
  return self.language_model.get_input_embeddings()
337
+
338
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
339
+ def set_input_embeddings(self, value):
340
+ self.language_model.set_input_embeddings(value)
341
+
342
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
343
  def get_output_embeddings(self):
344
+ return self.language_model.get_output_embeddings()
345
+
346
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
347
+ def set_output_embeddings(self, new_embeddings):
348
+ self.language_model.set_output_embeddings(new_embeddings)
349
+
350
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
351
+ def set_decoder(self, decoder):
352
+ self.language_model.set_decoder(decoder)
353
+
354
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
355
+ def get_decoder(self):
356
+ return self.language_model.get_decoder()
357
+
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_eagle2_5_vl.Eagle2_5_VLProcessor",
4
+ "AutoImageProcessor": "image_processing_eagle2_5_vl_fast.Eagle2_5_VLImageProcessorFast"
5
+ },
6
+ "image_processor_type": "Eagle2_5_VLImageProcessorFast",
7
+ "processor_class": "Eagle2_5_VLProcessor",
8
+ "image_mean": [0.5, 0.5, 0.5],
9
+ "image_std": [0.5, 0.5, 0.5],
10
+ "do_resize": false,
11
+ "size": {
12
+ "height": 448,
13
+ "width": 448
14
+ },
15
+ "max_dynamic_tiles": 12,
16
+ "min_dynamic_tiles": 1,
17
+ "tokens_per_tile": 256,
18
+ "use_thumbnail": true,
19
+ "do_rescale": true,
20
+ "do_normalize": true,
21
+ "do_pad": false,
22
+ "do_convert_rgb": true
23
+ }
processing_eagle2_5_vl.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Eagle2_5_VL.
17
+ copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
18
+ """
19
+
20
+ import math
21
+ import os
22
+ from typing import Iterable, List, Union, Literal
23
+ import base64
24
+ import sys
25
+ import time
26
+ import warnings
27
+ from functools import lru_cache
28
+ from io import BytesIO
29
+ import re
30
+ import requests
31
+ import torch
32
+ import torchvision
33
+ from packaging import version
34
+ from PIL import Image
35
+ from torchvision import io
36
+ from typing import Optional, Any
37
+ import numpy as np
38
+
39
+ from transformers.feature_extraction_utils import BatchFeature
40
+ from transformers.image_processing_utils import select_best_resolution
41
+ from transformers.image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
42
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
43
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
44
+ from transformers.utils import logging
45
+ from transformers.models.auto import AutoImageProcessor
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+
51
+ FRAME_FACTOR = 2
52
+ FPS = 2.0
53
+ FPS_MIN_FRAMES = 4
54
+ FPS_MAX_FRAMES = 256
55
+
56
+
57
+
58
+ def adjust_by_factor(number: int, factor: int, method: Literal['round', 'ceil', 'floor'] = 'round') -> int:
59
+ """Adjusts 'number' to the nearest, ceiling, or floor multiple of 'factor'."""
60
+ op = {'round': round, 'ceil': math.ceil, 'floor': math.floor}[method]
61
+ return op(number / factor) * factor
62
+
63
+
64
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
65
+ if pil_image.mode == 'RGBA':
66
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
67
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
68
+ return white_background
69
+ else:
70
+ return pil_image.convert("RGB")
71
+
72
+
73
+ def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
74
+ if "image" in ele:
75
+ image = ele["image"]
76
+ else:
77
+ image = ele["image_url"]
78
+ image_obj = None
79
+ if isinstance(image, Image.Image):
80
+ image_obj = image
81
+ elif image.startswith("http://") or image.startswith("https://"):
82
+ response = requests.get(image, stream=True)
83
+ image_obj = Image.open(BytesIO(response.content))
84
+ elif image.startswith("file://"):
85
+ image_obj = Image.open(image[7:])
86
+ elif image.startswith("data:image"):
87
+ if "base64," in image:
88
+ _, base64_data = image.split("base64,", 1)
89
+ data = base64.b64decode(base64_data)
90
+ image_obj = Image.open(BytesIO(data))
91
+ else:
92
+ image_obj = Image.open(image)
93
+ if image_obj is None:
94
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
95
+ image = to_rgb(image_obj)
96
+ if 'scale_factor' in ele:
97
+ scale_factor = ele['scale_factor']
98
+ image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
99
+ return image
100
+
101
+
102
+ def smart_nframes(
103
+ ele: dict,
104
+ total_frames: int,
105
+ video_fps: int | float,
106
+ ) -> int:
107
+ """calculate the number of frames for video used for model inputs.
108
+
109
+ Args:
110
+ ele (dict): a dict contains the configuration of video.
111
+ support either `fps` or `nframes`:
112
+ - nframes: the number of frames to extract for model inputs.
113
+ - fps: the fps to extract frames for model inputs.
114
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
115
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
116
+ total_frames (int): the original total number of frames of the video.
117
+ video_fps (int | float): the original fps of the video.
118
+
119
+ Raises:
120
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
121
+
122
+ Returns:
123
+ int: the number of frames for video used for model inputs.
124
+ """
125
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
126
+ if "nframes" in ele:
127
+ nframes = adjust_by_factor(ele["nframes"], FRAME_FACTOR, method='round')
128
+ else:
129
+ fps = ele.get("fps", FPS)
130
+ min_frames = adjust_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR, method='ceil')
131
+ max_frames = adjust_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR, method='floor')
132
+ nframes = total_frames / video_fps * fps
133
+ if nframes > total_frames:
134
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
135
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
136
+ nframes = adjust_by_factor(nframes, FRAME_FACTOR, method='floor')
137
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
138
+ nframes = total_frames
139
+ # raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
140
+ return nframes
141
+
142
+
143
+ def _read_video_torchvision(
144
+ ele: dict,
145
+ ) -> (torch.Tensor, float, list):
146
+ """read video using torchvision.io.read_video and return also per-frame timestamps"""
147
+ video_path = ele["video"]
148
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
149
+ if "http://" in video_path or "https://" in video_path:
150
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
151
+ if "file://" in video_path:
152
+ video_path = video_path[7:]
153
+ st = time.time()
154
+ video, audio, info = io.read_video(
155
+ video_path,
156
+ start_pts=ele.get("video_start", 0.0),
157
+ end_pts=ele.get("video_end", None),
158
+ pts_unit="sec",
159
+ output_format="TCHW",
160
+ )
161
+ total_frames, video_fps = video.size(0), info["video_fps"]
162
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
163
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
164
+ # Calculate frame indices and corresponding timestamps (based on video start time)
165
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
166
+ start_time = ele.get("video_start", 0.0)
167
+ timestamps = (start_time + idx.to(torch.float32) / video_fps).tolist()
168
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
169
+ video = video[idx]
170
+ return video, sample_fps, timestamps
171
+
172
+
173
+
174
+ def is_decord_available() -> bool:
175
+ import importlib.util
176
+
177
+ return importlib.util.find_spec("decord") is not None
178
+
179
+ def _read_video_decord(
180
+ ele: dict,
181
+ ) -> (torch.Tensor, float, list):
182
+ """read video using decord.VideoReader and return also per-frame timestamps"""
183
+ import decord
184
+ video_path = ele["video"]
185
+ st = time.time()
186
+ vr = decord.VideoReader(video_path)
187
+ if 'video_start' in ele or 'video_end' in ele:
188
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
189
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
190
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
191
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
192
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
193
+ start_time = ele.get("video_start", 0.0) # TODO:
194
+ timestamps = [start_time + i / video_fps for i in idx]
195
+ video = vr.get_batch(idx).asnumpy()
196
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
197
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
198
+ return video, sample_fps, timestamps
199
+
200
+
201
+ VIDEO_READER_BACKENDS = {
202
+ "decord": _read_video_decord,
203
+ "torchvision": _read_video_torchvision,
204
+ }
205
+
206
+
207
+ @lru_cache(maxsize=1)
208
+ def get_video_reader_backend() -> str:
209
+ if is_decord_available():
210
+ video_reader_backend = "decord"
211
+ else:
212
+ video_reader_backend = "torchvision"
213
+ return video_reader_backend
214
+
215
+
216
+
217
+
218
+ def fetch_video(ele: dict, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
219
+ if isinstance(ele["video"], str):
220
+ video_reader_backend = get_video_reader_backend()
221
+ try:
222
+ video, sample_fps, timestamps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
223
+ except Exception as e:
224
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
225
+ video, sample_fps, timestamps = VIDEO_READER_BACKENDS["torchvision"](ele)
226
+
227
+ nframes, _, height, width = video.shape
228
+
229
+ if return_video_sample_fps:
230
+ return video, sample_fps, timestamps
231
+ return video
232
+ else:
233
+ assert isinstance(ele["video"], (list, tuple))
234
+ process_info = ele.copy()
235
+ process_info.pop("type", None)
236
+ process_info.pop("video", None)
237
+ images = [
238
+ fetch_image({"image": video_element, **process_info})
239
+ for video_element in ele["video"]
240
+ ]
241
+ nframes = adjust_by_factor(len(images), FRAME_FACTOR, method='ceil')
242
+ if len(images) < nframes:
243
+ images.extend([images[-1]] * (nframes - len(images)))
244
+
245
+ timestamps = [-1 for i in range(nframes)] # not sure about this
246
+ if return_video_sample_fps:
247
+ return images, process_info.pop("fps", 2.0), timestamps
248
+ return images
249
+
250
+ class Eagle2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
251
+ # see processing_utils.ProcessingKwargs documentation for usage.
252
+ _defaults = {
253
+ "text_kwargs": {
254
+ "padding": False,
255
+ },
256
+ "images_kwargs": {},
257
+ "videos_kwargs": {"max_dynamic_tiles": 1},
258
+ }
259
+
260
+
261
+ class Eagle2_5_VLProcessor(ProcessorMixin):
262
+ r"""
263
+ Constructs a Eagle2_5_VL processor which wraps a Eagle2_5_VL video processor, Eagle2_5_VL image processor and a Eagle2_5_VL tokenizer into a single processor.
264
+
265
+ [`Eagle2_5_VLProcessor`] offers all the functionalities of [`Eagle2_5_VLVideoProcessor`], [`Eagle2_5_VLImageProcessor`] and [`Eagle2_5_VLTokenizer`]. See the
266
+ [`~Eagle2_5_VLVideoProcessor.__call__`], [`~Eagle2_5_VLProcessor.__call__`] and [`~Eagle2_5_VLProcessor.decode`] for more information.
267
+
268
+ Args:
269
+ image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
270
+ The image processor is a required input.
271
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
272
+ The tokenizer is a required input.
273
+ num_image_tokens (`int`, *optional*):
274
+ Number of image tokens for one imagethat will be returned by vision tower.
275
+ vision_feature_select_strategy (`str`, *optional*):
276
+ The feature selection strategy used to select the vision feature from the vision backbone.
277
+ Shoudl be same as in model's config
278
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
279
+ in a chat into a tokenizable string.
280
+ image_token (`str`, *optional*, defaults to `"<image>"`):
281
+ Special token used to denote image location.
282
+ video_token (`str`, *optional*, defaults to `"<video>"`):
283
+ Special token used to denote video location.
284
+ """
285
+
286
+ attributes = ["image_processor", "tokenizer"]
287
+ valid_kwargs = [
288
+ "chat_template",
289
+ "num_image_tokens",
290
+ "vision_feature_select_strategy",
291
+ "image_token",
292
+ "video_token",
293
+ "images_kwargs",
294
+ "videos_kwargs",
295
+ "text_kwargs",
296
+ ]
297
+ image_processor_class = "AutoImageProcessor"
298
+ tokenizer_class = "AutoTokenizer"
299
+
300
+ def __init__(
301
+ self,
302
+ image_processor=None,
303
+ tokenizer=None,
304
+ vision_feature_select_strategy=None,
305
+ chat_template=None,
306
+ image_token='<IMG_CONTEXT>',
307
+ video_token='<IMG_CONTEXT>',
308
+ tokens_per_tile=256,
309
+ image_placeholder='image',
310
+ video_placeholder='video',
311
+ image_start_token='<img>',
312
+ image_end_token='</img>',
313
+ **kwargs,
314
+ ):
315
+ self.vision_feature_select_strategy = vision_feature_select_strategy
316
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
317
+ self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
318
+ self.image_token_id = (
319
+ tokenizer.image_token_id
320
+ if getattr(tokenizer, "image_token_id", None)
321
+ else tokenizer.convert_tokens_to_ids(self.image_token)
322
+ )
323
+ self.video_token_id = (
324
+ tokenizer.video_token_id
325
+ if getattr(tokenizer, "video_token_id", None)
326
+ else tokenizer.convert_tokens_to_ids(self.video_token)
327
+ )
328
+ self.image_placeholder = image_placeholder
329
+ self.video_placeholder = video_placeholder
330
+ self.tokens_per_tile = tokens_per_tile
331
+ self.image_start_token = image_start_token
332
+ self.image_end_token = image_end_token
333
+ if 'auto_map' in kwargs:
334
+ self.auto_map = kwargs['auto_map']
335
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
336
+
337
+
338
+ def replace_media_placeholder(self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs):
339
+
340
+ num_of_images_in_this_sample = 0
341
+ num_of_videos_in_this_sample = 0
342
+ # Regular expression pattern to match formats like <image-1> or <video-2>
343
+ pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
344
+ unified_frame_list = []
345
+
346
+ image_min_dynamic_tiles = output_kwargs['images_kwargs'].get("min_dynamic_tiles", self.image_processor.min_dynamic_tiles)
347
+ image_max_dynamic_tiles = output_kwargs['images_kwargs'].get("max_dynamic_tiles", self.image_processor.max_dynamic_tiles)
348
+ image_use_thumbnail = output_kwargs['images_kwargs'].get("use_thumbnail", self.image_processor.use_thumbnail)
349
+ video_min_dynamic_tiles = output_kwargs['videos_kwargs'].get("min_dynamic_tiles", self.image_processor.min_dynamic_tiles)
350
+ video_max_dynamic_tiles = output_kwargs['videos_kwargs'].get("max_dynamic_tiles", self.image_processor.max_dynamic_tiles)
351
+ video_use_thumbnail = output_kwargs['videos_kwargs'].get("use_thumbnail", self.image_processor.use_thumbnail)
352
+
353
+ tile_size = self.image_processor.size.get("height", 448)
354
+
355
+
356
+ # Function to replace tags in a single text
357
+ def replace_in_text(text):
358
+ # repl callback function for each match replacement operation
359
+ def repl(match):
360
+ nonlocal unified_frame_list
361
+ nonlocal num_of_images_in_this_sample
362
+ nonlocal num_of_videos_in_this_sample
363
+ media_type = match.group(1) # 'image' or 'video'
364
+ idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
365
+ # Select the corresponding path based on media type
366
+ idx_mapper = {0: "first", 1: "second", 2: "third", 3: "fourth", 4: "fifth", 5: "sixth", 6: "seventh", 7: "eighth", 8: "ninth", 9: "tenth"}
367
+ if media_type == 'image':
368
+ image_inputs = self.image_processor(images=[image_list[idx_in_list]], videos=None, **output_kwargs["images_kwargs"])
369
+ num_all_tiles = image_inputs["pixel_values"].shape[0]
370
+ special_placeholder = f"<image {idx_in_list+1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
371
+ unified_frame_list.append(image_inputs)
372
+ num_of_images_in_this_sample += 1
373
+
374
+ elif media_type == 'video':
375
+ video_inputs = self.image_processor(images=None, videos=[video_list[idx_in_list]], **output_kwargs["videos_kwargs"])
376
+ num_all_tiles = video_inputs["pixel_values"].shape[0]
377
+ image_sizes = video_inputs["image_sizes"]
378
+ if timestamps_list is not None and -1 not in timestamps_list:
379
+ frame_timestamps = timestamps_list[idx_in_list]
380
+ else:
381
+ frame_timestamps = None
382
+ sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
383
+
384
+ num_of_tiles_each_frame = [
385
+ self.get_number_tiles_based_on_image_size(image_size, video_min_dynamic_tiles, video_max_dynamic_tiles, video_use_thumbnail, tile_size)
386
+ for image_size in image_sizes
387
+ ]
388
+ assert sum(num_of_tiles_each_frame) == num_all_tiles, f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
389
+
390
+ if frame_timestamps is not None:
391
+ assert len(frame_timestamps) == len(num_of_tiles_each_frame), f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
392
+ special_placeholder = [f"Frame {i+1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}" for i, num_of_tiles in enumerate(num_of_tiles_each_frame)]
393
+ else:
394
+ special_placeholder = [f"Frame {i+1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}" for i, num_of_tiles in enumerate(num_of_tiles_each_frame)]
395
+
396
+ if sampled_fps is not None:
397
+ special_placeholder = f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: " + "".join(special_placeholder)
398
+ else:
399
+ special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(special_placeholder)
400
+ unified_frame_list.append(video_inputs)
401
+ num_of_videos_in_this_sample += 1
402
+ else:
403
+ raise ValueError(f'Unknown media type: {media_type}')
404
+ return special_placeholder
405
+ return pattern.sub(repl, text)
406
+ text = replace_in_text(text)
407
+ if len(unified_frame_list) > 0:
408
+ pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
409
+ image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
410
+ else:
411
+ pixel_values = None
412
+ image_sizes = None
413
+ return text, pixel_values, image_sizes, num_of_images_in_this_sample, num_of_videos_in_this_sample
414
+
415
+ def __call__(
416
+ self,
417
+ images: ImageInput = None,
418
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
419
+ audio=None,
420
+ videos: VideoInput = None,
421
+ **kwargs: Unpack[Eagle2_5_VLProcessorKwargs],
422
+ ) -> BatchFeature:
423
+ """
424
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
425
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
426
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
427
+ LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
428
+ of the above two methods for more information.
429
+
430
+ Args:
431
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
432
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
433
+ tensor. Both channels-first and channels-last formats are supported.
434
+ text (`str`, `List[str]`, `List[List[str]]`):
435
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
436
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
437
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
438
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
439
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
440
+
441
+ Returns:
442
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
443
+
444
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
445
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
446
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
447
+ `None`).
448
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
449
+ - **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
450
+ - **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
451
+ """
452
+
453
+
454
+ output_kwargs = self._merge_kwargs(
455
+ Eagle2_5_VLProcessorKwargs,
456
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
457
+ **kwargs,
458
+ )
459
+
460
+ if isinstance(text, str):
461
+ text_list = [text]
462
+ elif not isinstance(text, list) and not isinstance(text[0], str):
463
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
464
+ elif isinstance(text, list) and isinstance(text[0], str):
465
+ text_list = text
466
+
467
+ if images is None: images = []
468
+ if videos is None: videos = []
469
+
470
+ pixel_values_list = []
471
+ image_sizes_list = []
472
+ new_sample_list = []
473
+ image_start_idx = 0
474
+ video_start_idx = 0
475
+ timestamps_batch = output_kwargs['videos_kwargs'].pop("timestamps", None)
476
+ fps_batch = output_kwargs['videos_kwargs'].pop("fps", None)
477
+ for sample in text_list:
478
+ timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
479
+ fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
480
+ sample, pixel_values, image_sizes, num_of_images_in_this_sample, num_of_videos_in_this_sample = self.replace_media_placeholder(sample, images[image_start_idx:], videos[video_start_idx:], timestamps_list, fps_list, **output_kwargs)
481
+ new_sample_list.append(sample)
482
+ if pixel_values is not None:
483
+ pixel_values_list.append(pixel_values)
484
+ image_sizes_list.append(image_sizes)
485
+ image_start_idx += num_of_images_in_this_sample
486
+ video_start_idx += num_of_videos_in_this_sample
487
+
488
+ if len(pixel_values_list) > 0:
489
+ image_inputs = {"pixel_values": torch.cat(pixel_values_list), "image_sizes": torch.cat(image_sizes_list)}
490
+ else:
491
+ image_inputs = {}
492
+ video_inputs = {}
493
+ text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
494
+ return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
495
+
496
+ def get_number_tiles_based_on_image_size(self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int) -> int:
497
+ """
498
+ Get the number of tiles based on the image size.
499
+ """
500
+ orig_height, orig_width = image_size
501
+ aspect_ratio = orig_width / orig_height
502
+ # calculate the existing image aspect ratio
503
+ target_ratios = set(
504
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
505
+ i * j <= max_num and i * j >= min_num)
506
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
507
+
508
+ # find the closest aspect ratio to the target
509
+ target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
510
+ aspect_ratio, target_ratios, orig_width, orig_height, tile_size)
511
+ tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
512
+ if use_thumbnail and tiles_num > 1:
513
+ tiles_num += 1
514
+ return tiles_num
515
+
516
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
517
+ def batch_decode(self, *args, **kwargs):
518
+ """
519
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
520
+ refer to the docstring of this method for more information.
521
+ """
522
+ return self.tokenizer.batch_decode(*args, **kwargs)
523
+
524
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
525
+ def decode(self, *args, **kwargs):
526
+ """
527
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
528
+ the docstring of this method for more information.
529
+ """
530
+ return self.tokenizer.decode(*args, **kwargs)
531
+
532
+ @property
533
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
534
+ def model_input_names(self):
535
+ tokenizer_input_names = self.tokenizer.model_input_names
536
+ image_processor_input_names = self.image_processor.model_input_names
537
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
538
+
539
+ # override to save video-config in a separate config file
540
+ def save_pretrained(self, save_directory, **kwargs):
541
+ if os.path.isfile(save_directory):
542
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
543
+ os.makedirs(save_directory, exist_ok=True)
544
+
545
+ outputs = super().save_pretrained(save_directory, **kwargs)
546
+ return outputs
547
+
548
+ # override to load video-config from a separate config file
549
+ @classmethod
550
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
551
+ processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
552
+
553
+ # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
554
+ if isinstance(processor, tuple):
555
+ processor = processor[0]
556
+ return processor
557
+
558
+ # Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
559
+ def process_vision_info(
560
+ self,
561
+ conversations: list[dict] | list[list[dict]],
562
+ return_video_kwargs: bool = False,
563
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
564
+
565
+ vision_infos = self.extract_vision_info(conversations)
566
+ ## Read images or videos
567
+ image_inputs = []
568
+ video_inputs = []
569
+ video_sample_fps_list = []
570
+ video_timestamps_list = []
571
+ for vision_info in vision_infos:
572
+ if "image" in vision_info or "image_url" in vision_info:
573
+ image_inputs.append(fetch_image(vision_info))
574
+ elif "video" in vision_info:
575
+ video_input, video_sample_fps, video_timestamps = fetch_video(vision_info, return_video_sample_fps=True)
576
+ video_sample_fps_list.append(video_sample_fps)
577
+ video_inputs.append(video_input)
578
+ video_timestamps_list.append(video_timestamps)
579
+ else:
580
+ raise ValueError("image, image_url or video should in content.")
581
+ if len(image_inputs) == 0:
582
+ image_inputs = None
583
+ if len(video_inputs) == 0:
584
+ video_inputs = None
585
+ if return_video_kwargs:
586
+ return image_inputs, video_inputs, {'fps': video_sample_fps_list, 'timestamps': video_timestamps_list}
587
+ return image_inputs, video_inputs
588
+
589
+ def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
590
+ vision_infos = []
591
+ if isinstance(conversations[0], dict):
592
+ conversations = [conversations]
593
+ for conversation in conversations:
594
+ for message in conversation:
595
+ if isinstance(message["content"], list):
596
+ for ele in message["content"]:
597
+ if (
598
+ "image" in ele
599
+ or "image_url" in ele
600
+ or "video" in ele
601
+ or ele["type"] in ("image", "image_url", "video")
602
+ ):
603
+ vision_infos.append(ele)
604
+ return vision_infos
605
+
606
+ def py_apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
607
+ """
608
+ Renders a chat conversation using a custom template with verification of tokens.
609
+
610
+ The purpose is to check for the existence of tokens like "<image-1>" or "<video-1>"
611
+ in the message text and skip adding them if they already exist.
612
+
613
+ Args:
614
+ messages (list): A list of message dictionaries. Each message should contain:
615
+ - 'role': The role of the speaker (e.g., 'system', 'user', 'assistant').
616
+ - 'content': Either a string or a list of content blocks. In the list each block may contain:
617
+ * 'type': The type of content, such as 'image' or 'video'.
618
+ * 'text': The actual text if present.
619
+ * Other keys such as 'image', 'image_url', or 'video'.
620
+ add_generation_prompt (bool): If True, appends "<|im_start|>assistant" at the end of the rendered string.
621
+ tokenize (bool): If True, tokenize the rendered string.
622
+ Returns:
623
+ str: The final rendered chat string according to the specified template.
624
+ """
625
+ assert tokenize == False, "tokenize is not supported yet"
626
+ result = ""
627
+ image_count = 0
628
+ video_count = 0
629
+
630
+ message_text = ""
631
+ for idx, message in enumerate(messages):
632
+ if message.get('role') != 'user': continue
633
+ # If content is a string, simply output it.
634
+ content = message.get('content')
635
+ if isinstance(content, str):
636
+ message_text += content
637
+ elif isinstance(content, list):
638
+ # Process each content item.
639
+ for item in content:
640
+ # If the block is a dictionary and contains text, add it to message_text.
641
+ if isinstance(item, dict) and "text" in item:
642
+ message_text += item["text"]
643
+ # If an item is already a string in the list, add it directly.
644
+ elif isinstance(item, str):
645
+ message_text += item
646
+
647
+ for idx, message in enumerate(messages):
648
+ # If the first message is not from the system, prepend a default system message.
649
+ if idx == 0 and message.get('role') != 'system':
650
+ result += "<|im_start|>system\n"
651
+ result += "You are a helpful assistant.\n"
652
+ result += "<|im_end|>\n"
653
+
654
+ # Start the current message block with its role.
655
+ result += f"<|im_start|>{message.get('role', '')}\n"
656
+ content = message.get('content')
657
+
658
+ # If content is a string, simply output it.
659
+ if isinstance(content, str):
660
+ result += content
661
+ result += "<|im_end|>\n"
662
+ else:
663
+ # Process each content item.
664
+ for item in content:
665
+ # Check if the item is an image (explicitly by type or by key presence).
666
+ if (isinstance(item, dict) and (item.get('type') == 'image' or 'image' in item or 'image_url' in item)):
667
+ image_count += 1
668
+ candidate_token = f"<image-{image_count}>"
669
+ # Only add the token if it is not already present in the collected text.
670
+ if candidate_token not in message_text:
671
+ result += candidate_token
672
+ # Check if the item is a video.
673
+ elif (isinstance(item, dict) and (item.get('type') == 'video' or 'video' in item)):
674
+ video_count += 1
675
+ candidate_token = f"<video-{video_count}>"
676
+ # Only add the token if it is not already present.
677
+ if candidate_token not in message_text:
678
+ result += candidate_token
679
+ # If the item contains text, add it.
680
+ elif isinstance(item, dict) and 'text' in item:
681
+ result += item['text']
682
+ # If the item is a string (and not handled already), add it.
683
+ elif isinstance(item, str):
684
+ result += item
685
+ result += "<|im_end|>\n"
686
+
687
+ # Optionally add assistant generation prompt at the end.
688
+ if add_generation_prompt:
689
+ result += "<|im_start|>assistant\n"
690
+
691
+ return result
692
+
693
+
694
+ @classmethod
695
+ def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
696
+ """
697
+ Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters.
698
+
699
+ Args:
700
+ processor_dict (`Dict[str, Any]`):
701
+ Dictionary that will be used to instantiate the processor object. Such a dictionary can be
702
+ retrieved from a pretrained checkpoint by leveraging the
703
+ [`~processing_utils.ProcessingMixin.to_dict`] method.
704
+ kwargs (`Dict[str, Any]`):
705
+ Additional parameters from which to initialize the processor object.
706
+
707
+ Returns:
708
+ [`~processing_utils.ProcessingMixin`]: The processor object instantiated from those
709
+ parameters.
710
+ """
711
+ processor_dict = processor_dict.copy()
712
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
713
+
714
+ # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
715
+ # If we don't pop, some specific kwargs will raise a warning
716
+ if "processor_class" in processor_dict:
717
+ del processor_dict["processor_class"]
718
+
719
+ #if "auto_map" in processor_dict:
720
+ # del processor_dict["auto_map"]
721
+
722
+ unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
723
+ processor = cls(*args, **processor_dict)
724
+
725
+ # Update processor with kwargs if needed
726
+ for key in set(kwargs.keys()):
727
+ if hasattr(processor, key):
728
+ setattr(processor, key, kwargs.pop(key))
729
+
730
+ kwargs.update(unused_kwargs)
731
+ logger.info(f"Processor {processor}")
732
+ if return_unused_kwargs:
733
+ return processor, kwargs
734
+ else:
735
+ return processor
736
+
737
+
738
+ __all__ = ["Eagle2_5_VLProcessor"]
processor_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_eagle2_5_vl.Eagle2_5_VLProcessor",
4
+ "AutoImageProcessor": "image_processing_eagle2_5_vl_fast.Eagle2_5_VLImageProcessorFast"
5
+ },
6
+ "image_end_token": "</img>",
7
+ "image_placeholder": "image",
8
+ "image_start_token": "<img>",
9
+ "image_token": "<IMG_CONTEXT>",
10
+ "processor_class": "Eagle2_5_VLProcessor",
11
+ "tokens_per_tile": 256,
12
+ "video_placeholder": "video",
13
+ "video_token": "<IMG_CONTEXT>",
14
+ "vision_feature_select_strategy": null
15
+ }