File size: 607 Bytes
c4e7950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from diffusers import ConfigMixin, Mel, ModelMixin
class ImageEncoder(ModelMixin, ConfigMixin):
def __init__(self, image_processor, encoder_model):
super().__init__()
self.processor = image_processor
self.encoder = encoder_model
self.eval()
def forward(self, x):
x = self.encoder(x)
return x
@torch.no_grad()
def encode(self, image):
x = self.processor(image, return_tensors="pt")['pixel_values']
y = self(x)
y = y.last_hidden_state
embedings = y[:,0,:]
return embedings |