""" API Security Module This module provides security features for the API, including: 1. Authentication using JWT tokens 2. Rate limiting to prevent abuse 3. Role-based access control 4. Request validation 5. Audit logging """ import os import time import logging import secrets from datetime import datetime, timedelta from typing import Dict, List, Optional, Union, Any, Callable from fastapi import Depends, HTTPException, Security, status, Request from fastapi.security import OAuth2PasswordBearer, APIKeyHeader from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from pydantic import BaseModel, EmailStr from src.models.user import User from src.api.database import get_db # Configure logging logger = logging.getLogger(__name__) # Security configuration SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32)) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 API_KEY_NAME = "X-API-Key" # Set up password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # Set up security schemes oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) # User models class Token(BaseModel): access_token: str token_type: str expires_at: datetime class TokenData(BaseModel): username: Optional[str] = None scopes: List[str] = [] class UserInDB(BaseModel): id: int username: str email: EmailStr full_name: Optional[str] = None is_active: bool = True is_superuser: bool = False scopes: List[str] = [] class Config: from_attributes = True # Rate limiting class RateLimiter: """Simple in-memory rate limiter""" def __init__(self, rate_limit: int = 100, time_window: int = 60): """ Initialize rate limiter. Args: rate_limit: Maximum number of requests per time window time_window: Time window in seconds """ self.rate_limit = rate_limit self.time_window = time_window self.requests = {} def is_rate_limited(self, key: str) -> bool: """ Check if a key is rate limited. Args: key: Identifier for the client (IP address, API key, etc.) Returns: True if rate limited, False otherwise """ current_time = time.time() # Initialize or clean up old requests if key not in self.requests: self.requests[key] = [] else: # Remove requests outside the time window self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window] # Check if rate limit is exceeded if len(self.requests[key]) >= self.rate_limit: return True # Add the current request self.requests[key].append(current_time) return False # Create global rate limiter instance rate_limiter = RateLimiter() # Role-based access control # Define roles and permissions ROLES = { "admin": ["read:all", "write:all", "delete:all"], "analyst": ["read:all", "write:threats", "write:indicators", "write:reports"], "user": ["read:threats", "read:reports", "read:dashboard"], "api": ["read:all", "write:threats", "write:indicators"] } # Security utility functions def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against a hash""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Hash a password for storage""" return pwd_context.hash(password) async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]: """Get a user from the database by username""" result = await db.execute(select(User).filter(User.username == username)) user_db = result.scalars().first() if not user_db: return None # Get user roles and scopes scopes = [] if user_db.is_superuser: scopes = ROLES["admin"] else: # In a real application, you would look up user roles in a database # For simplicity, we'll assume non-superusers have the "user" role scopes = ROLES["user"] return UserInDB( id=user_db.id, username=user_db.username, email=user_db.email, full_name=user_db.full_name, is_active=user_db.is_active, is_superuser=user_db.is_superuser, scopes=scopes ) async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]: """Authenticate a user with username and password""" user = await get_user(db, username) if not user: return None # Get the user from the database again to get the hashed password result = await db.execute(select(User).filter(User.username == username)) user_db = result.scalars().first() if not verify_password(password, user_db.hashed_password): return None return user def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create a JWT access token""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_api_key_user( api_key: str, db: AsyncSession ) -> Optional[UserInDB]: """Get user associated with an API key""" # In a real application, you would look up API keys in a database # For simplicity, we'll use a simple hardcoded mapping # TODO: Replace with database-backed API key storage API_KEYS = { "test-api-key": "api_user", # Add more API keys here } if api_key not in API_KEYS: return None username = API_KEYS[api_key] user = await get_user(db, username) if not user: return None # Override scopes with API role scopes user.scopes = ROLES["api"] return user # Dependencies for FastAPI async def rate_limit(request: Request): """Rate limiting dependency""" # Use API key or IP address as the rate limit key client_key = request.headers.get(API_KEY_NAME) or request.client.host if rate_limiter.is_rate_limited(client_key): logger.warning(f"Rate limit exceeded for {client_key}") raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded. Please try again later." ) return True async def get_current_user( token: str = Depends(oauth2_scheme), api_key: str = Security(api_key_header), db: AsyncSession = Depends(get_db) ) -> UserInDB: """ Get the current user from either JWT token or API key. This dependency can be used to require authentication for endpoints. """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) # Check API key first if api_key: user = await get_api_key_user(api_key, db) if user: return user # Then check JWT token try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData( username=username, scopes=payload.get("scopes", []) ) except JWTError: raise credentials_exception user = await get_user(db, username=token_data.username) if user is None: raise credentials_exception return user async def get_current_active_user( current_user: UserInDB = Depends(get_current_user) ) -> UserInDB: """ Get the current active user. This dependency can be used to require an active user for endpoints. """ if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") return current_user def has_scope(required_scopes: List[str]): """ Create a dependency that checks if the user has the required scopes. Args: required_scopes: List of required scopes Returns: A dependency function that checks if the user has the required scopes """ async def _has_scope( current_user: UserInDB = Depends(get_current_active_user) ) -> UserInDB: for scope in required_scopes: if scope not in current_user.scopes: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Permission denied. Required scope: {scope}" ) return current_user return _has_scope def admin_only( current_user: UserInDB = Depends(get_current_active_user) ) -> UserInDB: """ Dependency that requires an admin user. """ if not current_user.is_superuser: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied. Admin access required." ) return current_user # Audit logging middleware async def audit_log_middleware(request: Request, call_next): """ Middleware for audit logging. Records details about API requests. """ # Get request details method = request.method path = request.url.path client_host = request.client.host user_agent = request.headers.get("User-Agent", "Unknown") # Get user details if available user = getattr(request.state, "user", None) username = user.username if user else "Anonymous" # Log request logger.info( f"API Request: {method} {path} | User: {username} | " f"Client: {client_host} | User-Agent: {user_agent}" ) # Process the request start_time = time.time() response = await call_next(request) process_time = time.time() - start_time # Log response logger.info( f"API Response: {method} {path} | Status: {response.status_code} | " f"Time: {process_time:.4f}s | User: {username}" ) return response # API key validation middleware def validate_api_key(request: Request): """ Middleware function to validate API keys. This can be used as a dependency for FastAPI routes. """ api_key = request.headers.get(API_KEY_NAME) if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API key required", headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, ) # In a real application, you would validate the API key against a database # For simplicity, we'll use a hardcoded list valid_keys = ["test-api-key"] # Replace with database lookup if api_key not in valid_keys: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key", headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, ) return True