Spaces:
Running
Running
File size: 11,508 Bytes
89ae94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
"""
API Security Module
This module provides security features for the API, including:
1. Authentication using JWT tokens
2. Rate limiting to prevent abuse
3. Role-based access control
4. Request validation
5. Audit logging
"""
import os
import time
import logging
import secrets
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union, Any, Callable
from fastapi import Depends, HTTPException, Security, status, Request
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from pydantic import BaseModel, EmailStr
from src.models.user import User
from src.api.database import get_db
# Configure logging
logger = logging.getLogger(__name__)
# Security configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
API_KEY_NAME = "X-API-Key"
# Set up password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Set up security schemes
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
# User models
class Token(BaseModel):
access_token: str
token_type: str
expires_at: datetime
class TokenData(BaseModel):
username: Optional[str] = None
scopes: List[str] = []
class UserInDB(BaseModel):
id: int
username: str
email: EmailStr
full_name: Optional[str] = None
is_active: bool = True
is_superuser: bool = False
scopes: List[str] = []
class Config:
from_attributes = True
# Rate limiting
class RateLimiter:
"""Simple in-memory rate limiter"""
def __init__(self, rate_limit: int = 100, time_window: int = 60):
"""
Initialize rate limiter.
Args:
rate_limit: Maximum number of requests per time window
time_window: Time window in seconds
"""
self.rate_limit = rate_limit
self.time_window = time_window
self.requests = {}
def is_rate_limited(self, key: str) -> bool:
"""
Check if a key is rate limited.
Args:
key: Identifier for the client (IP address, API key, etc.)
Returns:
True if rate limited, False otherwise
"""
current_time = time.time()
# Initialize or clean up old requests
if key not in self.requests:
self.requests[key] = []
else:
# Remove requests outside the time window
self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window]
# Check if rate limit is exceeded
if len(self.requests[key]) >= self.rate_limit:
return True
# Add the current request
self.requests[key].append(current_time)
return False
# Create global rate limiter instance
rate_limiter = RateLimiter()
# Role-based access control
# Define roles and permissions
ROLES = {
"admin": ["read:all", "write:all", "delete:all"],
"analyst": ["read:all", "write:threats", "write:indicators", "write:reports"],
"user": ["read:threats", "read:reports", "read:dashboard"],
"api": ["read:all", "write:threats", "write:indicators"]
}
# Security utility functions
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Hash a password for storage"""
return pwd_context.hash(password)
async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]:
"""Get a user from the database by username"""
result = await db.execute(select(User).filter(User.username == username))
user_db = result.scalars().first()
if not user_db:
return None
# Get user roles and scopes
scopes = []
if user_db.is_superuser:
scopes = ROLES["admin"]
else:
# In a real application, you would look up user roles in a database
# For simplicity, we'll assume non-superusers have the "user" role
scopes = ROLES["user"]
return UserInDB(
id=user_db.id,
username=user_db.username,
email=user_db.email,
full_name=user_db.full_name,
is_active=user_db.is_active,
is_superuser=user_db.is_superuser,
scopes=scopes
)
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
"""Authenticate a user with username and password"""
user = await get_user(db, username)
if not user:
return None
# Get the user from the database again to get the hashed password
result = await db.execute(select(User).filter(User.username == username))
user_db = result.scalars().first()
if not verify_password(password, user_db.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_api_key_user(
api_key: str,
db: AsyncSession
) -> Optional[UserInDB]:
"""Get user associated with an API key"""
# In a real application, you would look up API keys in a database
# For simplicity, we'll use a simple hardcoded mapping
# TODO: Replace with database-backed API key storage
API_KEYS = {
"test-api-key": "api_user",
# Add more API keys here
}
if api_key not in API_KEYS:
return None
username = API_KEYS[api_key]
user = await get_user(db, username)
if not user:
return None
# Override scopes with API role scopes
user.scopes = ROLES["api"]
return user
# Dependencies for FastAPI
async def rate_limit(request: Request):
"""Rate limiting dependency"""
# Use API key or IP address as the rate limit key
client_key = request.headers.get(API_KEY_NAME) or request.client.host
if rate_limiter.is_rate_limited(client_key):
logger.warning(f"Rate limit exceeded for {client_key}")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded. Please try again later."
)
return True
async def get_current_user(
token: str = Depends(oauth2_scheme),
api_key: str = Security(api_key_header),
db: AsyncSession = Depends(get_db)
) -> UserInDB:
"""
Get the current user from either JWT token or API key.
This dependency can be used to require authentication for endpoints.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Check API key first
if api_key:
user = await get_api_key_user(api_key, db)
if user:
return user
# Then check JWT token
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(
username=username,
scopes=payload.get("scopes", [])
)
except JWTError:
raise credentials_exception
user = await get_user(db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: UserInDB = Depends(get_current_user)
) -> UserInDB:
"""
Get the current active user.
This dependency can be used to require an active user for endpoints.
"""
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def has_scope(required_scopes: List[str]):
"""
Create a dependency that checks if the user has the required scopes.
Args:
required_scopes: List of required scopes
Returns:
A dependency function that checks if the user has the required scopes
"""
async def _has_scope(
current_user: UserInDB = Depends(get_current_active_user)
) -> UserInDB:
for scope in required_scopes:
if scope not in current_user.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission denied. Required scope: {scope}"
)
return current_user
return _has_scope
def admin_only(
current_user: UserInDB = Depends(get_current_active_user)
) -> UserInDB:
"""
Dependency that requires an admin user.
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied. Admin access required."
)
return current_user
# Audit logging middleware
async def audit_log_middleware(request: Request, call_next):
"""
Middleware for audit logging.
Records details about API requests.
"""
# Get request details
method = request.method
path = request.url.path
client_host = request.client.host
user_agent = request.headers.get("User-Agent", "Unknown")
# Get user details if available
user = getattr(request.state, "user", None)
username = user.username if user else "Anonymous"
# Log request
logger.info(
f"API Request: {method} {path} | User: {username} | "
f"Client: {client_host} | User-Agent: {user_agent}"
)
# Process the request
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# Log response
logger.info(
f"API Response: {method} {path} | Status: {response.status_code} | "
f"Time: {process_time:.4f}s | User: {username}"
)
return response
# API key validation middleware
def validate_api_key(request: Request):
"""
Middleware function to validate API keys.
This can be used as a dependency for FastAPI routes.
"""
api_key = request.headers.get(API_KEY_NAME)
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key required",
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
)
# In a real application, you would validate the API key against a database
# For simplicity, we'll use a hardcoded list
valid_keys = ["test-api-key"] # Replace with database lookup
if api_key not in valid_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
)
return True |