Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
import warnings | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from videomind.constants import IGNORE_INDEX | |
class HybridDataCollator(object): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
def __call__(self, batch): | |
input_ids = [d['input_ids'] for d in batch] | |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
labels = [d['labels'] for d in batch] | |
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) | |
assert input_ids.size() == labels.size() | |
seq_len, max_len = input_ids.size(1), self.tokenizer.model_max_length | |
if seq_len > max_len: | |
warnings.warn(f'The length of input sequence is exceeding model max length: {seq_len} > {max_len}') | |
input_ids, labels = input_ids[:, :max_len], labels[:, :max_len] | |
data = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids != self.tokenizer.pad_token_id) | |
for key in ('pixel_values', 'pixel_values_videos', 'image_grid_thw', 'video_grid_thw'): | |
if key in batch[0]: | |
data[key] = torch.cat([d[key] for d in batch]) | |
for key in ('timestamps', 'saliency', 'pos_clip'): | |
if key in batch[0]: | |
data[key] = [d[key] for d in batch] | |
return data | |