Spaces:
Sleeping
Sleeping
""" | |
Utility script to pre-generate embedding pickle files for all models. | |
This script will: | |
1. Load each embedding model | |
2. Generate embeddings for both emotion and event dictionaries | |
3. Save the embeddings as pickle files in the 'embeddings' directory | |
Run this script once locally to create all pickle files before uploading to the repository. | |
""" | |
import os | |
from sentence_transformers import SentenceTransformer | |
from tqdm import tqdm | |
from config import CONFIG, EMBEDDING_MODELS | |
from utils import (logger, kitchen_txt_to_dict, | |
save_embeddings_to_pickle, get_embeddings_pickle_path) | |
def generate_embeddings_for_model(model_key, model_info): | |
"""Generate and save embeddings for a specific model. | |
Args: | |
model_key: Key of the model in EMBEDDING_MODELS | |
model_info: Model information dictionary | |
Returns: | |
Tuple of (success_emotion, success_event) | |
""" | |
model_id = model_info['id'] | |
print(f"\nProcessing model: {model_key} ({model_id}) - {model_info['size']}") | |
try: | |
# Load the model | |
print(f"Loading {model_key} model...") | |
model = SentenceTransformer(model_id) | |
# Load emoji dictionaries | |
print("Loading emoji dictionaries...") | |
emotion_dict = kitchen_txt_to_dict(CONFIG["emotion_file"]) | |
event_dict = kitchen_txt_to_dict(CONFIG["item_file"]) | |
if not emotion_dict or not event_dict: | |
print("Error: Failed to load emoji dictionaries") | |
return False, False | |
# Generate emotion embeddings | |
print(f"Generating {len(emotion_dict)} emotion embeddings...") | |
emotion_embeddings = {} | |
for emoji, desc in tqdm(emotion_dict.items()): | |
emotion_embeddings[emoji] = model.encode(desc) | |
# Generate event embeddings | |
print(f"Generating {len(event_dict)} event embeddings...") | |
event_embeddings = {} | |
for emoji, desc in tqdm(event_dict.items()): | |
event_embeddings[emoji] = model.encode(desc) | |
# Save embeddings | |
emotion_pickle_path = get_embeddings_pickle_path(model_id, "emotion") | |
event_pickle_path = get_embeddings_pickle_path(model_id, "event") | |
success_emotion = save_embeddings_to_pickle(emotion_embeddings, emotion_pickle_path) | |
success_event = save_embeddings_to_pickle(event_embeddings, event_pickle_path) | |
return success_emotion, success_event | |
except Exception as e: | |
print(f"Error generating embeddings for model {model_key}: {e}") | |
return False, False | |
def main(): | |
"""Main function to generate embeddings for all models.""" | |
# Create embeddings directory if it doesn't exist | |
os.makedirs('embeddings', exist_ok=True) | |
print(f"Generating embeddings for {len(EMBEDDING_MODELS)} models...") | |
results = {} | |
# Generate embeddings for each model | |
for model_key, model_info in EMBEDDING_MODELS.items(): | |
success_emotion, success_event = generate_embeddings_for_model(model_key, model_info) | |
results[model_key] = { | |
'emotion': success_emotion, | |
'event': success_event | |
} | |
# Print summary | |
print("\n=== Embedding Generation Summary ===") | |
for model_key, result in results.items(): | |
status_emotion = "β Success" if result['emotion'] else "β Failed" | |
status_event = "β Success" if result['event'] else "β Failed" | |
print(f"{model_key:<10}: Emotion: {status_emotion}, Event: {status_event}") | |
print("\nDone! Embedding pickle files are stored in the 'embeddings' directory.") | |
print("You can now upload these files to your repository.") | |
if __name__ == "__main__": | |
main() |