File size: 2,967 Bytes
e2a4738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from transformers import (
    BlipForQuestionAnswering,
    BlipProcessor,
    ViltForQuestionAnswering,
    ViltProcessor,
)


class ModelManager:
    """

    Class to manage loading and caching of various VQA models from Hugging Face

    """

    def __init__(self, cache_dir=None):
        """

        Initialize the model manager



        Args:

            cache_dir (str, optional): Directory to cache models. Defaults to None.

        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cache_dir = cache_dir
        self.models = {}
        self.processors = {}

        # Print device being used
        print(f"Using device: {self.device}")

    def load_blip(self):
        """

        Load BLIP model for VQA



        Returns:

            tuple: (processor, model)

        """
        if "blip" not in self.models:
            print("Loading BLIP model for visual question answering...")

            # Load processor and model
            processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
            model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

            # Move model to appropriate device
            model.to(self.device)

            # Store model and processor
            self.models["blip"] = model
            self.processors["blip"] = processor

            print("BLIP model loaded successfully!")

        return self.processors["blip"], self.models["blip"]

    def load_vilt(self):
        """

        Load ViLT model for VQA



        Returns:

            tuple: (processor, model)

        """
        if "vilt" not in self.models:
            print("Loading ViLT model for visual question answering...")

            # Load processor and model
            processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-vqa")
            model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-vqa")

            # Move model to appropriate device
            model.to(self.device)

            # Store model and processor
            self.models["vilt"] = model
            self.processors["vilt"] = processor

            print("ViLT model loaded successfully!")

        return self.processors["vilt"], self.models["vilt"]

    def get_model(self, model_name="blip"):
        """

        Get a model by name



        Args:

            model_name (str, optional): Name of model to load. Defaults to "blip".

                                       Options: "blip", "vilt"



        Returns:

            tuple: (processor, model)

        """
        if model_name.lower() == "blip":
            return self.load_blip()
        elif model_name.lower() == "vilt":
            return self.load_vilt()
        else:
            raise ValueError(
                f"Unknown model: {model_name}. Available models: blip, vilt"
            )