CyberForge / src /api /services /subscription_service.py
Replit Deployment
Deployment from Replit
89ae94f
"""
Subscription service.
This module provides functions for managing subscriptions.
"""
import os
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple, Union
import stripe
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import joinedload
from src.models.subscription import (
SubscriptionPlan, UserSubscription, PaymentHistory,
SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus
)
from src.models.user import User
# Set up Stripe API key
stripe.api_key = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY")
# Set up logging
logger = logging.getLogger(__name__)
async def get_subscription_plans(
db: AsyncSession,
active_only: bool = True
) -> List[SubscriptionPlan]:
"""
Get all subscription plans.
Args:
db: Database session
active_only: If True, only return active plans
Returns:
List of subscription plans
"""
query = select(SubscriptionPlan)
if active_only:
query = query.where(SubscriptionPlan.is_active == True)
result = await db.execute(query)
plans = result.scalars().all()
return plans
async def get_subscription_plan_by_id(
db: AsyncSession,
plan_id: int
) -> Optional[SubscriptionPlan]:
"""
Get a subscription plan by ID.
Args:
db: Database session
plan_id: ID of the plan to get
Returns:
Subscription plan or None if not found
"""
query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id)
result = await db.execute(query)
plan = result.scalars().first()
return plan
async def get_subscription_plan_by_tier(
db: AsyncSession,
tier: SubscriptionTier
) -> Optional[SubscriptionPlan]:
"""
Get a subscription plan by tier.
Args:
db: Database session
tier: Tier of the plan to get
Returns:
Subscription plan or None if not found
"""
query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier)
result = await db.execute(query)
plan = result.scalars().first()
return plan
async def create_subscription_plan(
db: AsyncSession,
name: str,
tier: SubscriptionTier,
description: str,
price_monthly: float,
price_annually: float,
max_alerts: int = 10,
max_reports: int = 5,
max_searches_per_day: int = 20,
max_monitoring_keywords: int = 10,
max_data_retention_days: int = 30,
supports_api_access: bool = False,
supports_live_feed: bool = False,
supports_dark_web_monitoring: bool = False,
supports_export: bool = False,
supports_advanced_analytics: bool = False,
create_stripe_product: bool = True
) -> Optional[SubscriptionPlan]:
"""
Create a new subscription plan.
Args:
db: Database session
name: Name of the plan
tier: Tier of the plan
description: Description of the plan
price_monthly: Monthly price of the plan
price_annually: Annual price of the plan
max_alerts: Maximum number of alerts allowed
max_reports: Maximum number of reports allowed
max_searches_per_day: Maximum number of searches per day
max_monitoring_keywords: Maximum number of monitoring keywords
max_data_retention_days: Maximum number of days to retain data
supports_api_access: Whether the plan supports API access
supports_live_feed: Whether the plan supports live feed
supports_dark_web_monitoring: Whether the plan supports dark web monitoring
supports_export: Whether the plan supports data export
supports_advanced_analytics: Whether the plan supports advanced analytics
create_stripe_product: Whether to create a Stripe product for this plan
Returns:
Created subscription plan or None if creation failed
"""
# Check if plan with the same tier already exists
existing_plan = await get_subscription_plan_by_tier(db, tier)
if existing_plan:
logger.warning(f"Subscription plan with tier {tier} already exists")
return None
# Create Stripe product if requested
stripe_product_id = None
stripe_monthly_price_id = None
stripe_annual_price_id = None
if create_stripe_product and stripe.api_key:
try:
# Create Stripe product
product = stripe.Product.create(
name=name,
description=description,
metadata={
"tier": tier.value,
"max_alerts": max_alerts,
"max_reports": max_reports,
"max_searches_per_day": max_searches_per_day,
"max_monitoring_keywords": max_monitoring_keywords,
"max_data_retention_days": max_data_retention_days,
"supports_api_access": "yes" if supports_api_access else "no",
"supports_live_feed": "yes" if supports_live_feed else "no",
"supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no",
"supports_export": "yes" if supports_export else "no",
"supports_advanced_analytics": "yes" if supports_advanced_analytics else "no"
}
)
stripe_product_id = product.id
# Create monthly price
monthly_price = stripe.Price.create(
product=product.id,
unit_amount=int(price_monthly * 100), # Stripe uses cents
currency="usd",
recurring={"interval": "month"},
metadata={"billing_period": "monthly"}
)
stripe_monthly_price_id = monthly_price.id
# Create annual price
annual_price = stripe.Price.create(
product=product.id,
unit_amount=int(price_annually * 100), # Stripe uses cents
currency="usd",
recurring={"interval": "year"},
metadata={"billing_period": "annually"}
)
stripe_annual_price_id = annual_price.id
logger.info(f"Created Stripe product {product.id} for plan {name}")
except Exception as e:
logger.error(f"Failed to create Stripe product for plan {name}: {e}")
# Create plan in database
plan = SubscriptionPlan(
name=name,
tier=tier,
description=description,
price_monthly=price_monthly,
price_annually=price_annually,
max_alerts=max_alerts,
max_reports=max_reports,
max_searches_per_day=max_searches_per_day,
max_monitoring_keywords=max_monitoring_keywords,
max_data_retention_days=max_data_retention_days,
supports_api_access=supports_api_access,
supports_live_feed=supports_live_feed,
supports_dark_web_monitoring=supports_dark_web_monitoring,
supports_export=supports_export,
supports_advanced_analytics=supports_advanced_analytics,
stripe_product_id=stripe_product_id,
stripe_monthly_price_id=stripe_monthly_price_id,
stripe_annual_price_id=stripe_annual_price_id
)
db.add(plan)
await db.commit()
await db.refresh(plan)
return plan
async def update_subscription_plan(
db: AsyncSession,
plan_id: int,
name: Optional[str] = None,
description: Optional[str] = None,
price_monthly: Optional[float] = None,
price_annually: Optional[float] = None,
is_active: Optional[bool] = None,
max_alerts: Optional[int] = None,
max_reports: Optional[int] = None,
max_searches_per_day: Optional[int] = None,
max_monitoring_keywords: Optional[int] = None,
max_data_retention_days: Optional[int] = None,
supports_api_access: Optional[bool] = None,
supports_live_feed: Optional[bool] = None,
supports_dark_web_monitoring: Optional[bool] = None,
supports_export: Optional[bool] = None,
supports_advanced_analytics: Optional[bool] = None,
update_stripe_product: bool = True
) -> Optional[SubscriptionPlan]:
"""
Update a subscription plan.
Args:
db: Database session
plan_id: ID of the plan to update
name: New name of the plan
description: New description of the plan
price_monthly: New monthly price of the plan
price_annually: New annual price of the plan
is_active: New active status of the plan
max_alerts: New maximum number of alerts allowed
max_reports: New maximum number of reports allowed
max_searches_per_day: New maximum number of searches per day
max_monitoring_keywords: New maximum number of monitoring keywords
max_data_retention_days: New maximum number of days to retain data
supports_api_access: New API access support status
supports_live_feed: New live feed support status
supports_dark_web_monitoring: New dark web monitoring support status
supports_export: New data export support status
supports_advanced_analytics: New advanced analytics support status
update_stripe_product: Whether to update the Stripe product for this plan
Returns:
Updated subscription plan or None if update failed
"""
# Get existing plan
plan = await get_subscription_plan_by_id(db, plan_id)
if not plan:
logger.warning(f"Subscription plan with ID {plan_id} not found")
return None
# Prepare update data
update_data = {}
if name is not None:
update_data["name"] = name
if description is not None:
update_data["description"] = description
if price_monthly is not None:
update_data["price_monthly"] = price_monthly
if price_annually is not None:
update_data["price_annually"] = price_annually
if is_active is not None:
update_data["is_active"] = is_active
if max_alerts is not None:
update_data["max_alerts"] = max_alerts
if max_reports is not None:
update_data["max_reports"] = max_reports
if max_searches_per_day is not None:
update_data["max_searches_per_day"] = max_searches_per_day
if max_monitoring_keywords is not None:
update_data["max_monitoring_keywords"] = max_monitoring_keywords
if max_data_retention_days is not None:
update_data["max_data_retention_days"] = max_data_retention_days
if supports_api_access is not None:
update_data["supports_api_access"] = supports_api_access
if supports_live_feed is not None:
update_data["supports_live_feed"] = supports_live_feed
if supports_dark_web_monitoring is not None:
update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring
if supports_export is not None:
update_data["supports_export"] = supports_export
if supports_advanced_analytics is not None:
update_data["supports_advanced_analytics"] = supports_advanced_analytics
# Update Stripe product if requested
if update_stripe_product and plan.stripe_product_id and stripe.api_key:
try:
# Update Stripe product
product_update_data = {}
if name is not None:
product_update_data["name"] = name
if description is not None:
product_update_data["description"] = description
metadata_update = {}
if max_alerts is not None:
metadata_update["max_alerts"] = max_alerts
if max_reports is not None:
metadata_update["max_reports"] = max_reports
if max_searches_per_day is not None:
metadata_update["max_searches_per_day"] = max_searches_per_day
if max_monitoring_keywords is not None:
metadata_update["max_monitoring_keywords"] = max_monitoring_keywords
if max_data_retention_days is not None:
metadata_update["max_data_retention_days"] = max_data_retention_days
if supports_api_access is not None:
metadata_update["supports_api_access"] = "yes" if supports_api_access else "no"
if supports_live_feed is not None:
metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no"
if supports_dark_web_monitoring is not None:
metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no"
if supports_export is not None:
metadata_update["supports_export"] = "yes" if supports_export else "no"
if supports_advanced_analytics is not None:
metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no"
if metadata_update:
product_update_data["metadata"] = metadata_update
if product_update_data:
stripe.Product.modify(plan.stripe_product_id, **product_update_data)
# Update prices if needed
if price_monthly is not None and plan.stripe_monthly_price_id:
# Can't update existing price in Stripe, create a new one
new_monthly_price = stripe.Price.create(
product=plan.stripe_product_id,
unit_amount=int(price_monthly * 100), # Stripe uses cents
currency="usd",
recurring={"interval": "month"},
metadata={"billing_period": "monthly"}
)
update_data["stripe_monthly_price_id"] = new_monthly_price.id
if price_annually is not None and plan.stripe_annual_price_id:
# Can't update existing price in Stripe, create a new one
new_annual_price = stripe.Price.create(
product=plan.stripe_product_id,
unit_amount=int(price_annually * 100), # Stripe uses cents
currency="usd",
recurring={"interval": "year"},
metadata={"billing_period": "annually"}
)
update_data["stripe_annual_price_id"] = new_annual_price.id
logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}")
except Exception as e:
logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}")
# Update plan in database
if update_data:
await db.execute(
update(SubscriptionPlan)
.where(SubscriptionPlan.id == plan_id)
.values(**update_data)
)
await db.commit()
# Refresh plan
plan = await get_subscription_plan_by_id(db, plan_id)
return plan
async def get_user_subscription(
db: AsyncSession,
user_id: int
) -> Optional[UserSubscription]:
"""
Get a user's active subscription.
Args:
db: Database session
user_id: ID of the user
Returns:
User subscription or None if not found
"""
query = (
select(UserSubscription)
.where(UserSubscription.user_id == user_id)
.where(UserSubscription.status != SubscriptionStatus.CANCELED)
.options(joinedload(UserSubscription.plan))
)
result = await db.execute(query)
subscription = result.scalars().first()
return subscription
async def get_user_subscription_by_id(
db: AsyncSession,
subscription_id: int
) -> Optional[UserSubscription]:
"""
Get a user subscription by ID.
Args:
db: Database session
subscription_id: ID of the subscription
Returns:
User subscription or None if not found
"""
query = (
select(UserSubscription)
.where(UserSubscription.id == subscription_id)
.options(joinedload(UserSubscription.plan))
)
result = await db.execute(query)
subscription = result.scalars().first()
return subscription
async def create_user_subscription(
db: AsyncSession,
user_id: int,
plan_id: int,
billing_period: BillingPeriod = BillingPeriod.MONTHLY,
create_stripe_subscription: bool = True,
payment_method_id: Optional[str] = None
) -> Optional[UserSubscription]:
"""
Create a new user subscription.
Args:
db: Database session
user_id: ID of the user
plan_id: ID of the subscription plan
billing_period: Billing period (monthly or annually)
create_stripe_subscription: Whether to create a Stripe subscription
payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True)
Returns:
Created user subscription or None if creation failed
"""
# Check if user exists
query = select(User).where(User.id == user_id)
result = await db.execute(query)
user = result.scalars().first()
if not user:
logger.warning(f"User with ID {user_id} not found")
return None
# Check if plan exists
plan = await get_subscription_plan_by_id(db, plan_id)
if not plan:
logger.warning(f"Subscription plan with ID {plan_id} not found")
return None
# Check if user already has an active subscription
existing_subscription = await get_user_subscription(db, user_id)
if existing_subscription:
logger.warning(f"User with ID {user_id} already has an active subscription")
return None
# Calculate subscription period
now = datetime.utcnow()
if billing_period == BillingPeriod.MONTHLY:
current_period_end = now + timedelta(days=30)
price = plan.price_monthly
stripe_price_id = plan.stripe_monthly_price_id
elif billing_period == BillingPeriod.ANNUALLY:
current_period_end = now + timedelta(days=365)
price = plan.price_annually
stripe_price_id = plan.stripe_annual_price_id
else:
logger.warning(f"Invalid billing period: {billing_period}")
return None
# Create Stripe subscription if requested
stripe_subscription_id = None
stripe_customer_id = None
if create_stripe_subscription and stripe.api_key and plan.stripe_product_id:
if not payment_method_id:
logger.warning("Payment method ID is required to create a Stripe subscription")
return None
try:
# Create or retrieve Stripe customer
customers = stripe.Customer.list(email=user.email)
if customers.data:
customer = customers.data[0]
stripe_customer_id = customer.id
else:
customer = stripe.Customer.create(
email=user.email,
name=user.full_name,
metadata={"user_id": user_id}
)
stripe_customer_id = customer.id
# Attach payment method to customer
stripe.PaymentMethod.attach(
payment_method_id,
customer=stripe_customer_id
)
# Set as default payment method
stripe.Customer.modify(
stripe_customer_id,
invoice_settings={
"default_payment_method": payment_method_id
}
)
# Create subscription
subscription = stripe.Subscription.create(
customer=stripe_customer_id,
items=[
{"price": stripe_price_id}
],
expand=["latest_invoice.payment_intent"]
)
stripe_subscription_id = subscription.id
logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}")
except Exception as e:
logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}")
return None
# Create subscription in database
subscription = UserSubscription(
user_id=user_id,
plan_id=plan_id,
status=SubscriptionStatus.ACTIVE,
billing_period=billing_period,
current_period_start=now,
current_period_end=current_period_end,
stripe_subscription_id=stripe_subscription_id,
stripe_customer_id=stripe_customer_id
)
db.add(subscription)
await db.commit()
await db.refresh(subscription)
# Record payment
if subscription.id:
payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING
payment = PaymentHistory(
user_id=user_id,
subscription_id=subscription.id,
amount=price,
currency="USD",
status=payment_status
)
db.add(payment)
await db.commit()
return subscription
async def cancel_user_subscription(
db: AsyncSession,
subscription_id: int,
cancel_stripe_subscription: bool = True
) -> Optional[UserSubscription]:
"""
Cancel a user subscription.
Args:
db: Database session
subscription_id: ID of the subscription to cancel
cancel_stripe_subscription: Whether to cancel the Stripe subscription
Returns:
Canceled user subscription or None if cancellation failed
"""
# Get subscription
subscription = await get_user_subscription_by_id(db, subscription_id)
if not subscription:
logger.warning(f"Subscription with ID {subscription_id} not found")
return None
# Cancel Stripe subscription if requested
if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key:
try:
stripe.Subscription.modify(
subscription.stripe_subscription_id,
cancel_at_period_end=True
)
logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end")
except Exception as e:
logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}")
# Update subscription in database
now = datetime.utcnow()
await db.execute(
update(UserSubscription)
.where(UserSubscription.id == subscription_id)
.values(
status=SubscriptionStatus.CANCELED,
canceled_at=now
)
)
await db.commit()
# Refresh subscription
subscription = await get_user_subscription_by_id(db, subscription_id)
return subscription