Spaces:
Runtime error
Runtime error
File size: 5,702 Bytes
e2a4738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import io
import os
import traceback
import torch
from PIL import Image, UnidentifiedImageError
from .model_loader import ModelManager
class VQAInference:
"""
Class to perform inference with Visual Question Answering models
"""
def __init__(self, model_name="blip", cache_dir=None):
"""
Initialize the VQA inference
Args:
model_name (str, optional): Name of model to use. Defaults to "blip".
cache_dir (str, optional): Directory to cache models. Defaults to None.
"""
self.model_name = model_name
self.model_manager = ModelManager(cache_dir=cache_dir)
self.processor, self.model = self.model_manager.get_model(model_name)
self.device = self.model_manager.device
def predict(self, image, question):
"""
Perform VQA prediction on an image with a question
Args:
image (PIL.Image.Image or str): Image to analyze or path to image
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
# Handle image input - could be a file path or PIL Image
if isinstance(image, str):
try:
# Check if file exists
if not os.path.exists(image):
raise FileNotFoundError(f"Image file not found: {image}")
# Try multiple approaches to load the image
try:
# Try the standard approach first
image = Image.open(image).convert("RGB")
print(
f"Successfully opened image: {image.size}, mode: {image.mode}"
)
except Exception as img_err:
print(
f"Standard image loading failed: {img_err}, trying alternative method..."
)
# Try alternative approach with binary mode explicitly
with open(image, "rb") as img_file:
img_data = img_file.read()
image = Image.open(io.BytesIO(img_data)).convert("RGB")
print(
f"Alternative image loading succeeded: {image.size}, mode: {image.mode}"
)
except UnidentifiedImageError as e:
# Specific error when image format cannot be identified
raise ValueError(f"Cannot identify image format: {str(e)}")
except Exception as e:
# Provide detailed error information
error_details = traceback.format_exc()
print(f"Error details: {error_details}")
raise ValueError(f"Could not open image file: {str(e)}")
# Make sure image is a PIL Image
if not isinstance(image, Image.Image):
raise ValueError("Image must be a PIL Image or a file path")
# Process based on model type
if self.model_name.lower() == "blip":
return self._predict_with_blip(image, question)
elif self.model_name.lower() == "vilt":
return self._predict_with_vilt(image, question)
else:
raise ValueError(f"Prediction not implemented for model: {self.model_name}")
def _predict_with_blip(self, image, question):
"""
Perform prediction with BLIP model
Args:
image (PIL.Image.Image): Image to analyze
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
try:
# Process image and text inputs
inputs = self.processor(
images=image, text=question, return_tensors="pt"
).to(self.device)
# Generate answer
with torch.no_grad():
outputs = self.model.generate(**inputs)
# Decode the output to text
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
return answer
except Exception as e:
error_details = traceback.format_exc()
print(f"Error in BLIP prediction: {str(e)}")
print(f"Error details: {error_details}")
raise RuntimeError(f"BLIP model prediction failed: {str(e)}")
def _predict_with_vilt(self, image, question):
"""
Perform prediction with ViLT model
Args:
image (PIL.Image.Image): Image to analyze
question (str): Question to ask about the image
Returns:
str: Answer to the question
"""
try:
# Process image and text inputs
encoding = self.processor(images=image, text=question, return_tensors="pt")
# Move inputs to device
for k, v in encoding.items():
encoding[k] = v.to(self.device)
# Forward pass
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
# Get the predicted answer idx
idx = logits.argmax(-1).item()
# Convert to answer text
answer = self.model.config.id2label[idx]
return answer
except Exception as e:
error_details = traceback.format_exc()
print(f"Error in ViLT prediction: {str(e)}")
print(f"Error details: {error_details}")
raise RuntimeError(f"ViLT model prediction failed: {str(e)}")
|