Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import AdamW | |
from transformers import SiglipVisionModel, AutoTokenizer, AutoImageProcessor, AutoModel | |
from torchvision.datasets import CIFAR10 | |
from torch.utils.data import DataLoader, Subset | |
import torchvision.transforms as transforms | |
from tqdm import tqdm | |
import os | |
import numpy as np | |
from PIL import Image | |
import argparse | |
def siglip_loss(image_embeddings, text_embeddings, temperature=0.07): | |
# Normalize | |
image_embeddings = F.normalize(image_embeddings, dim=-1) | |
text_embeddings = F.normalize(text_embeddings, dim=-1) | |
# Compute pairwise similarities | |
logits = image_embeddings @ text_embeddings.T # [batch_size, batch_size] | |
logits = logits / temperature | |
# Ground truth: 1.0 for matching pairs (diagonal), 0.0 for all others | |
batch_size = logits.size(0) | |
targets = torch.eye(batch_size).to(logits.device) | |
# Apply binary cross-entropy with logits | |
loss = F.binary_cross_entropy_with_logits(logits, targets) | |
return loss | |
class LinearProjection(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
return self.linear(x) | |
def get_text_embedding(text, tokenizer, device, max_length=128): | |
# Ensure text is not empty and has minimum content | |
if not text or len(text.strip()) == 0: | |
text = "This is a placeholder description." | |
# Tokenize with padding and truncation | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
padding='max_length', # Changed to max_length padding | |
truncation=True, | |
max_length=max_length # Fixed max length for all inputs | |
) | |
# Move inputs to device and ensure correct data type | |
inputs = { | |
k: v.to(device).float() for k, v in inputs.items() | |
} | |
# Return the input_ids as embeddings | |
return inputs['input_ids'].float() # Convert to float for the loss calculation | |
def main(num_images=100, batch_size=32, num_epochs=50, learning_rate=1e-4, load_checkpoint=True, checkpoint_path='linear_projection.pth'): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load models and processors | |
siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384") | |
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") | |
# Set padding token if not set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Freeze SigLIP model | |
for param in siglip_model.parameters(): | |
param.requires_grad = False | |
siglip_model.to(device) | |
# Get SigLIP output dimension and text embedding dimension | |
# Create a proper dummy image (black image) | |
dummy_image = Image.new('RGB', (384, 384), color='black') | |
with torch.no_grad(): | |
siglip_inputs = siglip_processor(dummy_image, return_tensors="pt").to(device) | |
siglip_outputs = siglip_model(**siglip_inputs) | |
siglip_output_dim = siglip_outputs.pooler_output.shape[-1] | |
# Get a sample text to determine embedding dimension | |
dummy_text = "This is a test." | |
dummy_embedding = get_text_embedding(dummy_text, tokenizer, device) | |
text_embedding_dim = dummy_embedding.shape[-1] | |
print(f"SigLIP output dimension: {siglip_output_dim}") | |
print(f"Text embedding dimension: {text_embedding_dim}") | |
# Create linear projection layer | |
linear_proj = LinearProjection(siglip_output_dim, text_embedding_dim).to(device) | |
# Load checkpoint if requested | |
if load_checkpoint: | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
linear_proj.load_state_dict(checkpoint) | |
print(f"Successfully loaded checkpoint from {checkpoint_path}") | |
except Exception as e: | |
print(f"Error loading checkpoint: {e}") | |
print("Starting training from scratch instead.") | |
# Load CIFAR10 test dataset | |
transform = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
]) | |
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) | |
subset_indices = list(range(num_images)) | |
subset_dataset = Subset(test_dataset, subset_indices) | |
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False) | |
# Create text files directory if it doesn't exist | |
os.makedirs('qa_outputs', exist_ok=True) | |
# Optimizer | |
optimizer = AdamW(linear_proj.parameters(), lr=learning_rate) | |
# Training loop | |
for epoch in range(num_epochs): | |
total_loss = 0 | |
linear_proj.train() | |
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}') | |
for batch_idx, (images, labels) in enumerate(progress_bar): | |
images = images.to(device) | |
batch_size = images.size(0) | |
# Get image embeddings | |
with torch.no_grad(): | |
siglip_inputs = siglip_processor(images, return_tensors="pt").to(device) | |
siglip_outputs = siglip_model(**siglip_inputs) | |
image_features = siglip_outputs.pooler_output | |
# Project image features | |
projected_image_features = linear_proj(image_features) | |
# Process text for each line (1 to 5) | |
total_batch_loss = 0 | |
for line_num in range(5): | |
text_embeddings_list = [] | |
# Read text from files for current batch | |
for idx in range(batch_size): | |
global_idx = batch_idx * batch_size + idx | |
if global_idx < num_images: | |
file_path = f'qa_outputs/image_{global_idx}_extr.txt' | |
try: | |
with open(file_path, 'r') as f: | |
lines = f.readlines() | |
text = lines[line_num].strip() if line_num < len(lines) else "" | |
except: | |
text = "No description available" | |
# Get text embeddings directly from tokenizer | |
text_embedding = get_text_embedding(text, tokenizer, device) | |
text_embeddings_list.append(text_embedding) | |
if text_embeddings_list: | |
# Stack instead of cat since all embeddings have same size now | |
text_embeddings = torch.stack(text_embeddings_list, dim=0).squeeze(1) | |
loss = siglip_loss(projected_image_features, text_embeddings) | |
total_batch_loss += loss | |
# Average loss over all text lines | |
avg_batch_loss = total_batch_loss / 5 | |
# Backpropagation | |
optimizer.zero_grad() | |
avg_batch_loss.backward() | |
optimizer.step() | |
total_loss += avg_batch_loss.item() | |
progress_bar.set_postfix({'loss': avg_batch_loss.item()}) | |
avg_epoch_loss = total_loss / len(dataloader) | |
print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}') | |
# Save checkpoint after each epoch | |
# checkpoint_dir = 'checkpoints' | |
# os.makedirs(checkpoint_dir, exist_ok=True) | |
# checkpoint_file = os.path.join(checkpoint_dir, f'linear_projection_epoch_{epoch+1}.pth') | |
# torch.save(linear_proj.state_dict(), checkpoint_file) | |
# print(f"Saved checkpoint to {checkpoint_file}") | |
# Save final model | |
torch.save(linear_proj.state_dict(), 'linear_projection_final.pth') | |
print("Training completed. Final model saved as 'linear_projection_final.pth'") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Train or continue training the linear projection layer') | |
parser.add_argument('--num_images', type=int, default=100, help='Number of images to train on') | |
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') | |
parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs to train') | |
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') | |
parser.add_argument('--load_checkpoint', action='store_true', help='Whether to load from checkpoint') | |
parser.add_argument('--checkpoint_path', type=str, default='linear_projection.pth', help='Path to checkpoint file') | |
args = parser.parse_args() | |
main( | |
num_images=args.num_images, | |
batch_size=args.batch_size, | |
num_epochs=args.num_epochs, | |
learning_rate=args.learning_rate, | |
load_checkpoint=args.load_checkpoint, | |
checkpoint_path=args.checkpoint_path | |
) |