File size: 669 Bytes
e19aac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# This file is modified from https://github.com/haotian-liu/LLaVA/
import torch

from llava.model.multimodal_encoder.vision_encoder import VisionTower
from transformers import (
    PretrainedConfig,
    CLIPVisionModel,
    CLIPImageProcessor,
)


class CLIPVisionTower(VisionTower):
    def __init__(self, model_name_or_path: str, config: PretrainedConfig):
        super().__init__(model_name_or_path, config)
        self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
        self.vision_tower = CLIPVisionModel.from_pretrained(
            model_name_or_path, torch_dtype=eval(config.model_dtype)
        )
        self.is_loaded = True