Spaces:
Running
Running

Implement asynchronous database initialization: add init_db function to create collections and indexes in MongoDB, update get_db to support async, and modify main.py to await database initialization during startup.
9b31d36
"""MongoDB database configuration.""" | |
import os | |
import logging | |
from typing import Optional | |
from motor.motor_asyncio import AsyncIOMotorClient | |
from pymongo.errors import ConnectionFailure | |
from dotenv import load_dotenv | |
import certifi | |
from datetime import datetime | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
# Get MongoDB URI from environment variable | |
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017") | |
DB_NAME = os.getenv("DB_NAME", "tts_api") | |
# MongoDB client instance | |
client: Optional[AsyncIOMotorClient] = None | |
async def init_db(): | |
"""Initialize the database with required collections and indexes.""" | |
try: | |
# Get database instance | |
db = await get_db() | |
# Create collections if they don't exist | |
collections = await db.list_collection_names() | |
# Audiobooks collection | |
if AUDIOBOOKS_COLLECTION not in collections: | |
logger.info(f"Creating collection: {AUDIOBOOKS_COLLECTION}") | |
await db.create_collection(AUDIOBOOKS_COLLECTION) | |
# Create indexes | |
await db[AUDIOBOOKS_COLLECTION].create_index("id", unique=True) | |
await db[AUDIOBOOKS_COLLECTION].create_index("created_at") | |
await db[AUDIOBOOKS_COLLECTION].create_index("status") | |
# Voices collection | |
if VOICES_COLLECTION not in collections: | |
logger.info(f"Creating collection: {VOICES_COLLECTION}") | |
await db.create_collection(VOICES_COLLECTION) | |
# Create indexes | |
await db[VOICES_COLLECTION].create_index("id", unique=True) | |
await db[VOICES_COLLECTION].create_index("name") | |
await db[VOICES_COLLECTION].create_index("type") | |
# Audio cache collection | |
if AUDIO_CACHE_COLLECTION not in collections: | |
logger.info(f"Creating collection: {AUDIO_CACHE_COLLECTION}") | |
await db.create_collection(AUDIO_CACHE_COLLECTION) | |
# Create indexes | |
await db[AUDIO_CACHE_COLLECTION].create_index("hash", unique=True) | |
await db[AUDIO_CACHE_COLLECTION].create_index("created_at") | |
logger.info("Database initialization completed successfully") | |
except Exception as e: | |
logger.error(f"Error initializing database: {str(e)}") | |
raise | |
async def connect_to_mongo(): | |
"""Connect to MongoDB.""" | |
global client | |
try: | |
# Configure client with proper SSL settings for Atlas | |
client = AsyncIOMotorClient( | |
MONGO_URI, | |
tls=True, | |
tlsCAFile=certifi.where(), | |
serverSelectionTimeoutMS=5000, | |
connectTimeoutMS=10000 | |
) | |
# Verify the connection | |
await client.admin.command('ping') | |
logger.info("Successfully connected to MongoDB") | |
except ConnectionFailure as e: | |
logger.error(f"Could not connect to MongoDB: {e}") | |
raise | |
async def close_mongo_connection(): | |
"""Close MongoDB connection.""" | |
global client | |
if client: | |
client.close() | |
logger.info("MongoDB connection closed") | |
async def get_db(): | |
"""Get database instance.""" | |
global client | |
if not client: | |
await connect_to_mongo() | |
return client[DB_NAME] | |
# Collection names | |
AUDIOBOOKS_COLLECTION = "audiobooks" | |
VOICES_COLLECTION = "voices" | |
AUDIO_CACHE_COLLECTION = "audio_cache" | |
# Database schemas/models | |
AUDIOBOOK_SCHEMA = { | |
"id": str, # UUID string | |
"title": str, | |
"author": str, | |
"voice_id": str, | |
"status": str, # pending, processing, completed, failed | |
"created_at": str, # ISO format datetime | |
"updated_at": str, # ISO format datetime | |
"duration": float, | |
"file_path": str, | |
"error": str, | |
"meta_data": dict | |
} | |
VOICE_SCHEMA = { | |
"id": str, # UUID string | |
"name": str, | |
"type": str, # standard, cloned | |
"speaker_id": int, | |
"created_at": str, # ISO format datetime | |
"is_active": bool, | |
"meta_data": dict | |
} | |
AUDIO_CACHE_SCHEMA = { | |
"id": str, # UUID string | |
"hash": str, # Hash of input parameters | |
"format": str, # Audio format (mp3, wav, etc.) | |
"created_at": str, # ISO format datetime | |
"file_path": str, | |
"meta_data": dict | |
} |