""" Subscription service. This module provides functions for managing subscriptions. """ import os import logging from datetime import datetime, timedelta from typing import List, Dict, Any, Optional, Tuple, Union import stripe from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.orm import joinedload from src.models.subscription import ( SubscriptionPlan, UserSubscription, PaymentHistory, SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus ) from src.models.user import User # Set up Stripe API key stripe.api_key = os.environ.get("STRIPE_SECRET_KEY") STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY") # Set up logging logger = logging.getLogger(__name__) async def get_subscription_plans( db: AsyncSession, active_only: bool = True ) -> List[SubscriptionPlan]: """ Get all subscription plans. Args: db: Database session active_only: If True, only return active plans Returns: List of subscription plans """ query = select(SubscriptionPlan) if active_only: query = query.where(SubscriptionPlan.is_active == True) result = await db.execute(query) plans = result.scalars().all() return plans async def get_subscription_plan_by_id( db: AsyncSession, plan_id: int ) -> Optional[SubscriptionPlan]: """ Get a subscription plan by ID. Args: db: Database session plan_id: ID of the plan to get Returns: Subscription plan or None if not found """ query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id) result = await db.execute(query) plan = result.scalars().first() return plan async def get_subscription_plan_by_tier( db: AsyncSession, tier: SubscriptionTier ) -> Optional[SubscriptionPlan]: """ Get a subscription plan by tier. Args: db: Database session tier: Tier of the plan to get Returns: Subscription plan or None if not found """ query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier) result = await db.execute(query) plan = result.scalars().first() return plan async def create_subscription_plan( db: AsyncSession, name: str, tier: SubscriptionTier, description: str, price_monthly: float, price_annually: float, max_alerts: int = 10, max_reports: int = 5, max_searches_per_day: int = 20, max_monitoring_keywords: int = 10, max_data_retention_days: int = 30, supports_api_access: bool = False, supports_live_feed: bool = False, supports_dark_web_monitoring: bool = False, supports_export: bool = False, supports_advanced_analytics: bool = False, create_stripe_product: bool = True ) -> Optional[SubscriptionPlan]: """ Create a new subscription plan. Args: db: Database session name: Name of the plan tier: Tier of the plan description: Description of the plan price_monthly: Monthly price of the plan price_annually: Annual price of the plan max_alerts: Maximum number of alerts allowed max_reports: Maximum number of reports allowed max_searches_per_day: Maximum number of searches per day max_monitoring_keywords: Maximum number of monitoring keywords max_data_retention_days: Maximum number of days to retain data supports_api_access: Whether the plan supports API access supports_live_feed: Whether the plan supports live feed supports_dark_web_monitoring: Whether the plan supports dark web monitoring supports_export: Whether the plan supports data export supports_advanced_analytics: Whether the plan supports advanced analytics create_stripe_product: Whether to create a Stripe product for this plan Returns: Created subscription plan or None if creation failed """ # Check if plan with the same tier already exists existing_plan = await get_subscription_plan_by_tier(db, tier) if existing_plan: logger.warning(f"Subscription plan with tier {tier} already exists") return None # Create Stripe product if requested stripe_product_id = None stripe_monthly_price_id = None stripe_annual_price_id = None if create_stripe_product and stripe.api_key: try: # Create Stripe product product = stripe.Product.create( name=name, description=description, metadata={ "tier": tier.value, "max_alerts": max_alerts, "max_reports": max_reports, "max_searches_per_day": max_searches_per_day, "max_monitoring_keywords": max_monitoring_keywords, "max_data_retention_days": max_data_retention_days, "supports_api_access": "yes" if supports_api_access else "no", "supports_live_feed": "yes" if supports_live_feed else "no", "supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no", "supports_export": "yes" if supports_export else "no", "supports_advanced_analytics": "yes" if supports_advanced_analytics else "no" } ) stripe_product_id = product.id # Create monthly price monthly_price = stripe.Price.create( product=product.id, unit_amount=int(price_monthly * 100), # Stripe uses cents currency="usd", recurring={"interval": "month"}, metadata={"billing_period": "monthly"} ) stripe_monthly_price_id = monthly_price.id # Create annual price annual_price = stripe.Price.create( product=product.id, unit_amount=int(price_annually * 100), # Stripe uses cents currency="usd", recurring={"interval": "year"}, metadata={"billing_period": "annually"} ) stripe_annual_price_id = annual_price.id logger.info(f"Created Stripe product {product.id} for plan {name}") except Exception as e: logger.error(f"Failed to create Stripe product for plan {name}: {e}") # Create plan in database plan = SubscriptionPlan( name=name, tier=tier, description=description, price_monthly=price_monthly, price_annually=price_annually, max_alerts=max_alerts, max_reports=max_reports, max_searches_per_day=max_searches_per_day, max_monitoring_keywords=max_monitoring_keywords, max_data_retention_days=max_data_retention_days, supports_api_access=supports_api_access, supports_live_feed=supports_live_feed, supports_dark_web_monitoring=supports_dark_web_monitoring, supports_export=supports_export, supports_advanced_analytics=supports_advanced_analytics, stripe_product_id=stripe_product_id, stripe_monthly_price_id=stripe_monthly_price_id, stripe_annual_price_id=stripe_annual_price_id ) db.add(plan) await db.commit() await db.refresh(plan) return plan async def update_subscription_plan( db: AsyncSession, plan_id: int, name: Optional[str] = None, description: Optional[str] = None, price_monthly: Optional[float] = None, price_annually: Optional[float] = None, is_active: Optional[bool] = None, max_alerts: Optional[int] = None, max_reports: Optional[int] = None, max_searches_per_day: Optional[int] = None, max_monitoring_keywords: Optional[int] = None, max_data_retention_days: Optional[int] = None, supports_api_access: Optional[bool] = None, supports_live_feed: Optional[bool] = None, supports_dark_web_monitoring: Optional[bool] = None, supports_export: Optional[bool] = None, supports_advanced_analytics: Optional[bool] = None, update_stripe_product: bool = True ) -> Optional[SubscriptionPlan]: """ Update a subscription plan. Args: db: Database session plan_id: ID of the plan to update name: New name of the plan description: New description of the plan price_monthly: New monthly price of the plan price_annually: New annual price of the plan is_active: New active status of the plan max_alerts: New maximum number of alerts allowed max_reports: New maximum number of reports allowed max_searches_per_day: New maximum number of searches per day max_monitoring_keywords: New maximum number of monitoring keywords max_data_retention_days: New maximum number of days to retain data supports_api_access: New API access support status supports_live_feed: New live feed support status supports_dark_web_monitoring: New dark web monitoring support status supports_export: New data export support status supports_advanced_analytics: New advanced analytics support status update_stripe_product: Whether to update the Stripe product for this plan Returns: Updated subscription plan or None if update failed """ # Get existing plan plan = await get_subscription_plan_by_id(db, plan_id) if not plan: logger.warning(f"Subscription plan with ID {plan_id} not found") return None # Prepare update data update_data = {} if name is not None: update_data["name"] = name if description is not None: update_data["description"] = description if price_monthly is not None: update_data["price_monthly"] = price_monthly if price_annually is not None: update_data["price_annually"] = price_annually if is_active is not None: update_data["is_active"] = is_active if max_alerts is not None: update_data["max_alerts"] = max_alerts if max_reports is not None: update_data["max_reports"] = max_reports if max_searches_per_day is not None: update_data["max_searches_per_day"] = max_searches_per_day if max_monitoring_keywords is not None: update_data["max_monitoring_keywords"] = max_monitoring_keywords if max_data_retention_days is not None: update_data["max_data_retention_days"] = max_data_retention_days if supports_api_access is not None: update_data["supports_api_access"] = supports_api_access if supports_live_feed is not None: update_data["supports_live_feed"] = supports_live_feed if supports_dark_web_monitoring is not None: update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring if supports_export is not None: update_data["supports_export"] = supports_export if supports_advanced_analytics is not None: update_data["supports_advanced_analytics"] = supports_advanced_analytics # Update Stripe product if requested if update_stripe_product and plan.stripe_product_id and stripe.api_key: try: # Update Stripe product product_update_data = {} if name is not None: product_update_data["name"] = name if description is not None: product_update_data["description"] = description metadata_update = {} if max_alerts is not None: metadata_update["max_alerts"] = max_alerts if max_reports is not None: metadata_update["max_reports"] = max_reports if max_searches_per_day is not None: metadata_update["max_searches_per_day"] = max_searches_per_day if max_monitoring_keywords is not None: metadata_update["max_monitoring_keywords"] = max_monitoring_keywords if max_data_retention_days is not None: metadata_update["max_data_retention_days"] = max_data_retention_days if supports_api_access is not None: metadata_update["supports_api_access"] = "yes" if supports_api_access else "no" if supports_live_feed is not None: metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no" if supports_dark_web_monitoring is not None: metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no" if supports_export is not None: metadata_update["supports_export"] = "yes" if supports_export else "no" if supports_advanced_analytics is not None: metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no" if metadata_update: product_update_data["metadata"] = metadata_update if product_update_data: stripe.Product.modify(plan.stripe_product_id, **product_update_data) # Update prices if needed if price_monthly is not None and plan.stripe_monthly_price_id: # Can't update existing price in Stripe, create a new one new_monthly_price = stripe.Price.create( product=plan.stripe_product_id, unit_amount=int(price_monthly * 100), # Stripe uses cents currency="usd", recurring={"interval": "month"}, metadata={"billing_period": "monthly"} ) update_data["stripe_monthly_price_id"] = new_monthly_price.id if price_annually is not None and plan.stripe_annual_price_id: # Can't update existing price in Stripe, create a new one new_annual_price = stripe.Price.create( product=plan.stripe_product_id, unit_amount=int(price_annually * 100), # Stripe uses cents currency="usd", recurring={"interval": "year"}, metadata={"billing_period": "annually"} ) update_data["stripe_annual_price_id"] = new_annual_price.id logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}") except Exception as e: logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}") # Update plan in database if update_data: await db.execute( update(SubscriptionPlan) .where(SubscriptionPlan.id == plan_id) .values(**update_data) ) await db.commit() # Refresh plan plan = await get_subscription_plan_by_id(db, plan_id) return plan async def get_user_subscription( db: AsyncSession, user_id: int ) -> Optional[UserSubscription]: """ Get a user's active subscription. Args: db: Database session user_id: ID of the user Returns: User subscription or None if not found """ query = ( select(UserSubscription) .where(UserSubscription.user_id == user_id) .where(UserSubscription.status != SubscriptionStatus.CANCELED) .options(joinedload(UserSubscription.plan)) ) result = await db.execute(query) subscription = result.scalars().first() return subscription async def get_user_subscription_by_id( db: AsyncSession, subscription_id: int ) -> Optional[UserSubscription]: """ Get a user subscription by ID. Args: db: Database session subscription_id: ID of the subscription Returns: User subscription or None if not found """ query = ( select(UserSubscription) .where(UserSubscription.id == subscription_id) .options(joinedload(UserSubscription.plan)) ) result = await db.execute(query) subscription = result.scalars().first() return subscription async def create_user_subscription( db: AsyncSession, user_id: int, plan_id: int, billing_period: BillingPeriod = BillingPeriod.MONTHLY, create_stripe_subscription: bool = True, payment_method_id: Optional[str] = None ) -> Optional[UserSubscription]: """ Create a new user subscription. Args: db: Database session user_id: ID of the user plan_id: ID of the subscription plan billing_period: Billing period (monthly or annually) create_stripe_subscription: Whether to create a Stripe subscription payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True) Returns: Created user subscription or None if creation failed """ # Check if user exists query = select(User).where(User.id == user_id) result = await db.execute(query) user = result.scalars().first() if not user: logger.warning(f"User with ID {user_id} not found") return None # Check if plan exists plan = await get_subscription_plan_by_id(db, plan_id) if not plan: logger.warning(f"Subscription plan with ID {plan_id} not found") return None # Check if user already has an active subscription existing_subscription = await get_user_subscription(db, user_id) if existing_subscription: logger.warning(f"User with ID {user_id} already has an active subscription") return None # Calculate subscription period now = datetime.utcnow() if billing_period == BillingPeriod.MONTHLY: current_period_end = now + timedelta(days=30) price = plan.price_monthly stripe_price_id = plan.stripe_monthly_price_id elif billing_period == BillingPeriod.ANNUALLY: current_period_end = now + timedelta(days=365) price = plan.price_annually stripe_price_id = plan.stripe_annual_price_id else: logger.warning(f"Invalid billing period: {billing_period}") return None # Create Stripe subscription if requested stripe_subscription_id = None stripe_customer_id = None if create_stripe_subscription and stripe.api_key and plan.stripe_product_id: if not payment_method_id: logger.warning("Payment method ID is required to create a Stripe subscription") return None try: # Create or retrieve Stripe customer customers = stripe.Customer.list(email=user.email) if customers.data: customer = customers.data[0] stripe_customer_id = customer.id else: customer = stripe.Customer.create( email=user.email, name=user.full_name, metadata={"user_id": user_id} ) stripe_customer_id = customer.id # Attach payment method to customer stripe.PaymentMethod.attach( payment_method_id, customer=stripe_customer_id ) # Set as default payment method stripe.Customer.modify( stripe_customer_id, invoice_settings={ "default_payment_method": payment_method_id } ) # Create subscription subscription = stripe.Subscription.create( customer=stripe_customer_id, items=[ {"price": stripe_price_id} ], expand=["latest_invoice.payment_intent"] ) stripe_subscription_id = subscription.id logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}") except Exception as e: logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}") return None # Create subscription in database subscription = UserSubscription( user_id=user_id, plan_id=plan_id, status=SubscriptionStatus.ACTIVE, billing_period=billing_period, current_period_start=now, current_period_end=current_period_end, stripe_subscription_id=stripe_subscription_id, stripe_customer_id=stripe_customer_id ) db.add(subscription) await db.commit() await db.refresh(subscription) # Record payment if subscription.id: payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING payment = PaymentHistory( user_id=user_id, subscription_id=subscription.id, amount=price, currency="USD", status=payment_status ) db.add(payment) await db.commit() return subscription async def cancel_user_subscription( db: AsyncSession, subscription_id: int, cancel_stripe_subscription: bool = True ) -> Optional[UserSubscription]: """ Cancel a user subscription. Args: db: Database session subscription_id: ID of the subscription to cancel cancel_stripe_subscription: Whether to cancel the Stripe subscription Returns: Canceled user subscription or None if cancellation failed """ # Get subscription subscription = await get_user_subscription_by_id(db, subscription_id) if not subscription: logger.warning(f"Subscription with ID {subscription_id} not found") return None # Cancel Stripe subscription if requested if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key: try: stripe.Subscription.modify( subscription.stripe_subscription_id, cancel_at_period_end=True ) logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end") except Exception as e: logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}") # Update subscription in database now = datetime.utcnow() await db.execute( update(UserSubscription) .where(UserSubscription.id == subscription_id) .values( status=SubscriptionStatus.CANCELED, canceled_at=now ) ) await db.commit() # Refresh subscription subscription = await get_user_subscription_by_id(db, subscription_id) return subscription