File size: 71,822 Bytes
42c6bee 41a5265 42c6bee 41a5265 42c6bee 41a5265 42c6bee 41a5265 42c6bee 41a5265 42c6bee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 |
import base64
import copy
import io
import math
import os
import uuid
from typing import Dict, List, Optional, Union
from urllib.parse import urlparse
import av
import cv2
import numpy as np
import requests
import torch
from decord import VideoReader, cpu
from PIL import Image, UnidentifiedImageError
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def determine_possible_resolutions(anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False):
"""
Finds and returns possible resolution combinations with a total number of grids less than or equal to max_num_grids.
For example, if max_num_grids is 4, the possible grid combinations are:
[1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1], and the resolutions are calculated accordingly.
Example:
>>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336)
>>> print(possible_resolutions)
[[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]]
Args:
anyres (bool): Whether to allow any resolution combinations up to the maximum grid count.
max_num_grids (int): The maximum number of grids allowed (height x width must be ≤ this value).
grid_size (int): The size of each grid in pixels (e.g., 336).
use_1x1_grid (bool, optional): Whether to include the 1x1 grid as a valid resolution. Defaults to False.
Returns:
List[List[int]]: A list of possible [height, width] resolution pairs.
"""
possible_resolutions = []
if anyres:
assert max_num_grids > 0
for i in range(1, max_num_grids + 1):
for j in range(1, max_num_grids + 1):
if i == 1 and j == 1 and not use_1x1_grid:
continue
if i * j <= max_num_grids:
possible_resolutions.append([i, j])
possible_resolutions = [[ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions]
return possible_resolutions
def divide_to_grids(image: np.array, grid_size: int, input_data_format=None) -> List[np.array]:
"""
Divides a local image into grids of size (grid_size x grid_size).
Args:
image (np.array): Input image as a NumPy array.
grid_size (int): The size (in pixels) of each square grid.
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
List[np.array]: A list of image patches, each of size (grid_size x grid_size).
"""
grids = []
height, width = get_image_size(image, channel_dim=input_data_format)
for i in range(0, height, grid_size):
for j in range(0, width, grid_size):
if input_data_format == ChannelDimension.LAST:
grid = image[i : i + grid_size, j : j + grid_size]
else:
grid = image[:, i : i + grid_size, j : j + grid_size]
grids.append(grid)
return grids
def pad(
image: np.array,
target_size: tuple,
background_color=(127, 127, 127),
input_data_format=None,
) -> np.array:
"""
Pads the input image on the sides (top/bottom and left/right) to match the target height and width.
Args:
image (np.array): Input image as a NumPy array.
target_size (tuple): Target size as (target_height, target_width).
background_color (tuple, optional): RGB color value used for padding. Defaults to (127, 127, 127).
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
np.array: The padded image with the specified target size.
"""
target_height, target_width = target_size
height, width = get_image_size(image, channel_dim=input_data_format)
# result = np.ones((target_height, target_width, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
paste_x = (target_width - width) // 2
paste_y = (target_height - height) // 2
result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
return result
def expand2square(
image: np.array,
bboxes_dict=None,
background_color=(127, 127, 127),
input_data_format=None,
) -> np.array:
"""
Expands the input image to a square shape by placing it at the center of a new square canvas,
with padding added to the shorter side (either top/bottom or left/right).
The image is always centered on the new canvas, and padding is applied symmetrically.
Args:
image (np.array): Input image as a NumPy array.
bboxes_dict (dict, optional): A dictionary of bounding boxes, where each value is an NDArray of shape (N, 4, 2)
with box coordinates in the format [[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]].
Supports multiple categories (e.g., "ocr", "html") simultaneously.
background_color (tuple, optional): RGB color to fill the padding area. Defaults to (127, 127, 127).
input_data_format (optional): Optional format specifier for image data (e.g., "channels_first" or "channels_last").
Returns:
np.array: A square-shaped image with the original image centered and padded as needed.
Example:
>>> _img = np.ones((80, 100), dtype=np.uint8) * 100
>>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
>>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
>>> _img.shape
(100, 100)
>>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
... [[40, 30], [50, 30], [50, 40], [40, 40]]])
>>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
True
"""
height, width = get_image_size(image, channel_dim=input_data_format)
if width == height:
return image, bboxes_dict
elif width > height:
# result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
if bboxes_dict is not None:
for key in bboxes_dict:
bboxes_dict[key][:, :, 1] += (width - height) // 2
return result, bboxes_dict
else:
# result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
if bboxes_dict is not None:
for key in bboxes_dict:
bboxes_dict[key][:, :, 0] += (height - width) // 2
return result, bboxes_dict
def resize_longside(
image: np.array,
size: int,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Resizes the image so that its longer side matches the specified size, maintaining the original aspect ratio.
Args:
image (np.array): Input image as a NumPy array.
size (int): Target size for the longer side of the image.
resample (PILImageResampling, optional): Resampling method to use during resizing. Defaults to BICUBIC.
data_format (str or ChannelDimension, optional): Output data format (e.g., "channels_first" or "channels_last").
input_data_format (str or ChannelDimension, optional): Input data format of the image.
Returns:
np.array: The resized image with its aspect ratio preserved.
"""
height, width = get_image_size(image, channel_dim=input_data_format)
if width == height:
target_height, target_width = size, size
elif width > height:
target_width = size
target_height = math.ceil(height / width * size)
else:
target_width = math.ceil(width / height * size)
target_height = size
return resize(
image,
size=(target_height, target_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
)
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
"""
Selects the best-fit resolution from a list of possible resolutions based on the original image size.
This function evaluates each resolution by computing its effective and wasted area compared to the original size.
The optimal resolution is the one that maximizes the effective area while minimizing unused (wasted) space.
Args:
original_size (tuple): The original image size in the format (height, width).
possible_resolutions (list): A list of candidate resolutions in the format [(height1, width1), (height2, width2), ...].
Returns:
tuple: The best-fit resolution in the format (height, width).
This function includes code adapted from the file image_processing_llava_next.py in the LLaVA-Next
project(https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py),
which is licensed under apache-2.0.
"""
original_height, original_width = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for height, width in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (height, width)
return best_fit
def _get_local_grids_output_size(image: np.array, target_resolution: tuple, input_data_format=None):
"""
Computes the number of local grids (patches) along the height and width when resizing an image
to the target resolution.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): Target resolution in the format (target_height, target_width).
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
tuple: A tuple (grid_h, grid_w) representing the number of grids along the height and width.
"""
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
def determine_anyres_num_vision_patches(
num_grids,
image_size,
grid_size,
patch_size,
possible_resolutions,
anyres=False,
unpad=True,
num_queries_vis_abstractor=0,
num_queries_vis_abstractor_slow=0,
is_video=False,
first_last_frames_slow=False, # sample-wise option
is_first_or_last_frames=False, # grid-wise option
):
"""
Computes the number of visual tokens (patches) based on image resolution, grid configuration, and patch size.
This function supports both fixed-size and any-resolution settings, as well as video-specific configurations
such as handling slow frames and frame position flags.
Args:
num_grids (int): Number of grids per image (e.g., 1 for 1x1, 4 for 2x2, etc.).
image_size (tuple): The original image size as (height, width).
grid_size (int): Size of each grid in pixels (e.g., 336).
patch_size (int): Size of each vision patch (e.g., 14 for ViT models).
possible_resolutions (list): List of possible resolution tuples [(h1, w1), (h2, w2), ...].
anyres (bool, optional): Whether to use any-resolution mode. Defaults to False.
unpad (bool, optional): Whether to unpad the image before computing patches. Defaults to True.
num_queries_vis_abstractor (int, optional): Number of query tokens for vision abstractor (fast path).
num_queries_vis_abstractor_slow (int, optional): Number of query tokens for vision abstractor (slow path).
is_video (bool, optional): Whether the input is a video. Defaults to False.
first_last_frames_slow (bool, optional): Whether to treat first/last video frames as "slow". Defaults to False.
is_first_or_last_frames (bool, optional): Whether current grid corresponds to first/last frame. Defaults to False.
Returns:
int: Total number of visual tokens (patches) after processing.
"""
if not anyres:
return num_queries_vis_abstractor if num_queries_vis_abstractor > 0 else (grid_size // patch_size) ** 2
if num_queries_vis_abstractor > 0:
num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
else:
num_patch_per_grid = grid_size // patch_size
num_global_per_grid = num_patch_per_grid
# In anyres mode, a global image is included, so there are always at least 2 grids.
# However, for video inputs, there is no global image, so it's possible to have only 1 grid.
# Therefore, the assertion below is commented out:
# assert num_grids > 1
# Compute the number of vision patches.
height, width = select_best_resolution(image_size, possible_resolutions)
num_patch_height = (height // grid_size) * num_patch_per_grid
num_patch_width = (width // grid_size) * num_patch_per_grid
# local images
if unpad:
original_height, original_width = image_size
original_aspect_ratio = original_width / original_height
current_aspect_ratio = num_patch_width / num_patch_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = num_patch_width / original_width
new_height = int(original_height * scale_factor)
padding = (num_patch_height - new_height) // 2
num_patch_height = num_patch_height - padding * 2
else:
scale_factor = num_patch_height / original_height
new_width = int(original_width * scale_factor)
padding = (num_patch_width - new_width) // 2
num_patch_width = num_patch_width - padding * 2
num_patches = num_patch_width * num_patch_height + num_patch_height
else:
num_patches = num_patch_width * num_patch_height
# In the "slow" strategy, when applying to first and last frames only, it is applied exclusively to those two frames.
if num_queries_vis_abstractor_slow > 0:
if first_last_frames_slow:
if is_first_or_last_frames:
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
else:
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
# The slowfast feature is only applicable when unpad is set to False.
assert unpad is False
# Global image is not included for video inputs.
if not is_video:
num_patches += num_global_per_grid**2
return num_patches
class HCXVisionProcessor(BaseImageProcessor):
r"""
Constructs a VLM image processor.
This processor is based on [`CLIPImageProcessor`] and incorporates additional techniques
for handling high-resolution images, such as flexible resolution support (`anyres`), unpadding,
square padding, and multi-grid patching strategies.
Args:
do_resize (bool): Whether to resize the image.
size (Dict[str, int], optional): Target size for resizing, typically with keys `"height"` and `"width"`.
anyres (bool): Whether to enable the any-resolution (`anyres`) feature, which allows flexible resolution handling via grid division.
unpad (bool): When `anyres` is enabled, whether to remove visual tokens corresponding to pure padding regions.
max_num_grids (int): Maximum number of grids allowed per image.
max_image_cnt (int): Maximum number of images that can be processed at once (used for batching).
num_queries_vis_abstractor (int): Number of visual query tokens per grid when using a visual resampler (e.g., Perceiver).
num_queries_vis_abstractor_video_fast (int): Number of visual queries for fast-path video frames.
num_queries_vis_abstractor_video_slow (int): Number of visual queries for slow-path video frames (e.g., first/last).
possible_resolutions (List): List of allowed resolution pairs when `anyres` is enabled. Example: [[336, 336], [336, 672], [672, 336]].
patch_size (int): Patch size for the Vision Transformer (ViT).
pad_to_square (bool): Whether to pad images to a square shape. If `False`, a center crop is applied to fit ViT input.
resample (PILImageResampling): Resampling method to use for resizing. Default is `BICUBIC`.
do_center_crop (bool): Whether to apply center cropping.
crop_size (Dict[str, int], optional): Size for center cropping.
do_rescale (bool): Whether to rescale pixel values.
rescale_factor (float or int): Factor to use for rescaling pixel values (typically `1/255`).
do_normalize (bool): Whether to normalize pixel values using `image_mean` and `image_std`.
image_mean (float or List[float], optional): Mean values for normalization. Can be a single float or list of floats per channel.
image_std (float or List[float], optional): Standard deviation values for normalization. Can be a single float or list of floats per channel.
do_convert_rgb (bool): Whether to convert the input image to RGB.
first_last_frames_slow (bool): Whether to treat the first and last frames of a video as “slow path” (processed differently).
Attributes:
model_input_names (List[str]): Names of the expected model inputs. Defaults to `["pixel_values"]`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
anyres: bool = False,
unpad: bool = False,
max_num_grids: int = 9,
max_image_cnt: int = 12,
num_queries_vis_abstractor: int = 0,
num_queries_vis_abstractor_video_fast: int = 0,
num_queries_vis_abstractor_video_slow: int = 0,
possible_resolutions: List = [],
patch_size: int = 14,
pad_to_square: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
first_last_frames_slow: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 512}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.anyres = anyres
self.unpad = unpad
self.max_num_grids = max_num_grids
self.max_image_cnt = max_image_cnt
self.num_queries_vis_abstractor = num_queries_vis_abstractor
self.num_queries_vis_abstractor_video_fast = num_queries_vis_abstractor_video_fast
self.num_queries_vis_abstractor_video_slow = num_queries_vis_abstractor_video_slow
self.possible_resolutions = [_resolution for _resolution in possible_resolutions]
self.patch_size = patch_size
self.pad_to_square = pad_to_square
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.first_last_frames_slow = first_last_frames_slow
assert self.crop_size["height"] == self.crop_size["width"]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resizes the input image to the specified target size.
Args:
image (np.ndarray): The input image to resize.
size (Dict[str, int]): A dictionary specifying the target size with keys `"height"` and `"width"`.
resample (PILImageResampling, optional): The resampling filter to use. Defaults to `BICUBIC`.
data_format (str or ChannelDimension, optional): The desired output data format (e.g., "channels_last").
input_data_format (str or ChannelDimension, optional): The input data format of the image.
**kwargs: Additional keyword arguments, if any.
Returns:
np.ndarray: The resized image as a NumPy array.
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Image.Image:
"""
Applies a sequence of preprocessing operations to the input image(s), including resizing, cropping, rescaling,
normalization, and format conversion.
This method is typically used internally to prepare images for model input.
Args:
images (ImageInput): A single image or a batch of images to preprocess.
do_resize (bool, optional): Whether to resize the image(s).
size (Dict[str, int], optional): Target size for resizing, with keys `"height"` and `"width"`.
resample (PILImageResampling, optional): Resampling method to use for resizing.
do_center_crop (bool, optional): Whether to apply center cropping.
crop_size (int, optional): Size of the center crop (applied to both height and width).
do_rescale (bool, optional): Whether to rescale the image pixel values.
rescale_factor (float, optional): Factor to use when rescaling pixel values (e.g., 1/255).
do_normalize (bool, optional): Whether to normalize the image using `image_mean` and `image_std`.
image_mean (float or List[float], optional): Mean value(s) used for normalization.
image_std (float or List[float], optional): Standard deviation value(s) used for normalization.
data_format (ChannelDimension, optional): The desired output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (str or ChannelDimension, optional): The format of the input image(s).
Returns:
Image.Image: The preprocessed image or batch of images, ready for model input.
"""
images = make_list_of_images(images)
if do_resize:
images = [
self.resize(
image=image,
size=size,
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
if do_center_crop:
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale:
images = [
self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
for image in images
]
if do_normalize:
images = [
self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
return images
def _resize_for_local_grids(
self,
image: np.array,
target_resolution: tuple,
resample,
input_data_format: ChannelDimension,
) -> np.array:
"""
Resizes the image to the given target resolution for use in local grid processing.
This function ensures that the image is properly resized to match the (height, width) specified
in `target_resolution`, using the provided resampling method. It supports channel-first and
channel-last formats based on `input_data_format`.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): Target resolution as (height, width) for resizing.
resample: Resampling method to use (e.g., `PILImageResampling.BICUBIC`).
input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
Returns:
np.array: The resized image in NumPy array format.
"""
new_height, new_width = _get_local_grids_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(
image,
(new_height, new_width),
resample=resample,
input_data_format=input_data_format,
)
return resized_image
def _pad_for_patching(
self,
image: np.array,
target_resolution: tuple,
input_data_format: ChannelDimension,
) -> np.array:
"""
Pads the image to match the target resolution, ensuring compatibility with patch-based models.
This is typically used to make sure the image dimensions are divisible by the patch size or to
meet specific model input requirements. Padding is applied symmetrically where needed.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): The desired resolution after padding, in the format (height, width).
input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
Returns:
np.array: The padded image as a NumPy array.
"""
target_height, target_width = target_resolution
background_color = tuple(int(x * 255) for x in self.image_mean)
padded_image = pad(
image,
target_size=(target_height, target_width),
background_color=background_color,
input_data_format=input_data_format,
)
return padded_image
def get_image_grids(
self,
image: np.array,
possible_resolutions,
grid_size: int,
resample: PILImageResampling,
data_format: ChannelDimension,
input_data_format: ChannelDimension,
) -> List[np.array]:
"""
Splits the input image into multiple local grids based on possible resolutions and grid size.
The function selects the best resolution from the provided list, resizes the image accordingly,
and divides it into non-overlapping grid patches of size (grid_size x grid_size). It is commonly
used for any-resolution (anyres) visual processing.
Args:
image (np.array): Input image as a NumPy array.
possible_resolutions (List[Tuple[int, int]]): List of allowed resolutions to choose from.
grid_size (int): The size of each grid patch (e.g., 336 pixels).
resample (PILImageResampling): Resampling method used during resizing.
data_format (ChannelDimension): Output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (ChannelDimension): Input data format of the image.
Returns:
List[np.array]: A list of grid image patches as NumPy arrays.
"""
if not isinstance(possible_resolutions, list):
raise ValueError("possible_resolutions must be a list of possible resolutions.")
image_size = get_image_size(image, channel_dim=input_data_format)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_local_grids(
image,
best_resolution,
resample=resample,
input_data_format=input_data_format,
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
local_grids = divide_to_grids(padded_image, grid_size=grid_size, input_data_format=input_data_format)
# make sure that all patches are in the input data format
local_grids = [
to_channel_dimension_format(grid, channel_dim=data_format, input_channel_dim=input_data_format)
for grid in local_grids
]
return local_grids
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
anyres: bool = None,
unpad: bool = None,
is_video_list: List[bool] = None,
possible_resolutions: List = None,
patch_size: int = None,
pad_to_square: bool = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
is_first_or_last_frames: List[bool] = False,
):
"""
Preprocesses images using HCXVisionProcessor.
This method prepares images for visual language models by applying resizing, padding, cropping,
normalization, and tokenization into visual patches. In video mode, each frame is converted to
a 1D sequence of patches. The `unpad` option is disabled when processing videos.
Args:
images (ImageInput): A single image or a batch of images (PIL, NumPy, or tensor format).
do_resize (bool, optional): Whether to resize the image(s).
size (Dict[str, int], optional): Resize target with keys `"height"` and `"width"`.
anyres (bool, optional): Whether to use any-resolution processing with grid splitting.
unpad (bool, optional): Whether to remove visual tokens that belong to padding areas (only in non-video mode).
is_video_list (List[bool], optional): A list indicating which inputs are video frames.
possible_resolutions (List, optional): List of resolution pairs allowed in `anyres` mode.
patch_size (int, optional): Patch size for the Vision Transformer (ViT).
pad_to_square (bool, optional): Whether to pad the image to a square.
resample (PILImageResampling, optional): Resampling method to use for resizing.
do_center_crop (bool, optional): Whether to apply center cropping.
crop_size (int, optional): Target crop size for center cropping.
do_rescale (bool, optional): Whether to rescale image pixel values.
rescale_factor (float, optional): Factor for pixel rescaling, e.g., `1/255`.
do_normalize (bool, optional): Whether to normalize using mean and std.
image_mean (float or List[float], optional): Mean value(s) for normalization.
image_std (float or List[float], optional): Standard deviation(s) for normalization.
do_convert_rgb (bool, optional): Whether to convert the image to RGB.
return_tensors (str or TensorType, optional): Desired output tensor type (e.g., "pt" for PyTorch).
data_format (ChannelDimension, optional): Output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (str or ChannelDimension, optional): Format of the input image.
is_first_or_last_frames (List[bool], optional): Flags indicating whether each image is a first/last video frame.
Returns:
Tuple:
pixel_values (List[torch.Tensor]): A list of 4D image tensors ready for model input.
image_sizes (List[List[int]]): A list of list containing the original width and height [width, height]
of each image, e.g., `[[width, height], ...]`.
vision_query_lengths (List[int]): A list of integers representing the number of visual tokens
each image contributes to the LLM input.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
anyres = anyres if anyres is not None else self.anyres
unpad = unpad if unpad is not None else self.unpad
possible_resolutions = possible_resolutions if possible_resolutions is not None else self.possible_resolutions
patch_size = patch_size if patch_size is not None else self.patch_size
pad_to_square = pad_to_square if pad_to_square is not None else self.pad_to_square
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
new_images = []
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
vision_query_lengths = []
assert crop_size["height"] == crop_size["width"]
# Padding operations for the global image can become a bottleneck when the original image width or height is large.
# To mitigate this, the image is first resized such that the longest side is scaled proportionally based on size["shortest_edge"],
# and then padding is applied to reach the target dimensions.
if anyres:
anyres_global_images = copy.deepcopy(images)
if pad_to_square:
background_color = tuple(int(x * 255) for x in self.image_mean)
anyres_global_images = [
resize_longside(
copy.deepcopy(image),
size["shortest_edge"],
resample,
input_data_format,
)
for image in anyres_global_images
]
anyres_global_images = [
expand2square(
image,
background_color=background_color,
input_data_format=input_data_format,
)[0]
for image in anyres_global_images
]
else:
anyres_global_images = [
self.resize(
image=image,
size={
"height": size["shortest_edge"],
"width": size["shortest_edge"],
},
resample=resample,
input_data_format=input_data_format,
)
for image in anyres_global_images
]
else:
anyres_global_images = [None for _ in range(len(images))]
if pad_to_square:
background_color = tuple(int(x * 255) for x in self.image_mean)
images = [
resize_longside(image, size["shortest_edge"], resample, input_data_format) for image in images
]
images = [
expand2square(
image,
background_color=background_color,
input_data_format=input_data_format,
)[0]
for image in images
]
num_queries_vis_abstractors = []
num_queries_vis_abstractors_slow = []
first_last_frames_slows = []
for image, is_video, anyres_global_image, image_size in zip(
images, is_video_list, anyres_global_images, image_sizes
):
if is_video:
num_queries_vis_abstractor = self.num_queries_vis_abstractor_video_fast
num_queries_vis_abstractor_slow = self.num_queries_vis_abstractor_video_slow
else:
num_queries_vis_abstractor = self.num_queries_vis_abstractor
num_queries_vis_abstractor_slow = 0
num_queries_vis_abstractors.append(num_queries_vis_abstractor)
num_queries_vis_abstractors_slow.append(num_queries_vis_abstractor_slow)
first_last_frames_slows.append(self.first_last_frames_slow)
if anyres:
# convert image into a list of grids
# we intentially use the same data format as the input data format
image_grids = self.get_image_grids(
image,
possible_resolutions,
grid_size=crop_size["height"],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
# Global image (thumbnail) is not used for video inputs.
if not is_video:
image_grids = [anyres_global_image] + image_grids
else:
image_grids = [image]
pixel_values = self._preprocess(
image_grids,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
pixel_values = np.array(pixel_values)
new_images.append(pixel_values)
num_grids = pixel_values.shape[0]
vision_query_length = determine_anyres_num_vision_patches(
num_grids=num_grids,
image_size=image_size,
grid_size=crop_size["height"],
patch_size=patch_size,
possible_resolutions=possible_resolutions,
anyres=anyres,
unpad=False if is_video else unpad,
num_queries_vis_abstractor=num_queries_vis_abstractor,
num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
is_video=is_video,
first_last_frames_slow=self.first_last_frames_slow,
is_first_or_last_frames=self.first_last_frames_slow,
)
vision_query_lengths.append(vision_query_length)
data = {
"pixel_values": [[torch.tensor(new_image) for new_image in new_images]],
"image_sizes": [[[image_size[1], image_size[0]] for image_size in image_sizes]],
"vision_query_lengths": [vision_query_lengths],
"is_videos": [is_video_list],
"num_queries_vis_abstractors": [num_queries_vis_abstractors],
"num_queries_vis_abstractors_slow": [num_queries_vis_abstractors_slow],
"first_last_frames_slows": [first_last_frames_slows],
}
return BatchFeature(data=data)
def load_images_videos(self, vlm_chat):
"""
Loads and prepares images or video frames from a VLM chat input.
This function parses the input `vlm_chat` object, extracts image or video sources,
and loads them into memory as PIL or NumPy images, ready for preprocessing.
Args:
vlm_chat: A VLM chat input structure containing multimodal elements
(e.g., images, videos, URLs, or file paths). The format is typically a list of messages
with associated media fields.
Returns:
List[Union[PIL.Image.Image, List[PIL.Image.Image]]]:
A list of loaded images. For video entries, a list of frames is returned instead of a single image.
"""
vlm_chat = copy.deepcopy(vlm_chat)
new_vlm_chat = []
all_images = [] # images + images_from_videos
is_video_list = []
for line in vlm_chat:
if "content" in line:
content = line["content"]
if "image" in content:
if "filename" not in content:
content["filename"] = f"{uuid.uuid4().hex}.jpg"
image_pil = load_image(content["image"])
all_images.append(image_pil)
is_video_list.append(False)
new_vlm_chat.append(line)
elif "video" in content:
video_bytesio = load_video_to_bytesio(content["video"])
pil_img_frames, video_time_stamp = process_video(
video_bytesio, self.max_num_grids, self.max_image_cnt, self.crop_size["width"]
)
all_images.extend(pil_img_frames)
is_video_list.extend([True] * len(pil_img_frames))
if "filename" not in content:
content["filename"] = f"{uuid.uuid4().hex}.mp4"
for i, image_time_stamp in enumerate(video_time_stamp):
new_line = copy.deepcopy(line)
basename, ext = os.path.splitext(content["filename"])
new_line["content"]["filename"] = f"{basename}-{i}{ext}"
new_line["content"]["video_time_stamp"] = image_time_stamp
if i == len(video_time_stamp) - 1:
new_line["content"]["is_final_grid"] = True
for last_frame_target_key in ["lens_keywords", "lens_local_keywords", "speech_to_text"]:
if last_frame_target_key in content:
new_line["content"][last_frame_target_key] = content[last_frame_target_key]
new_vlm_chat.append(new_line)
else:
new_vlm_chat.append(line)
return new_vlm_chat, all_images, is_video_list
def process_video(video_bytesio, max_num_grids, max_image_cnt, vit_input_size):
"""
Processes a video file and extracts frames suitable for vision transformer (ViT) input.
The function reads video data from a BytesIO object, extracts a limited number of frames
based on `max_num_grids` and `max_image_cnt`, and resizes them to the appropriate ViT input size.
Args:
video_bytesio (io.BytesIO): A BytesIO object containing the raw video file data.
max_num_grids (int): The maximum number of grids allowed (e.g., for tiling or patching).
max_image_cnt (int): The maximum number of frames to extract from the video.
vit_input_size (int): The desired input size (height and width) for the ViT model.
Returns:
List[np.ndarray]: A list of processed video frames as NumPy arrays, each resized to (vit_input_size, vit_input_size).
"""
frames, time_interval = video_decoder(
video_bytesio, max_num_grids=max_num_grids, max_image_cnt=max_image_cnt, default_interval=0.4
)
pil_img_frames, video_time_stamp = combine_frames_into_images(
frames, time_interval, max_grid_shape=(max_num_grids, 1), vit_input_size=vit_input_size
)
return pil_img_frames, video_time_stamp
def load_image(image_src):
"""
Loads an image from various sources (file path, URL, base64 string, or raw bytes)
and returns it as a PIL Image object.
Args:
image_src (str or bytes): The image source. It can be:
- A local file path
- A URL
- A base64-encoded string
- Raw image bytes
Returns:
PIL.Image.Image: The loaded image as a PIL Image object.
Raises:
ValueError: If the image cannot be loaded or the format is unsupported.
TypeError: If the input is not of type str or bytes.
"""
try:
# 1. If input is bytes type
if isinstance(image_src, bytes):
return Image.open(io.BytesIO(image_src))
# 2. If input is str type (path, URL, base64)
if isinstance(image_src, str):
# 2a. Check if it's a Base64 data URI format ('data:image/...')
if image_src.startswith("data:image"):
try:
# Remove the 'data:image/...;base64,' part and decode
header, encoded = image_src.split(",", 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes))
except (ValueError, base64.binascii.Error) as e:
raise ValueError(f"Invalid base64 data URI format: {e}") from e
# 2b. Check if it's a URL format ('http://' or 'https://')
elif image_src.startswith("http://") or image_src.startswith("https://"):
try:
response = requests.get(image_src, stream=True, timeout=10)
response.raise_for_status() # Raise an exception for HTTP errors
image_bytes = response.content
return Image.open(io.BytesIO(image_bytes))
except requests.exceptions.RequestException as e:
raise ValueError(f"Error loading image from URL '{image_src}': {e}") from e
# 2c. Assume it's a local file path
else:
return Image.open(image_src)
else:
raise TypeError(f"Unsupported image_src type: {type(image_src)}")
# Common exception handling
except FileNotFoundError:
raise ValueError(f"Image loading error: File not found '{image_src}'")
except UnidentifiedImageError:
raise ValueError("Image loading error: Cannot identify image file format.")
except IOError as e:
raise ValueError(f"Image loading error (I/O): {e}") from e
except Exception as e:
raise ValueError(f"Unexpected error during image loading: {e}") from e
def load_video_to_bytesio(video_src):
"""
Loads video data from various sources (file path, URL, base64 string, or raw bytes)
and returns an `io.BytesIO` object containing the raw video content.
Args:
video_src (str or bytes): The video source. Supported formats include:
- Local file path
- URL
- Base64-encoded data URI string
- Raw video bytes
Returns:
io.BytesIO: A `BytesIO` object containing the loaded video data.
Raises:
ValueError: If the video cannot be loaded due to issues such as an invalid path,
URL failure, malformed base64 string, or unsupported format.
TypeError: If the input is not a `str` or `bytes` object.
"""
video_bytes = None
try:
# 1. If input is bytes type
if isinstance(video_src, bytes):
video_bytes = video_src
# 2. If input is str type (path, URL, base64)
elif isinstance(video_src, str):
# 2a. Check if it's a Base64 data URI format ('data:video/...')
if video_src.startswith("data:video"):
try:
# Remove the 'data:video/...;base64,' part and decode
header, encoded = video_src.split(",", 1)
video_bytes = base64.b64decode(encoded)
except (ValueError, base64.binascii.Error) as e:
raise ValueError(f"Invalid base64 data URI format: {e}") from e
# 2b. Check if it looks like a URL
elif urlparse(video_src).scheme in ("http", "https"):
try:
response = requests.get(
video_src, stream=True, timeout=30
) # Increased timeout for potentially large videos
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
# Read all content from the stream into bytes
video_bytes = response.content
except requests.exceptions.MissingSchema:
# If urlparse thinks it's a scheme but requests disagrees (e.g., "http:/example.com")
# Treat it as a potential file path below.
pass
except requests.exceptions.RequestException as e:
raise ValueError(f"Error loading video from URL '{video_src}': {e}") from e
# 2c. Assume it's a local file path if not base64 or confirmed URL
if video_bytes is None: # Only attempt file read if not already loaded as base64 or URL failed gracefully
# Check if it could potentially be a file path
# Note: This check is basic. A string like "http:/path/file" might incorrectly be treated as a path here
# if the requests call failed due to MissingSchema. More robust path validation could be added.
if (
os.path.exists(video_src) or "/" in video_src or "\\" in video_src
): # Basic check if it resembles a path
try:
with open(video_src, "rb") as f:
video_bytes = f.read()
except FileNotFoundError:
raise ValueError(f"Video loading error: File not found at path '{video_src}'")
except IsADirectoryError:
raise ValueError(f"Video loading error: Path '{video_src}' is a directory, not a file.")
except IOError as e:
raise ValueError(f"Video loading error (I/O) for path '{video_src}': {e}") from e
else:
# If it's not base64, not a valid downloadable URL, and doesn't look like a path/doesn't exist
raise ValueError(f"Unsupported string input format or resource not found: '{video_src}'")
# 3. If the type is unsupported
else:
raise TypeError(f"Unsupported video_src type: {type(video_src)}")
# Final check if video_bytes was successfully obtained
if video_bytes is None:
raise ValueError(f"Could not load video data from the provided source: {video_src}")
# Return the bytes wrapped in BytesIO
return io.BytesIO(video_bytes)
# Catch specific exceptions first for better error reporting
except FileNotFoundError as e: # Should be caught above, but as a safeguard
raise ValueError(f"Video loading error: File not found '{video_src}'") from e
except requests.exceptions.RequestException as e: # Already handled, but for clarity
raise ValueError(f"Video loading error (Network): {e}") from e
except (ValueError, TypeError) as e: # Re-raise ValueErrors/TypeErrors raised intentionally within the try block
raise e
except Exception as e:
# Catch any other unexpected errors during processing
raise ValueError(f"Unexpected error during video loading from source '{video_src}': {e}") from e
def video_decoder(video_bytesio, max_num_grids, max_image_cnt, default_interval=0.4):
"""
Decodes video data from a BytesIO object and returns a list of extracted frames.
Args:
video_bytesio (io.BytesIO): A BytesIO object containing the raw video data.
max_num_grids (int): Maximum number of grids allowed per image. Used to determine how many frames to extract.
max_image_cnt (int): Maximum number of frames to extract from the video.
default_interval (float, optional): Default time interval (in seconds) between frames. Used when frame rate info is unavailable. TODO: make configurable.
Returns:
Tuple:
frames (List[PIL.Image.Image]): A list of extracted frames as PIL Images.
time_interval (float): Time interval (in seconds) between selected frames.
"""
error_messages = []
frames = []
# 1. Try decoding the video using Decord.
try:
vr = VideoReader(video_bytesio, ctx=cpu(0), num_threads=8)
fps = vr.get_avg_fps()
play_time = len(vr) / fps
total_frames = len(vr)
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample every 0.4 seconds; if the video is too long, apply uniform sampling instead.
if frame_indices is None:
frame_indices = range(len(vr)) # Convert all frames.
batch_frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(frame).convert("RGB") for frame in batch_frames]
return frames, time_interval
except Exception as e:
print("error with decord")
error_messages.append(f"Decord 실패: {e}")
# 2. Fallback: Try decoding the video using PyAV.
try:
container = av.open(video_bytesio)
fps = container.streams.video[0].average_rate
play_time = len(container) / fps
total_frames = len(container)
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample frames every 0.4 seconds. If the video is long, use uniform sampling to limit the number of frames.
# Even if frame_indices were assigned using Decord, reprocess them to be compatible with PyAV.
target_indices = None if frame_indices is None else set(frame_indices)
frames = []
for i, frame in enumerate(container.decode(video=0)):
if target_indices is not None and i not in target_indices:
continue # Skip frames that are not in the required indices.
pil_frame = Image.fromarray(frame.to_ndarray(format="rgb24")).convert("RGB")
frames.append(pil_frame)
if frames:
return frames, time_interval
else:
raise Exception("Decoding with PyAV succeeded, but no frames were extracted.")
except Exception as e:
error_messages.append(f"PyAV failed: {e}")
# 3. Fallback: Try decoding the video using OpenCV.
try:
byte_data = np.frombuffer(video_bytesio.getvalue(), dtype=np.uint8)
video = cv2.imdecode(byte_data, cv2.IMREAD_UNCHANGED)
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
play_time = total_frames / fps
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample frames every 0.4 seconds; if the video is too long, apply uniform sampling to limit the total number of frames.
if frame_indices is None:
frame_indices = range(total_frames) # Convert all frames.
index_set = set(frame_indices) # Convert to a set for faster lookup.
current_index = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if current_index in index_set:
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert("RGB"))
current_index += 1
if current_index > max(index_set): # Stop processing once all required indices have been handled.
break
cap.release()
if frames:
return frames, time_interval
except Exception as e:
error_messages.append(f"OpenCV failed: {e}")
if error_messages:
raise Exception(f"All decoding attempts have failed.: {error_messages}")
def convert_format_for_multi_image(img, json, convert_key_list=["words", "text", "objects", "entities"]):
"""
Converts the format of image and annotation data from a single-image dataset to a multi-image dataset format.
Single-image datasets typically return a single image and its associated annotation as individual objects.
This function wraps them in a dictionary format used by multi-image datasets.
Args:
img: The input image (e.g., a PIL Image or NumPy array).
json: The annotation data associated with the image.
convert_key_list (List[str], optional): A list of keys to extract and convert from the original JSON.
Defaults to ["words", "text", "objects", "entities"].
Returns:
Tuple[Dict, Dict]:
- A dictionary mapping image IDs to images (e.g., {"image_0": img}).
- A dictionary mapping image IDs to corresponding annotation JSONs (with filtered keys).
"""
is_multi_image_dataset = isinstance(img, dict)
if not is_multi_image_dataset:
img = {"00": img}
for convert_key in convert_key_list:
if convert_key in json:
json[convert_key] = {"00": json[convert_key]}
for json_key in json:
if "region" in json_key:
json[json_key] = {"00": json[json_key]}
return is_multi_image_dataset, img, json
def convert_tags_for_video(img, json):
"""
Converts <video_00> tags to <image_xx> tags based on the number of video frames.
In video datasets, annotations often use a generic <video_00> tag. This function replaces that tag
with frame-specific tags such as <image_00>, <image_01>, ..., <image_NN> based on the number of frames in `img`.
Args:
img: A list of video frames (e.g., list of PIL Images or NumPy arrays).
json: The annotation data containing <video_00> tags to be replaced.
Returns:
Dict: The updated annotation JSON with frame-specific <image_xx> tags.
"""
image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))])
# image_tag = "<image_00>" # Use this format to construct and insert image-specific tags.
for json_key in json:
if "qa_pairs" in json_key:
new_qa_pairs = []
for qa_pair in json[json_key]:
question = qa_pair[0]
# Replace <video_00> tags with corresponding <image_xx> tags.
question = question.replace("<video_00>", image_tag)
new_qa_pairs.append([question, qa_pair[1]])
json[json_key] = new_qa_pairs
return img, json
def split_list(input_list, split_value):
"""
Splits a list into sublists using a specified delimiter value.
Each time `split_value` is encountered in `input_list`, a new sublist is started.
The delimiter itself is not included in the output.
Args:
input_list (List[Any]): The input list to split.
split_value (Any): The value used as the delimiter for splitting.
Returns:
List[List[Any]]: A list of sublists, split by the specified delimiter.
Example:
>>> split_list(["a", "b", "|", "c", "d", "|", "e"], "|")
[['a', 'b'], ['c', 'd'], ['e']]
"""
temp_list = []
result = []
for value in input_list:
if value == split_value:
result.append(temp_list)
temp_list = []
else:
temp_list.append(value)
result.append(temp_list)
return result
def combine_frames_into_images(frames, time_interval, max_grid_shape=(3, 3), vit_input_size=378):
"""
Combines a sequence of video frames into grid-based images and generates corresponding time range labels.
Frames are grouped and arranged into a grid (e.g., 3x3) such that each combined image contains up to
`max_grid_shape[0] * max_grid_shape[1]` frames. Each combined image is resized to the given ViT input size.
Args:
frames (List[PIL.Image.Image]): A list of frames extracted from a video.
time_interval (float): Time interval (in seconds) between consecutive frames.
max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3).
vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378.
Returns:
Tuple:
image_list (List[PIL.Image.Image]): A list of grid-combined images.
image_time_stamps (List[str]): A list of time span labels for each combined image,
e.g., ["0.00s~1.50s", "1.50s~3.00s", ...].
"""
# grid_size = int(np.sqrt(max_num_grids))
# assert grid_size**2 == max_num_grids, "max_num_grids must be a perfect square."
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
assert (
max_grid_shape[1] == 1
), f"For video processing, decided to concatenate frames horizontally into a wide image."
# List to store the resulting combined images.
image_list = []
# Calculate the number of canvases needed.
num_frames = len(frames)
num_canvases = num_frames // max_num_grids
leftover_frames = num_frames % max_num_grids
time_stamp = 0 # second
image_time_stamps = []
for canvas_idx in range(num_canvases):
# Initialize the current canvas.
combined_image = Image.new(
"RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0)
)
# Determine the frames to fill in the current canvas.
start_idx = canvas_idx * max_num_grids
end_idx = min(start_idx + max_num_grids, num_frames)
for idx in range(start_idx, end_idx):
img = frames[idx]
# Resize each frame to a square shape.
img_resized = img.resize((vit_input_size, vit_input_size))
# Calculate the (row, column) position to place the frame within the grid layout.
local_idx = idx - start_idx
x_offset = (local_idx % max_grid_shape[0]) * vit_input_size
y_offset = (local_idx // max_grid_shape[0]) * vit_input_size
# Calculate the position to place the frame in the grid.
combined_image.paste(img_resized, (x_offset, y_offset))
# Append the current canvas to the result list.
image_list.append(combined_image)
frame_cnt = end_idx - start_idx
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
time_stamp += frame_cnt * time_interval
if leftover_frames > 0:
# canvas_idx might be undefined; default to 0 if not previously assigned to avoid "referenced before assignment" error.
canvas_idx = num_canvases
# Add the remaining frames to the final canvas.
combined_image = Image.new("RGB", (vit_input_size * leftover_frames, vit_input_size * 1), color=(0, 0, 0))
for idx in range(leftover_frames):
img = frames[num_canvases * max_num_grids + idx]
# Resize the frame to a square (equal width and height).
img_resized = img.resize((vit_input_size, vit_input_size))
# Calculate the (row, column) position to place the frame within the grid layout.
x_offset = (idx % leftover_frames) * vit_input_size
y_offset = (idx // leftover_frames) * vit_input_size
# Calculate the position to place the frame within the grid layout.
combined_image.paste(img_resized, (x_offset, y_offset))
# Add the current canvas to the list of combined images.
image_list.append(combined_image)
frame_cnt = leftover_frames
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
time_stamp += frame_cnt * time_interval
return image_list, image_time_stamps
def extract_frame_indices(play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=0.4):
"""
Extracts specific frame indices from a video based on duration, frame count, and sampling strategy.
The function determines which frames to extract given the video duration (`play_time`),
total frame count, and frame rate. It samples frames at regular intervals (default: 0.4s),
but if the number of frames exceeds the limit defined by `max_num_grids * max_image_cnt`,
it performs uniform sampling to stay within that limit.
Args:
play_time (float): Total play time of the video in seconds.
total_frames (int): Total number of frames in the video.
fps (float): Frames per second of the video.
max_num_grids (int): Maximum number of grids to display.
max_image_cnt (int): Maximum number of images per grid.
default_interval (float, optional): Interval in seconds between frame samples. Defaults to 0.4.
Returns:
Tuple:
frame_indices (List[int]): A list of selected frame indices.
time_interval (float): Time interval between selected frames (in seconds).
"""
# Calculate how many frames to extract with the default interval
default_frame_count = int(play_time / default_interval)
# Maximum frames allowed based on max_num_grids and max_image_cnt
max_frames_allowed = max_num_grids * max_image_cnt
# Determine whether we can use the default interval or need uniform sampling
if default_frame_count <= max_frames_allowed:
# Default interval is sufficient, extract frames every 0.4 seconds
frame_interval = int(total_frames / default_frame_count)
else:
# Use uniform sampling to fit within max_frames_allowed
frame_interval = int(total_frames / max_frames_allowed)
# Extract frame indices at the calculated interval
selected_indices = list(range(0, total_frames, frame_interval))
time_interval = frame_interval / fps
# Ensure the number of selected indices does not exceed max_frames_allowed
return selected_indices[:max_frames_allowed], time_interval
|