CyberForge / src /database_init.py
Replit Deployment
Deployment from Replit
bb6d7b4
"""
Database initialization for the application.
This script checks if the database is initialized and creates tables if needed.
It's meant to be imported and run at application startup.
"""
import os
import logging
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.future import select
import subprocess
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Database URL from environment
db_url = os.getenv("DATABASE_URL", "")
if db_url.startswith("postgresql://"):
# Remove sslmode parameter if present which causes issues with asyncpg
if "?" in db_url:
base_url, params = db_url.split("?", 1)
param_list = params.split("&")
filtered_params = [p for p in param_list if not p.startswith("sslmode=")]
if filtered_params:
db_url = f"{base_url}?{'&'.join(filtered_params)}"
else:
db_url = base_url
ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
else:
ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
async def check_db_initialized():
"""Check if the database is initialized with required tables."""
try:
engine = create_async_engine(
ASYNC_DATABASE_URL,
echo=False,
)
# Create session factory
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
async with async_session() as session:
# Try to query tables
# Replace with actual table names once you've defined them
try:
# Check if the 'users' table exists
from sqlalchemy import text
query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')")
result = await session.execute(query)
exists = result.scalar()
if exists:
logger.info("Database is initialized.")
return True
else:
logger.warning("Database tables are not initialized.")
return False
except Exception as e:
logger.error(f"Error checking tables: {e}")
return False
except Exception as e:
logger.error(f"Failed to connect to database: {e}")
return False
def initialize_database():
"""Initialize the database with required tables."""
try:
# Call the init_db.py script
logger.info("Initializing database...")
# Get the current directory
current_dir = os.path.dirname(os.path.abspath(__file__))
script_path = os.path.join(current_dir, "scripts", "init_db.py")
# Run the script using the current Python interpreter
result = subprocess.run([sys.executable, script_path], capture_output=True, text=True)
if result.returncode == 0:
logger.info("Database initialized successfully.")
logger.debug(result.stdout)
return True
else:
logger.error(f"Failed to initialize database: {result.stderr}")
return False
except Exception as e:
logger.error(f"Error initializing database: {e}")
return False
def ensure_database_initialized():
"""Ensure the database is initialized with required tables."""
is_initialized = asyncio.run(check_db_initialized())
if not is_initialized:
return initialize_database()
return True