import os import sys import torch import numpy as np import csv import argparse import open_clip def load_descriptions(file_path): """Load descriptions from a CSV file.""" descriptions = [] with open(file_path, 'r') as file: csv_reader = csv.reader(file) next(csv_reader) # Skip the header for row in csv_reader: descriptions.append(row[0]) return descriptions def generate_embeddings(descriptions, model, tokenizer, device, batch_size): """Generate text embeddings in batches.""" final_embeddings = [] for i in range(0, len(descriptions), batch_size): batch_desc = descriptions[i:i + batch_size] texts = tokenizer(batch_desc).to(device) batch_embeddings = model.encode_text(texts) batch_embeddings = batch_embeddings.detach().cpu().numpy() final_embeddings.append(batch_embeddings) del texts, batch_embeddings torch.cuda.empty_cache() return np.vstack(final_embeddings) def save_embeddings(output_file, embeddings): """Save embeddings to a .npy file.""" np.save(output_file, embeddings) def main(): parser = argparse.ArgumentParser(description="Generate text embeddings using CLIP.") parser.add_argument("--input_csv", type=str, required=True, help="Path to the input CSV file containing text descriptions.") parser.add_argument("--output_file", type=str, required=True, help="Path to save the output .npy file.") parser.add_argument("--batch_size", type=int, default=100, help="Batch size for processing embeddings.") parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the model on (e.g., 'cuda:0' or 'cpu').") args = parser.parse_args() # Load the CLIP model and tokenizer model, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') model.to(args.device) tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') # Load descriptions from CSV descriptions = load_descriptions(args.input_csv) # Generate embeddings embeddings = generate_embeddings(descriptions, model, tokenizer, args.device, args.batch_size) # Save embeddings to output file save_embeddings(args.output_file, embeddings) if __name__ == "__main__": main()