Spaces:
Runtime error
Runtime error
File size: 4,487 Bytes
7cc4b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass
from torch import nn
from transformers import AutoTokenizer, CLIPTokenizerFast, CLIPTextModel, T5EncoderModel
from typing import List
@dataclass
class TextModelOutput:
embeddings: torch.Tensor
masks: torch.Tensor
pooled: List
class TextModel(nn.Module):
available_modes = [
"last", # If present, use last layer.
"penultimate", # If present, use penultimate layer.
"penultimate_nonorm", # If present, use penultimate layer without final norm.
"token_cat", # If present, concat in token dimension, default concat in channel dimension.
"pad0", # If present, use 0 padding, default use EOT padding.
"masked", # If present, pass attention mask to encoder.
]
def __init__(self, variant: List[str], mode: List[str]):
super().__init__()
self.mode = set(mode)
self.tokenizers = []
self.models = nn.ModuleList([])
for v in variant:
if "clip" in v.lower():
self.tokenizers.append(CLIPTokenizerFast.from_pretrained(v, model_max_length=77))
self.models.append(CLIPTextModel.from_pretrained(v))
elif "t5" in v.lower() or "ul2" in v.lower():
self.tokenizers.append(AutoTokenizer.from_pretrained(v, model_max_length=77))
self.models.append(T5EncoderModel.from_pretrained(v, torch_dtype=torch.bfloat16))
else:
raise NotImplementedError
def get_vaild_token_length(self, text): # Return the length of the BPE encoding of the text, excluding `<sos>` and `<eos>`.
lengths = []
for tokenizer, model in zip(self.tokenizers, self.models):
tokens = tokenizer(
text=text,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(model.device)
token_length = tokens["attention_mask"].sum() - 2 # In the attention mask, both the SOS and EOS (first PAD) have a value of 1.
lengths.append(token_length.item())
length = int(sum(lengths) / len(lengths))
return length
def forward(self, text: List[str]) -> TextModelOutput:
embeddings = []
masks = []
pooled = []
for tokenizer, model in zip(self.tokenizers, self.models):
tokens = tokenizer(
text=text,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(model.device)
if "pad0" in self.mode:
tokens.input_ids *= tokens.attention_mask
output = model(
input_ids=tokens.input_ids,
attention_mask=tokens.attention_mask if "masked" in self.mode else None,
output_hidden_states=True
)
if "last" in self.mode:
embeddings.append(output.last_hidden_state)
if "penultimate" in self.mode:
embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
if "penultimate_nonorm" in self.mode:
embeddings.append(output.hidden_states[-2])
masks.append(tokens.attention_mask)
if hasattr(output, "pooler_output"):
pooled.append(output.pooler_output)
if "token_cat" in self.mode:
return TextModelOutput(
embeddings=torch.cat(embeddings, dim=1),
masks=torch.cat(masks, dim=1),
pooled=pooled
)
else:
return TextModelOutput(
embeddings=torch.cat(embeddings, dim=2),
masks=torch.stack(masks, dim=2).sum(2).clamp_max(1),
pooled=pooled
) |