Spaces:
Running
Running
""" | |
Database integration for Streamlit application. | |
This module provides functions to interact with the database for the Streamlit frontend. | |
It wraps the async database functions in sync functions for Streamlit compatibility. | |
""" | |
import os | |
import asyncio | |
import pandas as pd | |
from typing import List, Dict, Any, Optional, Union, Tuple | |
from datetime import datetime, timedelta | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
# Import database models | |
from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory | |
from src.models.indicator import Indicator, IndicatorType | |
from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus | |
from src.models.alert import Alert, AlertStatus, AlertCategory | |
from src.models.report import Report, ReportType, ReportStatus | |
# Import service functions | |
from src.api.services.dark_web_content_service import ( | |
create_content, get_content_by_id, get_contents, count_contents, | |
create_mention, get_mentions, create_threat_from_content | |
) | |
from src.api.services.alert_service import ( | |
create_alert, get_alert_by_id, get_alerts, count_alerts, | |
update_alert_status, mark_alert_as_read, get_alert_counts_by_severity | |
) | |
from src.api.services.threat_service import ( | |
create_threat, get_threat_by_id, get_threats, count_threats, | |
update_threat, add_indicator_to_threat, get_threat_statistics | |
) | |
from src.api.services.report_service import ( | |
create_report, get_report_by_id, get_reports, count_reports, | |
update_report, add_threat_to_report, publish_report | |
) | |
# Import schemas | |
from src.api.schemas import PaginationParams | |
# Get database URL from environment | |
db_url = os.getenv("DATABASE_URL", "") | |
if db_url.startswith("postgresql://"): | |
# Remove sslmode parameter if present which causes issues with asyncpg | |
if "?" in db_url: | |
base_url, params = db_url.split("?", 1) | |
param_list = params.split("&") | |
filtered_params = [p for p in param_list if not p.startswith("sslmode=")] | |
if filtered_params: | |
db_url = f"{base_url}?{'&'.join(filtered_params)}" | |
else: | |
db_url = base_url | |
ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) | |
else: | |
ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres" | |
# Create async engine | |
engine = create_async_engine( | |
ASYNC_DATABASE_URL, | |
echo=False, | |
future=True, | |
pool_size=5, | |
max_overflow=10 | |
) | |
# Create async session factory | |
async_session = sessionmaker( | |
engine, | |
class_=AsyncSession, | |
expire_on_commit=False | |
) | |
def run_async(coro): | |
"""Run an async function in a sync context.""" | |
try: | |
loop = asyncio.get_event_loop() | |
except RuntimeError: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete(coro) | |
async def get_session(): | |
"""Get an async database session.""" | |
async with async_session() as session: | |
yield session | |
def get_db_session(): | |
"""Get a database session for use in Streamlit.""" | |
try: | |
session_gen = get_session().__aiter__() | |
return run_async(session_gen.__anext__()) | |
except StopAsyncIteration: | |
return None | |
async def get_async_session(): | |
""" | |
Async context manager for database sessions. | |
Usage: | |
async with get_async_session() as session: | |
# Use session here | |
""" | |
session = async_session() | |
try: | |
yield session | |
await session.commit() | |
except Exception as e: | |
await session.rollback() | |
raise e | |
finally: | |
await session.close() | |
# Dark Web Content functions | |
def get_dark_web_contents( | |
page: int = 1, | |
size: int = 10, | |
content_type: Optional[List[ContentType]] = None, | |
content_status: Optional[List[ContentStatus]] = None, | |
source_name: Optional[str] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> pd.DataFrame: | |
""" | |
Get dark web contents as a DataFrame. | |
Args: | |
page: Page number | |
size: Page size | |
content_type: Filter by content type | |
content_status: Filter by content status | |
source_name: Filter by source name | |
search_query: Search in title and content | |
from_date: Filter by scraped_at >= from_date | |
to_date: Filter by scraped_at <= to_date | |
Returns: | |
pd.DataFrame: DataFrame with dark web contents | |
""" | |
session = get_db_session() | |
if not session: | |
return pd.DataFrame() | |
contents = run_async(get_contents( | |
db=session, | |
pagination=PaginationParams(page=page, size=size), | |
content_type=content_type, | |
content_status=content_status, | |
source_name=source_name, | |
search_query=search_query, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
if not contents: | |
return pd.DataFrame() | |
# Convert to DataFrame | |
data = [] | |
for content in contents: | |
data.append({ | |
"id": content.id, | |
"url": content.url, | |
"title": content.title, | |
"content_type": content.content_type.value if content.content_type else None, | |
"content_status": content.content_status.value if content.content_status else None, | |
"source_name": content.source_name, | |
"source_type": content.source_type, | |
"language": content.language, | |
"scraped_at": content.scraped_at, | |
"relevance_score": content.relevance_score, | |
"sentiment_score": content.sentiment_score, | |
}) | |
return pd.DataFrame(data) | |
def add_dark_web_content( | |
url: str, | |
content: str, | |
title: Optional[str] = None, | |
content_type: ContentType = ContentType.OTHER, | |
source_name: Optional[str] = None, | |
source_type: Optional[str] = None, | |
) -> Optional[DarkWebContent]: | |
""" | |
Add a new dark web content. | |
Args: | |
url: URL of the content | |
content: Text content | |
title: Title of the content | |
content_type: Type of content | |
source_name: Name of the source | |
source_type: Type of source | |
Returns: | |
Optional[DarkWebContent]: Created content or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(create_content( | |
db=session, | |
url=url, | |
content=content, | |
title=title, | |
content_type=content_type, | |
source_name=source_name, | |
source_type=source_type, | |
)) | |
def get_dark_web_mentions( | |
page: int = 1, | |
size: int = 10, | |
keyword: Optional[str] = None, | |
content_id: Optional[int] = None, | |
is_verified: Optional[bool] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> pd.DataFrame: | |
""" | |
Get dark web mentions as a DataFrame. | |
Args: | |
page: Page number | |
size: Page size | |
keyword: Filter by keyword | |
content_id: Filter by content ID | |
is_verified: Filter by verification status | |
from_date: Filter by created_at >= from_date | |
to_date: Filter by created_at <= to_date | |
Returns: | |
pd.DataFrame: DataFrame with dark web mentions | |
""" | |
session = get_db_session() | |
if not session: | |
return pd.DataFrame() | |
mentions = run_async(get_mentions( | |
db=session, | |
pagination=PaginationParams(page=page, size=size), | |
keyword=keyword, | |
content_id=content_id, | |
is_verified=is_verified, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
if not mentions: | |
return pd.DataFrame() | |
# Convert to DataFrame | |
data = [] | |
for mention in mentions: | |
data.append({ | |
"id": mention.id, | |
"content_id": mention.content_id, | |
"keyword": mention.keyword, | |
"snippet": mention.snippet, | |
"mention_type": mention.mention_type, | |
"confidence": mention.confidence, | |
"is_verified": mention.is_verified, | |
"created_at": mention.created_at, | |
}) | |
return pd.DataFrame(data) | |
def add_dark_web_mention( | |
content_id: int, | |
keyword: str, | |
context: Optional[str] = None, | |
snippet: Optional[str] = None, | |
) -> Optional[DarkWebMention]: | |
""" | |
Add a new dark web mention. | |
Args: | |
content_id: ID of the content where the mention was found | |
keyword: Keyword that was mentioned | |
context: Text surrounding the mention | |
snippet: Extract of text containing the mention | |
Returns: | |
Optional[DarkWebMention]: Created mention or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(create_mention( | |
db=session, | |
content_id=content_id, | |
keyword=keyword, | |
context=context, | |
snippet=snippet, | |
)) | |
# Alerts functions | |
def get_alerts_df( | |
page: int = 1, | |
size: int = 10, | |
severity: Optional[List[ThreatSeverity]] = None, | |
status: Optional[List[AlertStatus]] = None, | |
category: Optional[List[AlertCategory]] = None, | |
is_read: Optional[bool] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> pd.DataFrame: | |
""" | |
Get alerts as a DataFrame. | |
Args: | |
page: Page number | |
size: Page size | |
severity: Filter by severity | |
status: Filter by status | |
category: Filter by category | |
is_read: Filter by read status | |
search_query: Search in title and description | |
from_date: Filter by generated_at >= from_date | |
to_date: Filter by generated_at <= to_date | |
Returns: | |
pd.DataFrame: DataFrame with alerts | |
""" | |
session = get_db_session() | |
if not session: | |
return pd.DataFrame() | |
alerts = run_async(get_alerts( | |
db=session, | |
pagination=PaginationParams(page=page, size=size), | |
severity=severity, | |
status=status, | |
category=category, | |
is_read=is_read, | |
search_query=search_query, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
if not alerts: | |
return pd.DataFrame() | |
# Convert to DataFrame | |
data = [] | |
for alert in alerts: | |
data.append({ | |
"id": alert.id, | |
"title": alert.title, | |
"description": alert.description, | |
"severity": alert.severity.value if alert.severity else None, | |
"status": alert.status.value if alert.status else None, | |
"category": alert.category.value if alert.category else None, | |
"generated_at": alert.generated_at, | |
"source_url": alert.source_url, | |
"is_read": alert.is_read, | |
"threat_id": alert.threat_id, | |
"mention_id": alert.mention_id, | |
"assigned_to_id": alert.assigned_to_id, | |
"action_taken": alert.action_taken, | |
"resolved_at": alert.resolved_at, | |
}) | |
return pd.DataFrame(data) | |
def add_alert( | |
title: str, | |
description: str, | |
severity: ThreatSeverity, | |
category: AlertCategory, | |
source_url: Optional[str] = None, | |
threat_id: Optional[int] = None, | |
mention_id: Optional[int] = None, | |
) -> Optional[Alert]: | |
""" | |
Add a new alert. | |
Args: | |
title: Alert title | |
description: Alert description | |
severity: Alert severity | |
category: Alert category | |
source_url: Source URL for the alert | |
threat_id: ID of related threat | |
mention_id: ID of related dark web mention | |
Returns: | |
Optional[Alert]: Created alert or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(create_alert( | |
db=session, | |
title=title, | |
description=description, | |
severity=severity, | |
category=category, | |
source_url=source_url, | |
threat_id=threat_id, | |
mention_id=mention_id, | |
)) | |
def update_alert( | |
alert_id: int, | |
status: AlertStatus, | |
action_taken: Optional[str] = None, | |
) -> Optional[Alert]: | |
""" | |
Update alert status. | |
Args: | |
alert_id: Alert ID | |
status: New status | |
action_taken: Description of action taken | |
Returns: | |
Optional[Alert]: Updated alert or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(update_alert_status( | |
db=session, | |
alert_id=alert_id, | |
status=status, | |
action_taken=action_taken, | |
)) | |
def get_alert_severity_counts( | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> Dict[str, int]: | |
""" | |
Get count of alerts by severity. | |
Args: | |
from_date: Filter by generated_at >= from_date | |
to_date: Filter by generated_at <= to_date | |
Returns: | |
Dict[str, int]: Mapping of severity to count | |
""" | |
session = get_db_session() | |
if not session: | |
return {} | |
return run_async(get_alert_counts_by_severity( | |
db=session, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
# Threats functions | |
def get_threats_df( | |
page: int = 1, | |
size: int = 10, | |
severity: Optional[List[ThreatSeverity]] = None, | |
status: Optional[List[ThreatStatus]] = None, | |
category: Optional[List[ThreatCategory]] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> pd.DataFrame: | |
""" | |
Get threats as a DataFrame. | |
Args: | |
page: Page number | |
size: Page size | |
severity: Filter by severity | |
status: Filter by status | |
category: Filter by category | |
search_query: Search in title and description | |
from_date: Filter by discovered_at >= from_date | |
to_date: Filter by discovered_at <= to_date | |
Returns: | |
pd.DataFrame: DataFrame with threats | |
""" | |
session = get_db_session() | |
if not session: | |
return pd.DataFrame() | |
threats = run_async(get_threats( | |
db=session, | |
pagination=PaginationParams(page=page, size=size), | |
severity=severity, | |
status=status, | |
category=category, | |
search_query=search_query, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
if not threats: | |
return pd.DataFrame() | |
# Convert to DataFrame | |
data = [] | |
for threat in threats: | |
data.append({ | |
"id": threat.id, | |
"title": threat.title, | |
"description": threat.description, | |
"severity": threat.severity.value if threat.severity else None, | |
"status": threat.status.value if threat.status else None, | |
"category": threat.category.value if threat.category else None, | |
"source_url": threat.source_url, | |
"source_name": threat.source_name, | |
"source_type": threat.source_type, | |
"discovered_at": threat.discovered_at, | |
"affected_entity": threat.affected_entity, | |
"affected_entity_type": threat.affected_entity_type, | |
"confidence_score": threat.confidence_score, | |
"risk_score": threat.risk_score, | |
}) | |
return pd.DataFrame(data) | |
def add_threat( | |
title: str, | |
description: str, | |
severity: ThreatSeverity, | |
category: ThreatCategory, | |
status: ThreatStatus = ThreatStatus.NEW, | |
source_url: Optional[str] = None, | |
source_name: Optional[str] = None, | |
source_type: Optional[str] = None, | |
affected_entity: Optional[str] = None, | |
affected_entity_type: Optional[str] = None, | |
confidence_score: float = 0.0, | |
risk_score: float = 0.0, | |
) -> Optional[Threat]: | |
""" | |
Add a new threat. | |
Args: | |
title: Threat title | |
description: Threat description | |
severity: Threat severity | |
category: Threat category | |
status: Threat status | |
source_url: URL of the source | |
source_name: Name of the source | |
source_type: Type of source | |
affected_entity: Name of affected entity | |
affected_entity_type: Type of affected entity | |
confidence_score: Confidence score (0-1) | |
risk_score: Risk score (0-1) | |
Returns: | |
Optional[Threat]: Created threat or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(create_threat( | |
db=session, | |
title=title, | |
description=description, | |
severity=severity, | |
category=category, | |
status=status, | |
source_url=source_url, | |
source_name=source_name, | |
source_type=source_type, | |
affected_entity=affected_entity, | |
affected_entity_type=affected_entity_type, | |
confidence_score=confidence_score, | |
risk_score=risk_score, | |
)) | |
def add_indicator( | |
threat_id: int, | |
value: str, | |
indicator_type: IndicatorType, | |
description: Optional[str] = None, | |
is_verified: bool = False, | |
context: Optional[str] = None, | |
source: Optional[str] = None, | |
) -> Optional[Indicator]: | |
""" | |
Add an indicator to a threat. | |
Args: | |
threat_id: Threat ID | |
value: Indicator value | |
indicator_type: Indicator type | |
description: Indicator description | |
is_verified: Whether the indicator is verified | |
context: Context of the indicator | |
source: Source of the indicator | |
Returns: | |
Optional[Indicator]: Created indicator or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(add_indicator_to_threat( | |
db=session, | |
threat_id=threat_id, | |
value=value, | |
indicator_type=indicator_type, | |
description=description, | |
is_verified=is_verified, | |
context=context, | |
source=source, | |
)) | |
def get_threat_stats( | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> Dict[str, Any]: | |
""" | |
Get threat statistics. | |
Args: | |
from_date: Filter by discovered_at >= from_date | |
to_date: Filter by discovered_at <= to_date | |
Returns: | |
Dict[str, Any]: Threat statistics | |
""" | |
session = get_db_session() | |
if not session: | |
return {} | |
return run_async(get_threat_statistics( | |
db=session, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
# Reports functions | |
def get_reports_df( | |
page: int = 1, | |
size: int = 10, | |
report_type: Optional[List[ReportType]] = None, | |
status: Optional[List[ReportStatus]] = None, | |
severity: Optional[List[ThreatSeverity]] = None, | |
search_query: Optional[str] = None, | |
from_date: Optional[datetime] = None, | |
to_date: Optional[datetime] = None, | |
) -> pd.DataFrame: | |
""" | |
Get reports as a DataFrame. | |
Args: | |
page: Page number | |
size: Page size | |
report_type: Filter by report type | |
status: Filter by status | |
severity: Filter by severity | |
search_query: Search in title and summary | |
from_date: Filter by created_at >= from_date | |
to_date: Filter by created_at <= to_date | |
Returns: | |
pd.DataFrame: DataFrame with reports | |
""" | |
session = get_db_session() | |
if not session: | |
return pd.DataFrame() | |
reports = run_async(get_reports( | |
db=session, | |
pagination=PaginationParams(page=page, size=size), | |
report_type=report_type, | |
status=status, | |
severity=severity, | |
search_query=search_query, | |
from_date=from_date, | |
to_date=to_date, | |
)) | |
if not reports: | |
return pd.DataFrame() | |
# Convert to DataFrame | |
data = [] | |
for report in reports: | |
data.append({ | |
"id": report.id, | |
"report_id": report.report_id, | |
"title": report.title, | |
"summary": report.summary, | |
"report_type": report.report_type.value if report.report_type else None, | |
"status": report.status.value if report.status else None, | |
"severity": report.severity.value if report.severity else None, | |
"publish_date": report.publish_date, | |
"created_at": report.created_at, | |
"time_period_start": report.time_period_start, | |
"time_period_end": report.time_period_end, | |
"author_id": report.author_id, | |
}) | |
return pd.DataFrame(data) | |
def add_report( | |
title: str, | |
summary: str, | |
content: str, | |
report_type: ReportType, | |
report_id: str, | |
status: ReportStatus = ReportStatus.DRAFT, | |
severity: Optional[ThreatSeverity] = None, | |
publish_date: Optional[datetime] = None, | |
time_period_start: Optional[datetime] = None, | |
time_period_end: Optional[datetime] = None, | |
keywords: Optional[List[str]] = None, | |
author_id: Optional[int] = None, | |
) -> Optional[Report]: | |
""" | |
Add a new report. | |
Args: | |
title: Report title | |
summary: Report summary | |
content: Report content | |
report_type: Type of report | |
report_id: Custom ID for the report | |
status: Report status | |
severity: Report severity | |
publish_date: Publication date | |
time_period_start: Start of time period covered | |
time_period_end: End of time period covered | |
keywords: List of keywords related to the report | |
author_id: ID of the report author | |
Returns: | |
Optional[Report]: Created report or None | |
""" | |
session = get_db_session() | |
if not session: | |
return None | |
return run_async(create_report( | |
db=session, | |
title=title, | |
summary=summary, | |
content=content, | |
report_type=report_type, | |
report_id=report_id, | |
status=status, | |
severity=severity, | |
publish_date=publish_date, | |
time_period_start=time_period_start, | |
time_period_end=time_period_end, | |
keywords=keywords, | |
author_id=author_id, | |
)) | |
# Helper functions | |
def get_time_range_dates(time_range: str) -> Tuple[datetime, datetime]: | |
""" | |
Get start and end dates for a time range. | |
Args: | |
time_range: Time range string (e.g., "Last 7 Days") | |
Returns: | |
Tuple[datetime, datetime]: (start_date, end_date) | |
""" | |
end_date = datetime.utcnow() | |
if time_range == "Last 24 Hours": | |
start_date = end_date - timedelta(days=1) | |
elif time_range == "Last 7 Days": | |
start_date = end_date - timedelta(days=7) | |
elif time_range == "Last 30 Days": | |
start_date = end_date - timedelta(days=30) | |
elif time_range == "Last Quarter": | |
start_date = end_date - timedelta(days=90) | |
else: # Default to last 30 days | |
start_date = end_date - timedelta(days=30) | |
return start_date, end_date | |
# Initialize database connection | |
def init_db_connection(): | |
"""Initialize database connection and check if tables exist.""" | |
session = get_db_session() | |
if not session: | |
return False | |
# Check if tables exist | |
from sqlalchemy.future import select | |
try: | |
# Try to query if tables exist using SQLAlchemy text() | |
from sqlalchemy import text | |
query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')") | |
result = run_async(session.execute(query)) | |
exists = result.scalar() | |
return exists | |
except Exception as e: | |
# Tables might not exist yet | |
print(f"Error checking database: {e}") | |
return False |