File size: 10,186 Bytes
5e37be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import (
    SiglipVisionModel, 
    AutoTokenizer, 
    AutoImageProcessor, 
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import os
from PIL import Image

class LinearProjection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

class ImageTextProjection(nn.Module):
    def __init__(self, image_dim, text_dim):
        super().__init__()
        self.image_projection = nn.Linear(image_dim, text_dim)
        
    def forward(self, x):
        return self.image_projection(x)

def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device):
    with torch.no_grad():
        # Process image through SigLIP
        inputs = siglip_processor(image, return_tensors="pt")
        # Move inputs to the same device as model
        inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        outputs = siglip_model(**inputs)
        image_features = outputs.pooler_output
        
        # Project through trained linear layer
        projected_features = linear_proj(image_features)
        
    return projected_features

def main(
    num_images=100,
    batch_size=4,  # Smaller batch size due to memory constraints
    num_epochs=100,
    learning_rate=2e-4,
    questions=None  # List of 5 questions to be provided
):
    if questions is None or len(questions) != 5:
        print("Please provide exactly 5 questions!")
        return

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load SigLIP model and processor
    siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
    siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
    
    # Load trained linear projection
    dummy_image = Image.new('RGB', (384, 384), color='black')
    with torch.no_grad():
        siglip_inputs = siglip_processor(dummy_image, return_tensors="pt")
        # Move inputs to device
        siglip_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in siglip_inputs.items()}
        siglip_outputs = siglip_model(**siglip_inputs)
        siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
    
    # First load the checkpoint to get the correct output dimension
    checkpoint = torch.load('linear_projection_final.pth', map_location=device)
    output_dim = checkpoint['linear.weight'].shape[0]  # Get the output dimension from saved weights
    print(f"Loading linear projection with output dimension: {output_dim}")
    
    # Initialize linear projection with correct dimensions
    linear_proj = LinearProjection(siglip_output_dim, output_dim).to(device)
    try:
        linear_proj.load_state_dict(checkpoint)
        print("Successfully loaded linear projection weights")
    except Exception as e:
        print(f"Error loading linear projection weights: {e}")
        return

    # Load Phi model with 4-bit quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False
    )

    phi_model = AutoModelForCausalLM.from_pretrained(
        "microsoft/Phi-3-mini-4k-instruct",
        quantization_config=bnb_config,
        device_map="auto"
    )
    phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
    
    # Add padding token if not present
    if phi_tokenizer.pad_token is None:
        phi_tokenizer.pad_token = phi_tokenizer.eos_token
    
    # Get embedding dimension from phi model
    phi_embed_dim = phi_model.get_input_embeddings().weight.shape[1]
    
    # Create projection layer for image embeddings
    image_text_proj = ImageTextProjection(output_dim, phi_embed_dim).to(device)
    
    # Prepare model for k-bit training
    phi_model = prepare_model_for_kbit_training(phi_model)

    # Setup LoRA configuration
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["mlp.dense_h_to_4h", "mlp.dense_4h_to_h", "self_attn.qkv_proj", "self_attn.dense"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    # Get PEFT model
    phi_model = get_peft_model(phi_model, lora_config)
    
    # Freeze SigLIP and linear projection
    for param in siglip_model.parameters():
        param.requires_grad = False
    for param in linear_proj.parameters():
        param.requires_grad = False

    # Load CIFAR10 test dataset
    transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
    ])
    
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    subset_indices = list(range(num_images))
    subset_dataset = Subset(test_dataset, subset_indices)
    dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)

    # Optimizer for both phi model and image projection
    optimizer = AdamW([
        {'params': phi_model.parameters()},
        {'params': image_text_proj.parameters()}
    ], lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        phi_model.train()
        image_text_proj.train()
        
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for batch_idx, (images, _) in enumerate(progress_bar):
            images = images.to(device)
            batch_size = images.size(0)

            # Get image embeddings
            image_embeddings = get_image_embedding(images, siglip_model, siglip_processor, linear_proj, device)

            # Process each question
            for q_idx, question in enumerate(questions):
                # Read corresponding answers
                answers = []
                for idx in range(batch_size):
                    global_idx = batch_idx * batch_size + idx
                    if global_idx < num_images:
                        file_path = f'qa_outputs/image_{global_idx}_extr.txt'
                        try:
                            with open(file_path, 'r') as f:
                                lines = f.readlines()
                                answer = lines[q_idx].strip() if q_idx < len(lines) else ""
                                answers.append(answer)
                        except:
                            answers.append("No answer available")

                # Tokenize questions and answers for the entire batch
                question_tokens = phi_tokenizer(
                    [question] * batch_size,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt"
                ).to(device)

                target_tokens = phi_tokenizer(
                    answers,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt"
                ).to(device)

                # Get question embeddings for the entire batch
                question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids'])  # [batch_size, seq_len, embed_dim]

                # Project and prepare image embeddings for the entire batch
                image_embeds = image_text_proj(image_embeddings)  # [batch_size, embed_dim]
                image_embeds = image_embeds.unsqueeze(1)  # [batch_size, 1, embed_dim]

                # Combine image embeddings with question embeddings
                combined_embedding = torch.cat([
                    image_embeds,  # [batch_size, 1, embed_dim]
                    question_embeds  # [batch_size, seq_len, embed_dim]
                ], dim=1)  # [batch_size, 1+seq_len, embed_dim]

                # Create attention mask for the combined sequence
                attention_mask = torch.ones(
                    (batch_size, combined_embedding.size(1)),
                    dtype=torch.long,
                    device=device
                )

                # Prepare labels by shifting them right
                labels = target_tokens['input_ids'].clone()
                labels = torch.cat([
                    torch.full((batch_size, combined_embedding.size(1) - 1), -100, device=device),
                    labels
                ], dim=1)[:, :combined_embedding.size(1)]

                # Forward pass
                outputs = phi_model(
                    inputs_embeds=combined_embedding,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss
                total_loss += loss.item()

                # Backward pass
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                progress_bar.set_postfix({'loss': loss.item()})

        avg_epoch_loss = total_loss / (len(dataloader) * len(questions) * batch_size)
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')

    # Save the trained models
    phi_model.save_pretrained('phi_model_trained')
    torch.save(image_text_proj.state_dict(), 'image_text_proj.pth')
    print("Training completed. Models saved as 'phi_model_trained' and 'image_text_proj.pth'")

if __name__ == "__main__":
    # Example questions - replace with your actual questions
    questions = [
    "Give a description of the image?",
    "How does the main object in the image look like?",
    "How can the main object in the image be useful to humans?",
    "What is the color of the main object in the image?",
    "Describe the setting of the image?"
    ]
    
    main(questions=questions)