Spaces:
Runtime error
Runtime error
import torch | |
from transformers import ( | |
BlipForQuestionAnswering, | |
BlipProcessor, | |
ViltForQuestionAnswering, | |
ViltProcessor, | |
) | |
class ModelManager: | |
""" | |
Class to manage loading and caching of various VQA models from Hugging Face | |
""" | |
def __init__(self, cache_dir=None): | |
""" | |
Initialize the model manager | |
Args: | |
cache_dir (str, optional): Directory to cache models. Defaults to None. | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.cache_dir = cache_dir | |
self.models = {} | |
self.processors = {} | |
# Print device being used | |
print(f"Using device: {self.device}") | |
def load_blip(self): | |
""" | |
Load BLIP model for VQA | |
Returns: | |
tuple: (processor, model) | |
""" | |
if "blip" not in self.models: | |
print("Loading BLIP model for visual question answering...") | |
# Load processor and model | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
# Move model to appropriate device | |
model.to(self.device) | |
# Store model and processor | |
self.models["blip"] = model | |
self.processors["blip"] = processor | |
print("BLIP model loaded successfully!") | |
return self.processors["blip"], self.models["blip"] | |
def load_vilt(self): | |
""" | |
Load ViLT model for VQA | |
Returns: | |
tuple: (processor, model) | |
""" | |
if "vilt" not in self.models: | |
print("Loading ViLT model for visual question answering...") | |
# Load processor and model | |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-vqa") | |
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-vqa") | |
# Move model to appropriate device | |
model.to(self.device) | |
# Store model and processor | |
self.models["vilt"] = model | |
self.processors["vilt"] = processor | |
print("ViLT model loaded successfully!") | |
return self.processors["vilt"], self.models["vilt"] | |
def get_model(self, model_name="blip"): | |
""" | |
Get a model by name | |
Args: | |
model_name (str, optional): Name of model to load. Defaults to "blip". | |
Options: "blip", "vilt" | |
Returns: | |
tuple: (processor, model) | |
""" | |
if model_name.lower() == "blip": | |
return self.load_blip() | |
elif model_name.lower() == "vilt": | |
return self.load_vilt() | |
else: | |
raise ValueError( | |
f"Unknown model: {model_name}. Available models: blip, vilt" | |
) | |