Spaces:
Running
Running
""" | |
API Security Module | |
This module provides security features for the API, including: | |
1. Authentication using JWT tokens | |
2. Rate limiting to prevent abuse | |
3. Role-based access control | |
4. Request validation | |
5. Audit logging | |
""" | |
import os | |
import time | |
import logging | |
import secrets | |
from datetime import datetime, timedelta | |
from typing import Dict, List, Optional, Union, Any, Callable | |
from fastapi import Depends, HTTPException, Security, status, Request | |
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader | |
from jose import JWTError, jwt | |
from passlib.context import CryptContext | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sqlalchemy.future import select | |
from pydantic import BaseModel, EmailStr | |
from src.models.user import User | |
from src.api.database import get_db | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
# Security configuration | |
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32)) | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
API_KEY_NAME = "X-API-Key" | |
# Set up password hashing | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# Set up security schemes | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
# User models | |
class Token(BaseModel): | |
access_token: str | |
token_type: str | |
expires_at: datetime | |
class TokenData(BaseModel): | |
username: Optional[str] = None | |
scopes: List[str] = [] | |
class UserInDB(BaseModel): | |
id: int | |
username: str | |
email: EmailStr | |
full_name: Optional[str] = None | |
is_active: bool = True | |
is_superuser: bool = False | |
scopes: List[str] = [] | |
class Config: | |
from_attributes = True | |
# Rate limiting | |
class RateLimiter: | |
"""Simple in-memory rate limiter""" | |
def __init__(self, rate_limit: int = 100, time_window: int = 60): | |
""" | |
Initialize rate limiter. | |
Args: | |
rate_limit: Maximum number of requests per time window | |
time_window: Time window in seconds | |
""" | |
self.rate_limit = rate_limit | |
self.time_window = time_window | |
self.requests = {} | |
def is_rate_limited(self, key: str) -> bool: | |
""" | |
Check if a key is rate limited. | |
Args: | |
key: Identifier for the client (IP address, API key, etc.) | |
Returns: | |
True if rate limited, False otherwise | |
""" | |
current_time = time.time() | |
# Initialize or clean up old requests | |
if key not in self.requests: | |
self.requests[key] = [] | |
else: | |
# Remove requests outside the time window | |
self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window] | |
# Check if rate limit is exceeded | |
if len(self.requests[key]) >= self.rate_limit: | |
return True | |
# Add the current request | |
self.requests[key].append(current_time) | |
return False | |
# Create global rate limiter instance | |
rate_limiter = RateLimiter() | |
# Role-based access control | |
# Define roles and permissions | |
ROLES = { | |
"admin": ["read:all", "write:all", "delete:all"], | |
"analyst": ["read:all", "write:threats", "write:indicators", "write:reports"], | |
"user": ["read:threats", "read:reports", "read:dashboard"], | |
"api": ["read:all", "write:threats", "write:indicators"] | |
} | |
# Security utility functions | |
def verify_password(plain_password: str, hashed_password: str) -> bool: | |
"""Verify a password against a hash""" | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password: str) -> str: | |
"""Hash a password for storage""" | |
return pwd_context.hash(password) | |
async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]: | |
"""Get a user from the database by username""" | |
result = await db.execute(select(User).filter(User.username == username)) | |
user_db = result.scalars().first() | |
if not user_db: | |
return None | |
# Get user roles and scopes | |
scopes = [] | |
if user_db.is_superuser: | |
scopes = ROLES["admin"] | |
else: | |
# In a real application, you would look up user roles in a database | |
# For simplicity, we'll assume non-superusers have the "user" role | |
scopes = ROLES["user"] | |
return UserInDB( | |
id=user_db.id, | |
username=user_db.username, | |
email=user_db.email, | |
full_name=user_db.full_name, | |
is_active=user_db.is_active, | |
is_superuser=user_db.is_superuser, | |
scopes=scopes | |
) | |
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]: | |
"""Authenticate a user with username and password""" | |
user = await get_user(db, username) | |
if not user: | |
return None | |
# Get the user from the database again to get the hashed password | |
result = await db.execute(select(User).filter(User.username == username)) | |
user_db = result.scalars().first() | |
if not verify_password(password, user_db.hashed_password): | |
return None | |
return user | |
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
"""Create a JWT access token""" | |
to_encode = data.copy() | |
if expires_delta: | |
expire = datetime.utcnow() + expires_delta | |
else: | |
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
async def get_api_key_user( | |
api_key: str, | |
db: AsyncSession | |
) -> Optional[UserInDB]: | |
"""Get user associated with an API key""" | |
# In a real application, you would look up API keys in a database | |
# For simplicity, we'll use a simple hardcoded mapping | |
# TODO: Replace with database-backed API key storage | |
API_KEYS = { | |
"test-api-key": "api_user", | |
# Add more API keys here | |
} | |
if api_key not in API_KEYS: | |
return None | |
username = API_KEYS[api_key] | |
user = await get_user(db, username) | |
if not user: | |
return None | |
# Override scopes with API role scopes | |
user.scopes = ROLES["api"] | |
return user | |
# Dependencies for FastAPI | |
async def rate_limit(request: Request): | |
"""Rate limiting dependency""" | |
# Use API key or IP address as the rate limit key | |
client_key = request.headers.get(API_KEY_NAME) or request.client.host | |
if rate_limiter.is_rate_limited(client_key): | |
logger.warning(f"Rate limit exceeded for {client_key}") | |
raise HTTPException( | |
status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
detail="Rate limit exceeded. Please try again later." | |
) | |
return True | |
async def get_current_user( | |
token: str = Depends(oauth2_scheme), | |
api_key: str = Security(api_key_header), | |
db: AsyncSession = Depends(get_db) | |
) -> UserInDB: | |
""" | |
Get the current user from either JWT token or API key. | |
This dependency can be used to require authentication for endpoints. | |
""" | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
# Check API key first | |
if api_key: | |
user = await get_api_key_user(api_key, db) | |
if user: | |
return user | |
# Then check JWT token | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
token_data = TokenData( | |
username=username, | |
scopes=payload.get("scopes", []) | |
) | |
except JWTError: | |
raise credentials_exception | |
user = await get_user(db, username=token_data.username) | |
if user is None: | |
raise credentials_exception | |
return user | |
async def get_current_active_user( | |
current_user: UserInDB = Depends(get_current_user) | |
) -> UserInDB: | |
""" | |
Get the current active user. | |
This dependency can be used to require an active user for endpoints. | |
""" | |
if not current_user.is_active: | |
raise HTTPException(status_code=400, detail="Inactive user") | |
return current_user | |
def has_scope(required_scopes: List[str]): | |
""" | |
Create a dependency that checks if the user has the required scopes. | |
Args: | |
required_scopes: List of required scopes | |
Returns: | |
A dependency function that checks if the user has the required scopes | |
""" | |
async def _has_scope( | |
current_user: UserInDB = Depends(get_current_active_user) | |
) -> UserInDB: | |
for scope in required_scopes: | |
if scope not in current_user.scopes: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail=f"Permission denied. Required scope: {scope}" | |
) | |
return current_user | |
return _has_scope | |
def admin_only( | |
current_user: UserInDB = Depends(get_current_active_user) | |
) -> UserInDB: | |
""" | |
Dependency that requires an admin user. | |
""" | |
if not current_user.is_superuser: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="Permission denied. Admin access required." | |
) | |
return current_user | |
# Audit logging middleware | |
async def audit_log_middleware(request: Request, call_next): | |
""" | |
Middleware for audit logging. | |
Records details about API requests. | |
""" | |
# Get request details | |
method = request.method | |
path = request.url.path | |
client_host = request.client.host | |
user_agent = request.headers.get("User-Agent", "Unknown") | |
# Get user details if available | |
user = getattr(request.state, "user", None) | |
username = user.username if user else "Anonymous" | |
# Log request | |
logger.info( | |
f"API Request: {method} {path} | User: {username} | " | |
f"Client: {client_host} | User-Agent: {user_agent}" | |
) | |
# Process the request | |
start_time = time.time() | |
response = await call_next(request) | |
process_time = time.time() - start_time | |
# Log response | |
logger.info( | |
f"API Response: {method} {path} | Status: {response.status_code} | " | |
f"Time: {process_time:.4f}s | User: {username}" | |
) | |
return response | |
# API key validation middleware | |
def validate_api_key(request: Request): | |
""" | |
Middleware function to validate API keys. | |
This can be used as a dependency for FastAPI routes. | |
""" | |
api_key = request.headers.get(API_KEY_NAME) | |
if not api_key: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="API key required", | |
headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, | |
) | |
# In a real application, you would validate the API key against a database | |
# For simplicity, we'll use a hardcoded list | |
valid_keys = ["test-api-key"] # Replace with database lookup | |
if api_key not in valid_keys: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid API key", | |
headers={"WWW-Authenticate": f"{API_KEY_NAME}"}, | |
) | |
return True |