Phi3-VLM-On-Cifar10 / process_cifar10.py
chbsaikiran's picture
Intial Commit
5e37be9
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import os
from transformers import AutoProcessor, AutoModelForImageTextToText
from tqdm import tqdm
# Initialize model and processor
model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
#_attn_implementation="flash_attention_2"
).to("cuda" if torch.cuda.is_available() else "cpu")
# Create output directory
os.makedirs("SigLIP_Training/qa_outputs", exist_ok=True)
# Load CIFAR-10 dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage()
])
# Using test set instead of train set
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
# List of questions
questions = [
"Give a description of the image?",
"How does the main object in the image look like?",
"How can the main object in the image be useful to humans?",
"What is the color of the main object in the image?",
"Describe the setting of the image?"
]
def process_image(image, image_idx):
# Create output file
output_file = f"SigLIP_Training/qa_outputs/image_{image_idx}.txt"
with open(output_file, 'w') as f:
for q_idx, question in enumerate(questions, 1):
# Prepare the message for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question}
]
}
]
# Process inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
# Generate answer
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Write to file in the correct format
f.write(f"Q{q_idx}: {question}\n")
f.write(f"A{q_idx}: {answer}\n")
# Process all images from test set
print(f"Starting to process CIFAR-10 test set images...")
for idx, (image, _) in enumerate(tqdm(testset)):
process_image(image, idx)
#if idx >= 1000: # Process first 1000 test images
# break
print("Processing complete! Check the SigLIP_Training/qa_outputs directory for results.")