|
import torch.nn as nn |
|
from transformers import CLIPTokenizer, CLIPTextModel |
|
import time |
|
|
|
|
|
class AbstractEncoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def encode(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
|
|
class FrozenCLIPEmbedder(AbstractEncoder): |
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" |
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): |
|
super().__init__() |
|
self.tokenizer = CLIPTokenizer.from_pretrained(version) |
|
self.transformer = CLIPTextModel.from_pretrained(version) |
|
self.device = device |
|
self.max_length = max_length |
|
self.freeze() |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, text): |
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, |
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") |
|
tokens = batch_encoding["input_ids"].to(self.device) |
|
outputs = self.transformer(input_ids=tokens) |
|
|
|
z = outputs.last_hidden_state |
|
return z, {'token_embedding': outputs.last_hidden_state, 'pooler_output': outputs.pooler_output, 'token_mask': batch_encoding['attention_mask'].to(self.device), 'tokens': batch_encoding["input_ids"].to(self.device)} |
|
|
|
def encode_from_token(self, tokens): |
|
tokens = tokens.to(self.device) |
|
outputs = self.transformer(input_ids=tokens) |
|
|
|
z = outputs.last_hidden_state |
|
return z |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenCLIPTokenizer(AbstractEncoder): |
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" |
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): |
|
super().__init__() |
|
self.tokenizer = CLIPTokenizer.from_pretrained(version) |
|
self.max_length = max_length |
|
self.freeze() |
|
|
|
def freeze(self): |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, text): |
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, |
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") |
|
tokens = batch_encoding["input_ids"] |
|
return tokens |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|