File size: 3,242 Bytes
f9567e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
    This file computes the clip score given image and text pair
"""
import clip
import torch
from PIL import Image
from sklearn.preprocessing import normalize
from torchvision.transforms import Compose, Normalize, Resize
import torch
import numpy as np

class ClipSocre:
    def __init__(self,device='cuda', prefix='A photo depicts', weight=1.0): # weight=2.5
        self.device = device

        self.model, _ = clip.load("ViT-B/32", device=device, jit=False)
        self.model.eval()

        self.transform = Compose([
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

        self.prefix = prefix
        if self.prefix[-1] != ' ':
            self.prefix += ' '
        
        self.w = weight
    
    def extract_all_images(self, images):
        images_input = self.transform(images)
        if self.device == 'cuda':
            images_input = images_input.to(torch.float16)
        image_feature = self.model.encode_image(images_input)
        return image_feature
    
    def extract_all_texts(self, texts,need_prefix):
        if need_prefix:
            c_data = clip.tokenize(self.prefix + texts, truncate=True).to(self.device)
        else:
            c_data = clip.tokenize(texts, truncate=True).to(self.device)
        text_feature = self.model.encode_text(c_data)
        return text_feature
    
    def get_clip_score(self, img, text, need_prefix=False):

        img_f = self.extract_all_images(img)
        text_f = self.extract_all_texts(text,need_prefix)
        images = img_f / torch.sqrt(torch.sum(img_f**2, axis=1, keepdims=True))
        candidates = text_f / torch.sqrt(torch.sum(text_f**2, axis=1, keepdims=True))

        clip_per = self.w * torch.clip(torch.sum(images * candidates, axis=1), 0, None)

        return clip_per
    
    def get_text_clip_score(self, text_1, text_2, need_prefix=False):
        text_1_f = self.extract_all_texts(text_1,need_prefix)
        text_2_f = self.extract_all_texts(text_2,need_prefix)

        candidates_1 = text_1_f / torch.sqrt(torch.sum(text_1_f**2, axis=1, keepdims=True))
        candidates_2 = text_2_f / torch.sqrt(torch.sum(text_2_f**2, axis=1, keepdims=True))

        per = self.w * torch.clip(torch.sum(candidates_1 * candidates_2, axis=1), 0, None)

        
        results = 'ClipS : ' + str(format(per.item(),'.4f'))

        print(results)

        return per.sum()
    
    def get_img_clip_score(self, img_1, img_2, weight = 1):

        img_f_1 = self.extract_all_images(img_1)
        img_f_2 = self.extract_all_images(img_2)

        images_1 = img_f_1 / torch.sqrt(torch.sum(img_f_1**2, axis=1, keepdims=True))
        images_2 = img_f_2 / torch.sqrt(torch.sum(img_f_2**2, axis=1, keepdims=True))

        # per = self.w * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)
        per = weight * torch.clip(torch.sum(images_1 * images_2, axis=1), 0, None)


        return per.sum()


    def calculate_clip_score(self, caption_list, image_unprocessed):
        image_unprocessed = 0.5 * (image_unprocessed + 1.)
        image_unprocessed.clamp_(0., 1.)
        img_resize = Resize((224))(image_unprocessed)
        return self.get_clip_score(img_resize,caption_list)