File size: 3,767 Bytes
cfb0d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Utility script to pre-generate embedding pickle files for all models.

This script will:
1. Load each embedding model
2. Generate embeddings for both emotion and event dictionaries
3. Save the embeddings as pickle files in the 'embeddings' directory

Run this script once locally to create all pickle files before uploading to the repository.
"""

import os
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from config import CONFIG, EMBEDDING_MODELS
from utils import (logger, kitchen_txt_to_dict, 
                  save_embeddings_to_pickle, get_embeddings_pickle_path)

def generate_embeddings_for_model(model_key, model_info):
    """Generate and save embeddings for a specific model.
    
    Args:
        model_key: Key of the model in EMBEDDING_MODELS
        model_info: Model information dictionary
        
    Returns:
        Tuple of (success_emotion, success_event)
    """
    model_id = model_info['id']
    print(f"\nProcessing model: {model_key} ({model_id}) - {model_info['size']}")
    
    try:
        # Load the model
        print(f"Loading {model_key} model...")
        model = SentenceTransformer(model_id)
        
        # Load emoji dictionaries
        print("Loading emoji dictionaries...")
        emotion_dict = kitchen_txt_to_dict(CONFIG["emotion_file"])
        event_dict = kitchen_txt_to_dict(CONFIG["item_file"])
        
        if not emotion_dict or not event_dict:
            print("Error: Failed to load emoji dictionaries")
            return False, False
            
        # Generate emotion embeddings
        print(f"Generating {len(emotion_dict)} emotion embeddings...")
        emotion_embeddings = {}
        for emoji, desc in tqdm(emotion_dict.items()):
            emotion_embeddings[emoji] = model.encode(desc)
            
        # Generate event embeddings
        print(f"Generating {len(event_dict)} event embeddings...")
        event_embeddings = {}
        for emoji, desc in tqdm(event_dict.items()):
            event_embeddings[emoji] = model.encode(desc)
            
        # Save embeddings
        emotion_pickle_path = get_embeddings_pickle_path(model_id, "emotion")
        event_pickle_path = get_embeddings_pickle_path(model_id, "event")
        
        success_emotion = save_embeddings_to_pickle(emotion_embeddings, emotion_pickle_path)
        success_event = save_embeddings_to_pickle(event_embeddings, event_pickle_path)
        
        return success_emotion, success_event
    except Exception as e:
        print(f"Error generating embeddings for model {model_key}: {e}")
        return False, False

def main():
    """Main function to generate embeddings for all models."""
    # Create embeddings directory if it doesn't exist
    os.makedirs('embeddings', exist_ok=True)
    
    print(f"Generating embeddings for {len(EMBEDDING_MODELS)} models...")
    
    results = {}
    
    # Generate embeddings for each model
    for model_key, model_info in EMBEDDING_MODELS.items():
        success_emotion, success_event = generate_embeddings_for_model(model_key, model_info)
        results[model_key] = {
            'emotion': success_emotion,
            'event': success_event
        }
    
    # Print summary
    print("\n=== Embedding Generation Summary ===")
    for model_key, result in results.items():
        status_emotion = "βœ“ Success" if result['emotion'] else "βœ— Failed"
        status_event = "βœ“ Success" if result['event'] else "βœ— Failed"
        print(f"{model_key:<10}: Emotion: {status_emotion}, Event: {status_event}")
    
    print("\nDone! Embedding pickle files are stored in the 'embeddings' directory.")
    print("You can now upload these files to your repository.")

if __name__ == "__main__":
    main()