Spaces:
Running
Running
""" | |
Service for threat operations. | |
""" | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sqlalchemy.future import select | |
from sqlalchemy import func, or_, and_ | |
from datetime import datetime, timedelta | |
from typing import List, Optional, Dict, Any, Union | |
from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory | |
from src.models.indicator import Indicator, IndicatorType | |
from src.api.schemas import PaginationParams | |
async def create_threat( | |
db: AsyncSession, | |
title: str, | |
description: str, | |
severity: ThreatSeverity, | |
category: ThreatCategory, | |
status: ThreatStatus = ThreatStatus.NEW, | |
source_url: Optional[str] = None, | |
source_name: Optional[str] = None, | |
source_type: Optional[str] = None, | |
affected_entity: Optional[str] = None, | |
affected_entity_type: Optional[str] = None, | |
confidence_score: float = 0.0, | |
risk_score: float = 0.0, | |
) -> Threat: | |
""" | |
Create a new threat. | |
Args: | |
db: Database session | |
title: Threat title | |
description: Threat description | |
severity: Threat severity | |
category: Threat category | |
status: Threat status | |
source_url: URL of the source | |
source_name: Name of the source | |
source_type: Type of source | |
affected_entity: Name of affected entity | |
affected_entity_type: Type of affected entity | |
confidence_score: Confidence score (0-1) | |
risk_score: Risk score (0-1) | |
Returns: | |
Threat: Created threat | |
""" | |
db_threat = Threat( | |
title=title, | |
description=description, | |
severity=severity, | |
category=category, | |
status=status, | |
source_url=source_url, | |
source_name=source_name, | |
source_type=source_type, | |
discovered_at=datetime.utcnow(), | |
affected_entity=affected_entity, | |
affected_entity_type=affected_entity_type, | |
confidence_score=confidence_score, | |
risk_score=risk_score, | |
) | |
db.add(db_threat) | |
await db.commit() | |
await db.refresh(db_threat) | |
return db_threat | |
async def get_threat_by_id(db: AsyncSession, threat_id: int) -> Optional[Threat]: | |
""" | |
Get threat by ID. | |
Args: | |
db: Database session | |
threat_id: Threat ID | |
Returns: | |
Optional[Threat]: Threat or None if not found | |
""" | |
result = await db.execute(select(Threat).filter(Threat.id == threat_id)) | |
return result.scalars().first() | |
async def get_threats( | |
db: AsyncSession, | |
pagination: PaginationParams, | |
severity: Optional[List[ThreatSeverity]] = None, | |
status: Optional[List[ThreatStatus]] = None, | |
category: Optional[List[ThreatCategory]] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> List[Threat]: | |
""" | |
Get threats with filtering and pagination. | |
Args: | |
db: Database session | |
pagination: Pagination parameters | |
severity: Filter by severity | |
status: Filter by status | |
category: Filter by category | |
search_query: Search in title and description | |
from_date: Filter by discovered_at >= from_date | |
to_date: Filter by discovered_at <= to_date | |
Returns: | |
List[Threat]: List of threats | |
""" | |
query = select(Threat) | |
# Apply filters | |
if severity: | |
query = query.filter(Threat.severity.in_(severity)) | |
if status: | |
query = query.filter(Threat.status.in_(status)) | |
if category: | |
query = query.filter(Threat.category.in_(category)) | |
if search_query: | |
search_filter = or_( | |
Threat.title.ilike(f"%{search_query}%"), | |
Threat.description.ilike(f"%{search_query}%") | |
) | |
query = query.filter(search_filter) | |
if from_date: | |
query = query.filter(Threat.discovered_at >= from_date) | |
if to_date: | |
query = query.filter(Threat.discovered_at <= to_date) | |
# Apply pagination | |
query = query.order_by(Threat.discovered_at.desc()) | |
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size) | |
result = await db.execute(query) | |
return result.scalars().all() | |
async def count_threats( | |
db: AsyncSession, | |
severity: Optional[List[ThreatSeverity]] = None, | |
status: Optional[List[ThreatStatus]] = None, | |
category: Optional[List[ThreatCategory]] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> int: | |
""" | |
Count threats with filtering. | |
Args: | |
db: Database session | |
severity: Filter by severity | |
status: Filter by status | |
category: Filter by category | |
search_query: Search in title and description | |
from_date: Filter by discovered_at >= from_date | |
to_date: Filter by discovered_at <= to_date | |
Returns: | |
int: Count of threats | |
""" | |
query = select(func.count(Threat.id)) | |
# Apply filters (same as in get_threats) | |
if severity: | |
query = query.filter(Threat.severity.in_(severity)) | |
if status: | |
query = query.filter(Threat.status.in_(status)) | |
if category: | |
query = query.filter(Threat.category.in_(category)) | |
if search_query: | |
search_filter = or_( | |
Threat.title.ilike(f"%{search_query}%"), | |
Threat.description.ilike(f"%{search_query}%") | |
) | |
query = query.filter(search_filter) | |
if from_date: | |
query = query.filter(Threat.discovered_at >= from_date) | |
if to_date: | |
query = query.filter(Threat.discovered_at <= to_date) | |
result = await db.execute(query) | |
return result.scalar() | |
async def update_threat( | |
db: AsyncSession, | |
threat_id: int, | |
title: Optional[str] = None, | |
description: Optional[str] = None, | |
severity: Optional[ThreatSeverity] = None, | |
status: Optional[ThreatStatus] = None, | |
category: Optional[ThreatCategory] = None, | |
affected_entity: Optional[str] = None, | |
affected_entity_type: Optional[str] = None, | |
confidence_score: Optional[float] = None, | |
risk_score: Optional[float] = None, | |
) -> Optional[Threat]: | |
""" | |
Update threat. | |
Args: | |
db: Database session | |
threat_id: Threat ID | |
title: New title | |
description: New description | |
severity: New severity | |
status: New status | |
category: New category | |
affected_entity: New affected entity | |
affected_entity_type: New affected entity type | |
confidence_score: New confidence score | |
risk_score: New risk score | |
Returns: | |
Optional[Threat]: Updated threat or None if not found | |
""" | |
threat = await get_threat_by_id(db, threat_id) | |
if not threat: | |
return None | |
if title is not None: | |
threat.title = title | |
if description is not None: | |
threat.description = description | |
if severity is not None: | |
threat.severity = severity | |
if status is not None: | |
threat.status = status | |
if category is not None: | |
threat.category = category | |
if affected_entity is not None: | |
threat.affected_entity = affected_entity | |
if affected_entity_type is not None: | |
threat.affected_entity_type = affected_entity_type | |
if confidence_score is not None: | |
threat.confidence_score = confidence_score | |
if risk_score is not None: | |
threat.risk_score = risk_score | |
threat.updated_at = datetime.utcnow() | |
await db.commit() | |
await db.refresh(threat) | |
return threat | |
async def add_indicator_to_threat( | |
db: AsyncSession, | |
threat_id: int, | |
value: str, | |
indicator_type: IndicatorType, | |
description: Optional[str] = None, | |
is_verified: bool = False, | |
context: Optional[str] = None, | |
source: Optional[str] = None, | |
confidence_score: float = 0.0, | |
) -> Indicator: | |
""" | |
Add an indicator to a threat. | |
Args: | |
db: Database session | |
threat_id: Threat ID | |
value: Indicator value | |
indicator_type: Indicator type | |
description: Description of the indicator | |
is_verified: Whether the indicator is verified | |
context: Context of the indicator | |
source: Source of the indicator | |
confidence_score: Confidence score (0-1) | |
Returns: | |
Indicator: Created indicator | |
""" | |
# Check if threat exists | |
threat = await get_threat_by_id(db, threat_id) | |
if not threat: | |
raise ValueError(f"Threat with ID {threat_id} not found") | |
# Create indicator | |
db_indicator = Indicator( | |
threat_id=threat_id, | |
value=value, | |
indicator_type=indicator_type, | |
description=description, | |
is_verified=is_verified, | |
context=context, | |
source=source, | |
confidence_score=confidence_score, | |
first_seen=datetime.utcnow(), | |
last_seen=datetime.utcnow(), | |
) | |
db.add(db_indicator) | |
await db.commit() | |
await db.refresh(db_indicator) | |
return db_indicator | |
async def get_threat_statistics( | |
db: AsyncSession, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> Dict[str, Any]: | |
""" | |
Get threat statistics. | |
Args: | |
db: Database session | |
from_date: Filter by discovered_at >= from_date | |
to_date: Filter by discovered_at <= to_date | |
Returns: | |
Dict[str, Any]: Threat statistics | |
""" | |
# Set default time range if not provided | |
if not to_date: | |
to_date = datetime.utcnow() | |
if not from_date: | |
from_date = to_date - timedelta(days=30) | |
# Get count by severity | |
severity_counts = {} | |
for severity in ThreatSeverity: | |
query = select(func.count(Threat.id)).filter(and_( | |
Threat.severity == severity, | |
Threat.discovered_at >= from_date, | |
Threat.discovered_at <= to_date, | |
)) | |
result = await db.execute(query) | |
severity_counts[severity.value] = result.scalar() or 0 | |
# Get count by status | |
status_counts = {} | |
for status in ThreatStatus: | |
query = select(func.count(Threat.id)).filter(and_( | |
Threat.status == status, | |
Threat.discovered_at >= from_date, | |
Threat.discovered_at <= to_date, | |
)) | |
result = await db.execute(query) | |
status_counts[status.value] = result.scalar() or 0 | |
# Get count by category | |
category_counts = {} | |
for category in ThreatCategory: | |
query = select(func.count(Threat.id)).filter(and_( | |
Threat.category == category, | |
Threat.discovered_at >= from_date, | |
Threat.discovered_at <= to_date, | |
)) | |
result = await db.execute(query) | |
category_counts[category.value] = result.scalar() or 0 | |
# Get total count | |
query = select(func.count(Threat.id)).filter(and_( | |
Threat.discovered_at >= from_date, | |
Threat.discovered_at <= to_date, | |
)) | |
result = await db.execute(query) | |
total_count = result.scalar() or 0 | |
# Get count by day | |
time_series = [] | |
current_date = from_date.date() | |
end_date = to_date.date() | |
while current_date <= end_date: | |
next_date = current_date + timedelta(days=1) | |
query = select(func.count(Threat.id)).filter(and_( | |
Threat.discovered_at >= datetime.combine(current_date, datetime.min.time()), | |
Threat.discovered_at < datetime.combine(next_date, datetime.min.time()), | |
)) | |
result = await db.execute(query) | |
count = result.scalar() or 0 | |
time_series.append({ | |
"date": current_date.isoformat(), | |
"count": count | |
}) | |
current_date = next_date | |
# Return statistics | |
return { | |
"total_count": total_count, | |
"severity_counts": severity_counts, | |
"status_counts": status_counts, | |
"category_counts": category_counts, | |
"time_series": time_series, | |
"from_date": from_date.isoformat(), | |
"to_date": to_date.isoformat(), | |
} |