File size: 5,702 Bytes
e2a4738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import io
import os
import traceback
import torch
from PIL import Image, UnidentifiedImageError
from .model_loader import ModelManager


class VQAInference:
    """

    Class to perform inference with Visual Question Answering models

    """

    def __init__(self, model_name="blip", cache_dir=None):
        """

        Initialize the VQA inference



        Args:

            model_name (str, optional): Name of model to use. Defaults to "blip".

            cache_dir (str, optional): Directory to cache models. Defaults to None.

        """
        self.model_name = model_name
        self.model_manager = ModelManager(cache_dir=cache_dir)
        self.processor, self.model = self.model_manager.get_model(model_name)
        self.device = self.model_manager.device

    def predict(self, image, question):
        """

        Perform VQA prediction on an image with a question



        Args:

            image (PIL.Image.Image or str): Image to analyze or path to image

            question (str): Question to ask about the image



        Returns:

            str: Answer to the question

        """
        # Handle image input - could be a file path or PIL Image
        if isinstance(image, str):
            try:
                # Check if file exists
                if not os.path.exists(image):
                    raise FileNotFoundError(f"Image file not found: {image}")

                # Try multiple approaches to load the image
                try:
                    # Try the standard approach first
                    image = Image.open(image).convert("RGB")
                    print(
                        f"Successfully opened image: {image.size}, mode: {image.mode}"
                    )
                except Exception as img_err:
                    print(
                        f"Standard image loading failed: {img_err}, trying alternative method..."
                    )

                    # Try alternative approach with binary mode explicitly
                    with open(image, "rb") as img_file:
                        img_data = img_file.read()
                        image = Image.open(io.BytesIO(img_data)).convert("RGB")
                        print(
                            f"Alternative image loading succeeded: {image.size}, mode: {image.mode}"
                        )

            except UnidentifiedImageError as e:
                # Specific error when image format cannot be identified
                raise ValueError(f"Cannot identify image format: {str(e)}")
            except Exception as e:
                # Provide detailed error information
                error_details = traceback.format_exc()
                print(f"Error details: {error_details}")
                raise ValueError(f"Could not open image file: {str(e)}")

        # Make sure image is a PIL Image
        if not isinstance(image, Image.Image):
            raise ValueError("Image must be a PIL Image or a file path")

        # Process based on model type
        if self.model_name.lower() == "blip":
            return self._predict_with_blip(image, question)
        elif self.model_name.lower() == "vilt":
            return self._predict_with_vilt(image, question)
        else:
            raise ValueError(f"Prediction not implemented for model: {self.model_name}")

    def _predict_with_blip(self, image, question):
        """

        Perform prediction with BLIP model



        Args:

            image (PIL.Image.Image): Image to analyze

            question (str): Question to ask about the image



        Returns:

            str: Answer to the question

        """
        try:
            # Process image and text inputs
            inputs = self.processor(
                images=image, text=question, return_tensors="pt"
            ).to(self.device)

            # Generate answer
            with torch.no_grad():
                outputs = self.model.generate(**inputs)

            # Decode the output to text
            answer = self.processor.decode(outputs[0], skip_special_tokens=True)

            return answer
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"Error in BLIP prediction: {str(e)}")
            print(f"Error details: {error_details}")
            raise RuntimeError(f"BLIP model prediction failed: {str(e)}")

    def _predict_with_vilt(self, image, question):
        """

        Perform prediction with ViLT model



        Args:

            image (PIL.Image.Image): Image to analyze

            question (str): Question to ask about the image



        Returns:

            str: Answer to the question

        """
        try:
            # Process image and text inputs
            encoding = self.processor(images=image, text=question, return_tensors="pt")

            # Move inputs to device
            for k, v in encoding.items():
                encoding[k] = v.to(self.device)

            # Forward pass
            with torch.no_grad():
                outputs = self.model(**encoding)
                logits = outputs.logits

            # Get the predicted answer idx
            idx = logits.argmax(-1).item()

            # Convert to answer text
            answer = self.model.config.id2label[idx]

            return answer
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"Error in ViLT prediction: {str(e)}")
            print(f"Error details: {error_details}")
            raise RuntimeError(f"ViLT model prediction failed: {str(e)}")