Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from dataclasses import dataclass | |
from torch import nn | |
from transformers import AutoTokenizer, CLIPTokenizerFast, CLIPTextModel, T5EncoderModel | |
from typing import List | |
class TextModelOutput: | |
embeddings: torch.Tensor | |
masks: torch.Tensor | |
pooled: List | |
class TextModel(nn.Module): | |
available_modes = [ | |
"last", # If present, use last layer. | |
"penultimate", # If present, use penultimate layer. | |
"penultimate_nonorm", # If present, use penultimate layer without final norm. | |
"token_cat", # If present, concat in token dimension, default concat in channel dimension. | |
"pad0", # If present, use 0 padding, default use EOT padding. | |
"masked", # If present, pass attention mask to encoder. | |
] | |
def __init__(self, variant: List[str], mode: List[str]): | |
super().__init__() | |
self.mode = set(mode) | |
self.tokenizers = [] | |
self.models = nn.ModuleList([]) | |
for v in variant: | |
if "clip" in v.lower(): | |
self.tokenizers.append(CLIPTokenizerFast.from_pretrained(v, model_max_length=77)) | |
self.models.append(CLIPTextModel.from_pretrained(v)) | |
elif "t5" in v.lower() or "ul2" in v.lower(): | |
self.tokenizers.append(AutoTokenizer.from_pretrained(v, model_max_length=77)) | |
self.models.append(T5EncoderModel.from_pretrained(v, torch_dtype=torch.bfloat16)) | |
else: | |
raise NotImplementedError | |
def get_vaild_token_length(self, text): # Return the length of the BPE encoding of the text, excluding `<sos>` and `<eos>`. | |
lengths = [] | |
for tokenizer, model in zip(self.tokenizers, self.models): | |
tokens = tokenizer( | |
text=text, | |
truncation=True, | |
padding="max_length", | |
return_tensors="pt" | |
).to(model.device) | |
token_length = tokens["attention_mask"].sum() - 2 # In the attention mask, both the SOS and EOS (first PAD) have a value of 1. | |
lengths.append(token_length.item()) | |
length = int(sum(lengths) / len(lengths)) | |
return length | |
def forward(self, text: List[str]) -> TextModelOutput: | |
embeddings = [] | |
masks = [] | |
pooled = [] | |
for tokenizer, model in zip(self.tokenizers, self.models): | |
tokens = tokenizer( | |
text=text, | |
truncation=True, | |
padding="max_length", | |
return_tensors="pt" | |
).to(model.device) | |
if "pad0" in self.mode: | |
tokens.input_ids *= tokens.attention_mask | |
output = model( | |
input_ids=tokens.input_ids, | |
attention_mask=tokens.attention_mask if "masked" in self.mode else None, | |
output_hidden_states=True | |
) | |
if "last" in self.mode: | |
embeddings.append(output.last_hidden_state) | |
if "penultimate" in self.mode: | |
embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2])) | |
if "penultimate_nonorm" in self.mode: | |
embeddings.append(output.hidden_states[-2]) | |
masks.append(tokens.attention_mask) | |
if hasattr(output, "pooler_output"): | |
pooled.append(output.pooler_output) | |
if "token_cat" in self.mode: | |
return TextModelOutput( | |
embeddings=torch.cat(embeddings, dim=1), | |
masks=torch.cat(masks, dim=1), | |
pooled=pooled | |
) | |
else: | |
return TextModelOutput( | |
embeddings=torch.cat(embeddings, dim=2), | |
masks=torch.stack(masks, dim=2).sum(2).clamp_max(1), | |
pooled=pooled | |
) |