Spaces:
Runtime error
Runtime error
File size: 2,967 Bytes
e2a4738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from transformers import (
BlipForQuestionAnswering,
BlipProcessor,
ViltForQuestionAnswering,
ViltProcessor,
)
class ModelManager:
"""
Class to manage loading and caching of various VQA models from Hugging Face
"""
def __init__(self, cache_dir=None):
"""
Initialize the model manager
Args:
cache_dir (str, optional): Directory to cache models. Defaults to None.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cache_dir = cache_dir
self.models = {}
self.processors = {}
# Print device being used
print(f"Using device: {self.device}")
def load_blip(self):
"""
Load BLIP model for VQA
Returns:
tuple: (processor, model)
"""
if "blip" not in self.models:
print("Loading BLIP model for visual question answering...")
# Load processor and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# Move model to appropriate device
model.to(self.device)
# Store model and processor
self.models["blip"] = model
self.processors["blip"] = processor
print("BLIP model loaded successfully!")
return self.processors["blip"], self.models["blip"]
def load_vilt(self):
"""
Load ViLT model for VQA
Returns:
tuple: (processor, model)
"""
if "vilt" not in self.models:
print("Loading ViLT model for visual question answering...")
# Load processor and model
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-vqa")
# Move model to appropriate device
model.to(self.device)
# Store model and processor
self.models["vilt"] = model
self.processors["vilt"] = processor
print("ViLT model loaded successfully!")
return self.processors["vilt"], self.models["vilt"]
def get_model(self, model_name="blip"):
"""
Get a model by name
Args:
model_name (str, optional): Name of model to load. Defaults to "blip".
Options: "blip", "vilt"
Returns:
tuple: (processor, model)
"""
if model_name.lower() == "blip":
return self.load_blip()
elif model_name.lower() == "vilt":
return self.load_vilt()
else:
raise ValueError(
f"Unknown model: {model_name}. Available models: blip, vilt"
)
|