# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for Eagle2_5_VL. copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py """ import math import os from typing import Iterable, List, Union, Literal import base64 import sys import time import warnings from functools import lru_cache from io import BytesIO import re import requests import torch import torchvision from packaging import version from PIL import Image from torchvision import io from typing import Optional, Any import numpy as np from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import select_best_resolution from transformers.image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from transformers.models.auto import AutoImageProcessor logger = logging.get_logger(__name__) FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 256 def adjust_by_factor(number: int, factor: int, method: Literal['round', 'ceil', 'floor'] = 'round') -> int: """Adjusts 'number' to the nearest, ceiling, or floor multiple of 'factor'.""" op = {'round': round, 'ceil': math.ceil, 'floor': math.floor}[method] return op(number / factor) * factor def to_rgb(pil_image: Image.Image) -> Image.Image: if pil_image.mode == 'RGBA': white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask return white_background else: return pil_image.convert("RGB") def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image: if "image" in ele: image = ele["image"] else: image = ele["image_url"] image_obj = None if isinstance(image, Image.Image): image_obj = image elif image.startswith("http://") or image.startswith("https://"): response = requests.get(image, stream=True) image_obj = Image.open(BytesIO(response.content)) elif image.startswith("file://"): image_obj = Image.open(image[7:]) elif image.startswith("data:image"): if "base64," in image: _, base64_data = image.split("base64,", 1) data = base64.b64decode(base64_data) image_obj = Image.open(BytesIO(data)) else: image_obj = Image.open(image) if image_obj is None: raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") image = to_rgb(image_obj) if 'scale_factor' in ele: scale_factor = ele['scale_factor'] image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR) return image def smart_nframes( ele: dict, total_frames: int, video_fps: int | float, ) -> int: """calculate the number of frames for video used for model inputs. Args: ele (dict): a dict contains the configuration of video. support either `fps` or `nframes`: - nframes: the number of frames to extract for model inputs. - fps: the fps to extract frames for model inputs. - min_frames: the minimum number of frames of the video, only used when fps is provided. - max_frames: the maximum number of frames of the video, only used when fps is provided. total_frames (int): the original total number of frames of the video. video_fps (int | float): the original fps of the video. Raises: ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. Returns: int: the number of frames for video used for model inputs. """ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" if "nframes" in ele: nframes = adjust_by_factor(ele["nframes"], FRAME_FACTOR, method='round') else: fps = ele.get("fps", FPS) min_frames = adjust_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR, method='ceil') max_frames = adjust_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR, method='floor') nframes = total_frames / video_fps * fps if nframes > total_frames: logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") nframes = min(min(max(nframes, min_frames), max_frames), total_frames) nframes = adjust_by_factor(nframes, FRAME_FACTOR, method='floor') if not (FRAME_FACTOR <= nframes and nframes <= total_frames): nframes = total_frames # raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") return nframes def _read_video_torchvision( ele: dict, ) -> (torch.Tensor, float, list): """read video using torchvision.io.read_video and return also per-frame timestamps""" video_path = ele["video"] if version.parse(torchvision.__version__) < version.parse("0.19.0"): if "http://" in video_path or "https://" in video_path: warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") if "file://" in video_path: video_path = video_path[7:] st = time.time() video, audio, info = io.read_video( video_path, start_pts=ele.get("video_start", 0.0), end_pts=ele.get("video_end", None), pts_unit="sec", output_format="TCHW", ) total_frames, video_fps = video.size(0), info["video_fps"] logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) # Calculate frame indices and corresponding timestamps (based on video start time) idx = torch.linspace(0, total_frames - 1, nframes).round().long() start_time = ele.get("video_start", 0.0) timestamps = (start_time + idx.to(torch.float32) / video_fps).tolist() sample_fps = nframes / max(total_frames, 1e-6) * video_fps video = video[idx] return video, sample_fps, timestamps def is_decord_available() -> bool: import importlib.util return importlib.util.find_spec("decord") is not None def _read_video_decord( ele: dict, ) -> (torch.Tensor, float, list): """read video using decord.VideoReader and return also per-frame timestamps""" import decord video_path = ele["video"] st = time.time() vr = decord.VideoReader(video_path) if 'video_start' in ele or 'video_end' in ele: raise NotImplementedError("not support start_pts and end_pts in decord for now.") total_frames, video_fps = len(vr), vr.get_avg_fps() logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() start_time = ele.get("video_start", 0.0) # TODO: timestamps = [start_time + i / video_fps for i in idx] video = vr.get_batch(idx).asnumpy() video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format sample_fps = nframes / max(total_frames, 1e-6) * video_fps return video, sample_fps, timestamps VIDEO_READER_BACKENDS = { "decord": _read_video_decord, "torchvision": _read_video_torchvision, } @lru_cache(maxsize=1) def get_video_reader_backend() -> str: if is_decord_available(): video_reader_backend = "decord" else: video_reader_backend = "torchvision" return video_reader_backend def fetch_video(ele: dict, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]: if isinstance(ele["video"], str): video_reader_backend = get_video_reader_backend() try: video, sample_fps, timestamps = VIDEO_READER_BACKENDS[video_reader_backend](ele) except Exception as e: logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") video, sample_fps, timestamps = VIDEO_READER_BACKENDS["torchvision"](ele) nframes, _, height, width = video.shape if return_video_sample_fps: return video, sample_fps, timestamps return video else: assert isinstance(ele["video"], (list, tuple)) process_info = ele.copy() process_info.pop("type", None) process_info.pop("video", None) images = [ fetch_image({"image": video_element, **process_info}) for video_element in ele["video"] ] nframes = adjust_by_factor(len(images), FRAME_FACTOR, method='ceil') if len(images) < nframes: images.extend([images[-1]] * (nframes - len(images))) timestamps = [-1 for i in range(nframes)] # not sure about this if return_video_sample_fps: return images, process_info.pop("fps", 2.0), timestamps return images class Eagle2_5_VLProcessorKwargs(ProcessingKwargs, total=False): # see processing_utils.ProcessingKwargs documentation for usage. _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": {}, "videos_kwargs": {"max_dynamic_tiles": 1}, } class Eagle2_5_VLProcessor(ProcessorMixin): r""" Constructs a Eagle2_5_VL processor which wraps a Eagle2_5_VL video processor, Eagle2_5_VL image processor and a Eagle2_5_VL tokenizer into a single processor. [`Eagle2_5_VLProcessor`] offers all the functionalities of [`Eagle2_5_VLVideoProcessor`], [`Eagle2_5_VLImageProcessor`] and [`Eagle2_5_VLTokenizer`]. See the [`~Eagle2_5_VLVideoProcessor.__call__`], [`~Eagle2_5_VLProcessor.__call__`] and [`~Eagle2_5_VLProcessor.decode`] for more information. Args: image_processor ([`LlavaOnevisionImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. num_image_tokens (`int`, *optional*): Number of image tokens for one imagethat will be returned by vision tower. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Shoudl be same as in model's config chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `""`): Special token used to denote image location. video_token (`str`, *optional*, defaults to `"