File size: 4,487 Bytes
7cc4b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from dataclasses import dataclass
from torch import nn
from transformers import AutoTokenizer, CLIPTokenizerFast, CLIPTextModel, T5EncoderModel
from typing import List

@dataclass
class TextModelOutput:
    embeddings: torch.Tensor
    masks: torch.Tensor
    pooled: List


class TextModel(nn.Module):
    available_modes = [
        "last",                 # If present, use last layer.
        "penultimate",          # If present, use penultimate layer.
        "penultimate_nonorm",   # If present, use penultimate layer without final norm.
        "token_cat",            # If present, concat in token dimension, default concat in channel dimension.
        "pad0",                 # If present, use 0 padding, default use EOT padding.
        "masked",               # If present, pass attention mask to encoder.
    ]

    def __init__(self, variant: List[str], mode: List[str]):
        super().__init__()
        self.mode = set(mode)
        self.tokenizers = []
        self.models = nn.ModuleList([])

        for v in variant:
            if "clip" in v.lower():
                self.tokenizers.append(CLIPTokenizerFast.from_pretrained(v, model_max_length=77))
                self.models.append(CLIPTextModel.from_pretrained(v))
            elif "t5" in v.lower() or "ul2" in v.lower():
                self.tokenizers.append(AutoTokenizer.from_pretrained(v, model_max_length=77))
                self.models.append(T5EncoderModel.from_pretrained(v, torch_dtype=torch.bfloat16))
            else:
                raise NotImplementedError
    
    def get_vaild_token_length(self, text): # Return the length of the BPE encoding of the text, excluding `<sos>` and `<eos>`.
        lengths = []
        for tokenizer, model in zip(self.tokenizers, self.models):

            tokens = tokenizer(
                text=text,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            ).to(model.device)
            token_length = tokens["attention_mask"].sum() - 2 # In the attention mask, both the SOS and EOS (first PAD) have a value of 1.
            lengths.append(token_length.item())
        length = int(sum(lengths) / len(lengths))
        return length

    def forward(self, text: List[str]) -> TextModelOutput:
        embeddings = []
        masks = []
        pooled = []

        for tokenizer, model in zip(self.tokenizers, self.models):

            tokens = tokenizer(
                text=text,
                truncation=True,
                padding="max_length",
                return_tensors="pt"
            ).to(model.device)

            if "pad0" in self.mode:
                tokens.input_ids *= tokens.attention_mask

            output = model(
                input_ids=tokens.input_ids,
                attention_mask=tokens.attention_mask if "masked" in self.mode else None,
                output_hidden_states=True
            )

            if "last" in self.mode:
                embeddings.append(output.last_hidden_state)
            if "penultimate" in self.mode:
                embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
            if "penultimate_nonorm" in self.mode:
                embeddings.append(output.hidden_states[-2])
            masks.append(tokens.attention_mask)
            if hasattr(output, "pooler_output"):
                pooled.append(output.pooler_output)

        if "token_cat" in self.mode:
            return TextModelOutput(
                embeddings=torch.cat(embeddings, dim=1),
                masks=torch.cat(masks, dim=1),
                pooled=pooled
            )
        else:
            return TextModelOutput(
                embeddings=torch.cat(embeddings, dim=2),
                masks=torch.stack(masks, dim=2).sum(2).clamp_max(1),
                pooled=pooled
            )