Amarthya7 commited on
Commit
e2a4738
·
verified ·
1 Parent(s): 4a5eae5

Upload 6 files

Browse files
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model_loader import ModelManager
2
+ from .vqa_inference import VQAInference
3
+
4
+ __all__ = ["ModelManager", "VQAInference"]
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (325 Bytes). View file
 
models/__pycache__/model_loader.cpython-311.pyc ADDED
Binary file (3.92 kB). View file
 
models/__pycache__/vqa_inference.cpython-311.pyc ADDED
Binary file (7.98 kB). View file
 
models/model_loader.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ BlipForQuestionAnswering,
4
+ BlipProcessor,
5
+ ViltForQuestionAnswering,
6
+ ViltProcessor,
7
+ )
8
+
9
+
10
+ class ModelManager:
11
+ """
12
+ Class to manage loading and caching of various VQA models from Hugging Face
13
+ """
14
+
15
+ def __init__(self, cache_dir=None):
16
+ """
17
+ Initialize the model manager
18
+
19
+ Args:
20
+ cache_dir (str, optional): Directory to cache models. Defaults to None.
21
+ """
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.cache_dir = cache_dir
24
+ self.models = {}
25
+ self.processors = {}
26
+
27
+ # Print device being used
28
+ print(f"Using device: {self.device}")
29
+
30
+ def load_blip(self):
31
+ """
32
+ Load BLIP model for VQA
33
+
34
+ Returns:
35
+ tuple: (processor, model)
36
+ """
37
+ if "blip" not in self.models:
38
+ print("Loading BLIP model for visual question answering...")
39
+
40
+ # Load processor and model
41
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
42
+ model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
43
+
44
+ # Move model to appropriate device
45
+ model.to(self.device)
46
+
47
+ # Store model and processor
48
+ self.models["blip"] = model
49
+ self.processors["blip"] = processor
50
+
51
+ print("BLIP model loaded successfully!")
52
+
53
+ return self.processors["blip"], self.models["blip"]
54
+
55
+ def load_vilt(self):
56
+ """
57
+ Load ViLT model for VQA
58
+
59
+ Returns:
60
+ tuple: (processor, model)
61
+ """
62
+ if "vilt" not in self.models:
63
+ print("Loading ViLT model for visual question answering...")
64
+
65
+ # Load processor and model
66
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-vqa")
67
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-vqa")
68
+
69
+ # Move model to appropriate device
70
+ model.to(self.device)
71
+
72
+ # Store model and processor
73
+ self.models["vilt"] = model
74
+ self.processors["vilt"] = processor
75
+
76
+ print("ViLT model loaded successfully!")
77
+
78
+ return self.processors["vilt"], self.models["vilt"]
79
+
80
+ def get_model(self, model_name="blip"):
81
+ """
82
+ Get a model by name
83
+
84
+ Args:
85
+ model_name (str, optional): Name of model to load. Defaults to "blip".
86
+ Options: "blip", "vilt"
87
+
88
+ Returns:
89
+ tuple: (processor, model)
90
+ """
91
+ if model_name.lower() == "blip":
92
+ return self.load_blip()
93
+ elif model_name.lower() == "vilt":
94
+ return self.load_vilt()
95
+ else:
96
+ raise ValueError(
97
+ f"Unknown model: {model_name}. Available models: blip, vilt"
98
+ )
models/vqa_inference.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import traceback
4
+ import torch
5
+ from PIL import Image, UnidentifiedImageError
6
+ from .model_loader import ModelManager
7
+
8
+
9
+ class VQAInference:
10
+ """
11
+ Class to perform inference with Visual Question Answering models
12
+ """
13
+
14
+ def __init__(self, model_name="blip", cache_dir=None):
15
+ """
16
+ Initialize the VQA inference
17
+
18
+ Args:
19
+ model_name (str, optional): Name of model to use. Defaults to "blip".
20
+ cache_dir (str, optional): Directory to cache models. Defaults to None.
21
+ """
22
+ self.model_name = model_name
23
+ self.model_manager = ModelManager(cache_dir=cache_dir)
24
+ self.processor, self.model = self.model_manager.get_model(model_name)
25
+ self.device = self.model_manager.device
26
+
27
+ def predict(self, image, question):
28
+ """
29
+ Perform VQA prediction on an image with a question
30
+
31
+ Args:
32
+ image (PIL.Image.Image or str): Image to analyze or path to image
33
+ question (str): Question to ask about the image
34
+
35
+ Returns:
36
+ str: Answer to the question
37
+ """
38
+ # Handle image input - could be a file path or PIL Image
39
+ if isinstance(image, str):
40
+ try:
41
+ # Check if file exists
42
+ if not os.path.exists(image):
43
+ raise FileNotFoundError(f"Image file not found: {image}")
44
+
45
+ # Try multiple approaches to load the image
46
+ try:
47
+ # Try the standard approach first
48
+ image = Image.open(image).convert("RGB")
49
+ print(
50
+ f"Successfully opened image: {image.size}, mode: {image.mode}"
51
+ )
52
+ except Exception as img_err:
53
+ print(
54
+ f"Standard image loading failed: {img_err}, trying alternative method..."
55
+ )
56
+
57
+ # Try alternative approach with binary mode explicitly
58
+ with open(image, "rb") as img_file:
59
+ img_data = img_file.read()
60
+ image = Image.open(io.BytesIO(img_data)).convert("RGB")
61
+ print(
62
+ f"Alternative image loading succeeded: {image.size}, mode: {image.mode}"
63
+ )
64
+
65
+ except UnidentifiedImageError as e:
66
+ # Specific error when image format cannot be identified
67
+ raise ValueError(f"Cannot identify image format: {str(e)}")
68
+ except Exception as e:
69
+ # Provide detailed error information
70
+ error_details = traceback.format_exc()
71
+ print(f"Error details: {error_details}")
72
+ raise ValueError(f"Could not open image file: {str(e)}")
73
+
74
+ # Make sure image is a PIL Image
75
+ if not isinstance(image, Image.Image):
76
+ raise ValueError("Image must be a PIL Image or a file path")
77
+
78
+ # Process based on model type
79
+ if self.model_name.lower() == "blip":
80
+ return self._predict_with_blip(image, question)
81
+ elif self.model_name.lower() == "vilt":
82
+ return self._predict_with_vilt(image, question)
83
+ else:
84
+ raise ValueError(f"Prediction not implemented for model: {self.model_name}")
85
+
86
+ def _predict_with_blip(self, image, question):
87
+ """
88
+ Perform prediction with BLIP model
89
+
90
+ Args:
91
+ image (PIL.Image.Image): Image to analyze
92
+ question (str): Question to ask about the image
93
+
94
+ Returns:
95
+ str: Answer to the question
96
+ """
97
+ try:
98
+ # Process image and text inputs
99
+ inputs = self.processor(
100
+ images=image, text=question, return_tensors="pt"
101
+ ).to(self.device)
102
+
103
+ # Generate answer
104
+ with torch.no_grad():
105
+ outputs = self.model.generate(**inputs)
106
+
107
+ # Decode the output to text
108
+ answer = self.processor.decode(outputs[0], skip_special_tokens=True)
109
+
110
+ return answer
111
+ except Exception as e:
112
+ error_details = traceback.format_exc()
113
+ print(f"Error in BLIP prediction: {str(e)}")
114
+ print(f"Error details: {error_details}")
115
+ raise RuntimeError(f"BLIP model prediction failed: {str(e)}")
116
+
117
+ def _predict_with_vilt(self, image, question):
118
+ """
119
+ Perform prediction with ViLT model
120
+
121
+ Args:
122
+ image (PIL.Image.Image): Image to analyze
123
+ question (str): Question to ask about the image
124
+
125
+ Returns:
126
+ str: Answer to the question
127
+ """
128
+ try:
129
+ # Process image and text inputs
130
+ encoding = self.processor(images=image, text=question, return_tensors="pt")
131
+
132
+ # Move inputs to device
133
+ for k, v in encoding.items():
134
+ encoding[k] = v.to(self.device)
135
+
136
+ # Forward pass
137
+ with torch.no_grad():
138
+ outputs = self.model(**encoding)
139
+ logits = outputs.logits
140
+
141
+ # Get the predicted answer idx
142
+ idx = logits.argmax(-1).item()
143
+
144
+ # Convert to answer text
145
+ answer = self.model.config.id2label[idx]
146
+
147
+ return answer
148
+ except Exception as e:
149
+ error_details = traceback.format_exc()
150
+ print(f"Error in ViLT prediction: {str(e)}")
151
+ print(f"Error details: {error_details}")
152
+ raise RuntimeError(f"ViLT model prediction failed: {str(e)}")