File size: 607 Bytes
c4e7950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from diffusers import ConfigMixin, Mel, ModelMixin

class ImageEncoder(ModelMixin, ConfigMixin):
    def __init__(self, image_processor, encoder_model):
        super().__init__()
        self.processor = image_processor
        self.encoder = encoder_model
        self.eval()
        
    def forward(self, x):
        x = self.encoder(x)
        return x
        
    @torch.no_grad()
    def encode(self, image):
        x = self.processor(image, return_tensors="pt")['pixel_values']
        y = self(x)
        y = y.last_hidden_state
        embedings = y[:,0,:] 
        return embedings