File size: 23,394 Bytes
89ae94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
"""
Subscription service.

This module provides functions for managing subscriptions.
"""
import os
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple, Union

import stripe
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import joinedload

from src.models.subscription import (
    SubscriptionPlan, UserSubscription, PaymentHistory,
    SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus
)
from src.models.user import User

# Set up Stripe API key
stripe.api_key = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY")

# Set up logging
logger = logging.getLogger(__name__)


async def get_subscription_plans(
    db: AsyncSession,
    active_only: bool = True
) -> List[SubscriptionPlan]:
    """
    Get all subscription plans.
    
    Args:
        db: Database session
        active_only: If True, only return active plans
        
    Returns:
        List of subscription plans
    """
    query = select(SubscriptionPlan)
    
    if active_only:
        query = query.where(SubscriptionPlan.is_active == True)
    
    result = await db.execute(query)
    plans = result.scalars().all()
    
    return plans


async def get_subscription_plan_by_id(
    db: AsyncSession,
    plan_id: int
) -> Optional[SubscriptionPlan]:
    """
    Get a subscription plan by ID.
    
    Args:
        db: Database session
        plan_id: ID of the plan to get
        
    Returns:
        Subscription plan or None if not found
    """
    query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id)
    result = await db.execute(query)
    plan = result.scalars().first()
    
    return plan


async def get_subscription_plan_by_tier(
    db: AsyncSession,
    tier: SubscriptionTier
) -> Optional[SubscriptionPlan]:
    """
    Get a subscription plan by tier.
    
    Args:
        db: Database session
        tier: Tier of the plan to get
        
    Returns:
        Subscription plan or None if not found
    """
    query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier)
    result = await db.execute(query)
    plan = result.scalars().first()
    
    return plan


async def create_subscription_plan(
    db: AsyncSession,
    name: str,
    tier: SubscriptionTier,
    description: str,
    price_monthly: float,
    price_annually: float,
    max_alerts: int = 10,
    max_reports: int = 5,
    max_searches_per_day: int = 20,
    max_monitoring_keywords: int = 10,
    max_data_retention_days: int = 30,
    supports_api_access: bool = False,
    supports_live_feed: bool = False,
    supports_dark_web_monitoring: bool = False,
    supports_export: bool = False,
    supports_advanced_analytics: bool = False,
    create_stripe_product: bool = True
) -> Optional[SubscriptionPlan]:
    """
    Create a new subscription plan.
    
    Args:
        db: Database session
        name: Name of the plan
        tier: Tier of the plan
        description: Description of the plan
        price_monthly: Monthly price of the plan
        price_annually: Annual price of the plan
        max_alerts: Maximum number of alerts allowed
        max_reports: Maximum number of reports allowed
        max_searches_per_day: Maximum number of searches per day
        max_monitoring_keywords: Maximum number of monitoring keywords
        max_data_retention_days: Maximum number of days to retain data
        supports_api_access: Whether the plan supports API access
        supports_live_feed: Whether the plan supports live feed
        supports_dark_web_monitoring: Whether the plan supports dark web monitoring
        supports_export: Whether the plan supports data export
        supports_advanced_analytics: Whether the plan supports advanced analytics
        create_stripe_product: Whether to create a Stripe product for this plan
        
    Returns:
        Created subscription plan or None if creation failed
    """
    # Check if plan with the same tier already exists
    existing_plan = await get_subscription_plan_by_tier(db, tier)
    
    if existing_plan:
        logger.warning(f"Subscription plan with tier {tier} already exists")
        return None
    
    # Create Stripe product if requested
    stripe_product_id = None
    stripe_monthly_price_id = None
    stripe_annual_price_id = None
    
    if create_stripe_product and stripe.api_key:
        try:
            # Create Stripe product
            product = stripe.Product.create(
                name=name,
                description=description,
                metadata={
                    "tier": tier.value,
                    "max_alerts": max_alerts,
                    "max_reports": max_reports,
                    "max_searches_per_day": max_searches_per_day,
                    "max_monitoring_keywords": max_monitoring_keywords,
                    "max_data_retention_days": max_data_retention_days,
                    "supports_api_access": "yes" if supports_api_access else "no",
                    "supports_live_feed": "yes" if supports_live_feed else "no",
                    "supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no",
                    "supports_export": "yes" if supports_export else "no",
                    "supports_advanced_analytics": "yes" if supports_advanced_analytics else "no"
                }
            )
            
            stripe_product_id = product.id
            
            # Create monthly price
            monthly_price = stripe.Price.create(
                product=product.id,
                unit_amount=int(price_monthly * 100),  # Stripe uses cents
                currency="usd",
                recurring={"interval": "month"},
                metadata={"billing_period": "monthly"}
            )
            
            stripe_monthly_price_id = monthly_price.id
            
            # Create annual price
            annual_price = stripe.Price.create(
                product=product.id,
                unit_amount=int(price_annually * 100),  # Stripe uses cents
                currency="usd",
                recurring={"interval": "year"},
                metadata={"billing_period": "annually"}
            )
            
            stripe_annual_price_id = annual_price.id
            
            logger.info(f"Created Stripe product {product.id} for plan {name}")
        except Exception as e:
            logger.error(f"Failed to create Stripe product for plan {name}: {e}")
    
    # Create plan in database
    plan = SubscriptionPlan(
        name=name,
        tier=tier,
        description=description,
        price_monthly=price_monthly,
        price_annually=price_annually,
        max_alerts=max_alerts,
        max_reports=max_reports,
        max_searches_per_day=max_searches_per_day,
        max_monitoring_keywords=max_monitoring_keywords,
        max_data_retention_days=max_data_retention_days,
        supports_api_access=supports_api_access,
        supports_live_feed=supports_live_feed,
        supports_dark_web_monitoring=supports_dark_web_monitoring,
        supports_export=supports_export,
        supports_advanced_analytics=supports_advanced_analytics,
        stripe_product_id=stripe_product_id,
        stripe_monthly_price_id=stripe_monthly_price_id,
        stripe_annual_price_id=stripe_annual_price_id
    )
    
    db.add(plan)
    await db.commit()
    await db.refresh(plan)
    
    return plan


async def update_subscription_plan(
    db: AsyncSession,
    plan_id: int,
    name: Optional[str] = None,
    description: Optional[str] = None,
    price_monthly: Optional[float] = None,
    price_annually: Optional[float] = None,
    is_active: Optional[bool] = None,
    max_alerts: Optional[int] = None,
    max_reports: Optional[int] = None,
    max_searches_per_day: Optional[int] = None,
    max_monitoring_keywords: Optional[int] = None,
    max_data_retention_days: Optional[int] = None,
    supports_api_access: Optional[bool] = None,
    supports_live_feed: Optional[bool] = None,
    supports_dark_web_monitoring: Optional[bool] = None,
    supports_export: Optional[bool] = None,
    supports_advanced_analytics: Optional[bool] = None,
    update_stripe_product: bool = True
) -> Optional[SubscriptionPlan]:
    """
    Update a subscription plan.
    
    Args:
        db: Database session
        plan_id: ID of the plan to update
        name: New name of the plan
        description: New description of the plan
        price_monthly: New monthly price of the plan
        price_annually: New annual price of the plan
        is_active: New active status of the plan
        max_alerts: New maximum number of alerts allowed
        max_reports: New maximum number of reports allowed
        max_searches_per_day: New maximum number of searches per day
        max_monitoring_keywords: New maximum number of monitoring keywords
        max_data_retention_days: New maximum number of days to retain data
        supports_api_access: New API access support status
        supports_live_feed: New live feed support status
        supports_dark_web_monitoring: New dark web monitoring support status
        supports_export: New data export support status
        supports_advanced_analytics: New advanced analytics support status
        update_stripe_product: Whether to update the Stripe product for this plan
        
    Returns:
        Updated subscription plan or None if update failed
    """
    # Get existing plan
    plan = await get_subscription_plan_by_id(db, plan_id)
    
    if not plan:
        logger.warning(f"Subscription plan with ID {plan_id} not found")
        return None
    
    # Prepare update data
    update_data = {}
    
    if name is not None:
        update_data["name"] = name
    
    if description is not None:
        update_data["description"] = description
    
    if price_monthly is not None:
        update_data["price_monthly"] = price_monthly
    
    if price_annually is not None:
        update_data["price_annually"] = price_annually
    
    if is_active is not None:
        update_data["is_active"] = is_active
    
    if max_alerts is not None:
        update_data["max_alerts"] = max_alerts
    
    if max_reports is not None:
        update_data["max_reports"] = max_reports
    
    if max_searches_per_day is not None:
        update_data["max_searches_per_day"] = max_searches_per_day
    
    if max_monitoring_keywords is not None:
        update_data["max_monitoring_keywords"] = max_monitoring_keywords
    
    if max_data_retention_days is not None:
        update_data["max_data_retention_days"] = max_data_retention_days
    
    if supports_api_access is not None:
        update_data["supports_api_access"] = supports_api_access
    
    if supports_live_feed is not None:
        update_data["supports_live_feed"] = supports_live_feed
    
    if supports_dark_web_monitoring is not None:
        update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring
    
    if supports_export is not None:
        update_data["supports_export"] = supports_export
    
    if supports_advanced_analytics is not None:
        update_data["supports_advanced_analytics"] = supports_advanced_analytics
    
    # Update Stripe product if requested
    if update_stripe_product and plan.stripe_product_id and stripe.api_key:
        try:
            # Update Stripe product
            product_update_data = {}
            
            if name is not None:
                product_update_data["name"] = name
            
            if description is not None:
                product_update_data["description"] = description
            
            metadata_update = {}
            
            if max_alerts is not None:
                metadata_update["max_alerts"] = max_alerts
            
            if max_reports is not None:
                metadata_update["max_reports"] = max_reports
            
            if max_searches_per_day is not None:
                metadata_update["max_searches_per_day"] = max_searches_per_day
            
            if max_monitoring_keywords is not None:
                metadata_update["max_monitoring_keywords"] = max_monitoring_keywords
            
            if max_data_retention_days is not None:
                metadata_update["max_data_retention_days"] = max_data_retention_days
            
            if supports_api_access is not None:
                metadata_update["supports_api_access"] = "yes" if supports_api_access else "no"
            
            if supports_live_feed is not None:
                metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no"
            
            if supports_dark_web_monitoring is not None:
                metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no"
            
            if supports_export is not None:
                metadata_update["supports_export"] = "yes" if supports_export else "no"
            
            if supports_advanced_analytics is not None:
                metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no"
            
            if metadata_update:
                product_update_data["metadata"] = metadata_update
            
            if product_update_data:
                stripe.Product.modify(plan.stripe_product_id, **product_update_data)
            
            # Update prices if needed
            if price_monthly is not None and plan.stripe_monthly_price_id:
                # Can't update existing price in Stripe, create a new one
                new_monthly_price = stripe.Price.create(
                    product=plan.stripe_product_id,
                    unit_amount=int(price_monthly * 100),  # Stripe uses cents
                    currency="usd",
                    recurring={"interval": "month"},
                    metadata={"billing_period": "monthly"}
                )
                
                update_data["stripe_monthly_price_id"] = new_monthly_price.id
            
            if price_annually is not None and plan.stripe_annual_price_id:
                # Can't update existing price in Stripe, create a new one
                new_annual_price = stripe.Price.create(
                    product=plan.stripe_product_id,
                    unit_amount=int(price_annually * 100),  # Stripe uses cents
                    currency="usd",
                    recurring={"interval": "year"},
                    metadata={"billing_period": "annually"}
                )
                
                update_data["stripe_annual_price_id"] = new_annual_price.id
            
            logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}")
        except Exception as e:
            logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}")
    
    # Update plan in database
    if update_data:
        await db.execute(
            update(SubscriptionPlan)
            .where(SubscriptionPlan.id == plan_id)
            .values(**update_data)
        )
        
        await db.commit()
        
        # Refresh plan
        plan = await get_subscription_plan_by_id(db, plan_id)
    
    return plan


async def get_user_subscription(
    db: AsyncSession,
    user_id: int
) -> Optional[UserSubscription]:
    """
    Get a user's active subscription.
    
    Args:
        db: Database session
        user_id: ID of the user
        
    Returns:
        User subscription or None if not found
    """
    query = (
        select(UserSubscription)
        .where(UserSubscription.user_id == user_id)
        .where(UserSubscription.status != SubscriptionStatus.CANCELED)
        .options(joinedload(UserSubscription.plan))
    )
    
    result = await db.execute(query)
    subscription = result.scalars().first()
    
    return subscription


async def get_user_subscription_by_id(
    db: AsyncSession,
    subscription_id: int
) -> Optional[UserSubscription]:
    """
    Get a user subscription by ID.
    
    Args:
        db: Database session
        subscription_id: ID of the subscription
        
    Returns:
        User subscription or None if not found
    """
    query = (
        select(UserSubscription)
        .where(UserSubscription.id == subscription_id)
        .options(joinedload(UserSubscription.plan))
    )
    
    result = await db.execute(query)
    subscription = result.scalars().first()
    
    return subscription


async def create_user_subscription(
    db: AsyncSession,
    user_id: int,
    plan_id: int,
    billing_period: BillingPeriod = BillingPeriod.MONTHLY,
    create_stripe_subscription: bool = True,
    payment_method_id: Optional[str] = None
) -> Optional[UserSubscription]:
    """
    Create a new user subscription.
    
    Args:
        db: Database session
        user_id: ID of the user
        plan_id: ID of the subscription plan
        billing_period: Billing period (monthly or annually)
        create_stripe_subscription: Whether to create a Stripe subscription
        payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True)
        
    Returns:
        Created user subscription or None if creation failed
    """
    # Check if user exists
    query = select(User).where(User.id == user_id)
    result = await db.execute(query)
    user = result.scalars().first()
    
    if not user:
        logger.warning(f"User with ID {user_id} not found")
        return None
    
    # Check if plan exists
    plan = await get_subscription_plan_by_id(db, plan_id)
    
    if not plan:
        logger.warning(f"Subscription plan with ID {plan_id} not found")
        return None
    
    # Check if user already has an active subscription
    existing_subscription = await get_user_subscription(db, user_id)
    
    if existing_subscription:
        logger.warning(f"User with ID {user_id} already has an active subscription")
        return None
    
    # Calculate subscription period
    now = datetime.utcnow()
    
    if billing_period == BillingPeriod.MONTHLY:
        current_period_end = now + timedelta(days=30)
        price = plan.price_monthly
        stripe_price_id = plan.stripe_monthly_price_id
    elif billing_period == BillingPeriod.ANNUALLY:
        current_period_end = now + timedelta(days=365)
        price = plan.price_annually
        stripe_price_id = plan.stripe_annual_price_id
    else:
        logger.warning(f"Invalid billing period: {billing_period}")
        return None
    
    # Create Stripe subscription if requested
    stripe_subscription_id = None
    stripe_customer_id = None
    
    if create_stripe_subscription and stripe.api_key and plan.stripe_product_id:
        if not payment_method_id:
            logger.warning("Payment method ID is required to create a Stripe subscription")
            return None
        
        try:
            # Create or retrieve Stripe customer
            customers = stripe.Customer.list(email=user.email)
            
            if customers.data:
                customer = customers.data[0]
                stripe_customer_id = customer.id
            else:
                customer = stripe.Customer.create(
                    email=user.email,
                    name=user.full_name,
                    metadata={"user_id": user_id}
                )
                
                stripe_customer_id = customer.id
            
            # Attach payment method to customer
            stripe.PaymentMethod.attach(
                payment_method_id,
                customer=stripe_customer_id
            )
            
            # Set as default payment method
            stripe.Customer.modify(
                stripe_customer_id,
                invoice_settings={
                    "default_payment_method": payment_method_id
                }
            )
            
            # Create subscription
            subscription = stripe.Subscription.create(
                customer=stripe_customer_id,
                items=[
                    {"price": stripe_price_id}
                ],
                expand=["latest_invoice.payment_intent"]
            )
            
            stripe_subscription_id = subscription.id
            
            logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}")
        except Exception as e:
            logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}")
            return None
    
    # Create subscription in database
    subscription = UserSubscription(
        user_id=user_id,
        plan_id=plan_id,
        status=SubscriptionStatus.ACTIVE,
        billing_period=billing_period,
        current_period_start=now,
        current_period_end=current_period_end,
        stripe_subscription_id=stripe_subscription_id,
        stripe_customer_id=stripe_customer_id
    )
    
    db.add(subscription)
    await db.commit()
    await db.refresh(subscription)
    
    # Record payment
    if subscription.id:
        payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING
        
        payment = PaymentHistory(
            user_id=user_id,
            subscription_id=subscription.id,
            amount=price,
            currency="USD",
            status=payment_status
        )
        
        db.add(payment)
        await db.commit()
    
    return subscription


async def cancel_user_subscription(
    db: AsyncSession,
    subscription_id: int,
    cancel_stripe_subscription: bool = True
) -> Optional[UserSubscription]:
    """
    Cancel a user subscription.
    
    Args:
        db: Database session
        subscription_id: ID of the subscription to cancel
        cancel_stripe_subscription: Whether to cancel the Stripe subscription
        
    Returns:
        Canceled user subscription or None if cancellation failed
    """
    # Get subscription
    subscription = await get_user_subscription_by_id(db, subscription_id)
    
    if not subscription:
        logger.warning(f"Subscription with ID {subscription_id} not found")
        return None
    
    # Cancel Stripe subscription if requested
    if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key:
        try:
            stripe.Subscription.modify(
                subscription.stripe_subscription_id,
                cancel_at_period_end=True
            )
            
            logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end")
        except Exception as e:
            logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}")
    
    # Update subscription in database
    now = datetime.utcnow()
    
    await db.execute(
        update(UserSubscription)
        .where(UserSubscription.id == subscription_id)
        .values(
            status=SubscriptionStatus.CANCELED,
            canceled_at=now
        )
    )
    
    await db.commit()
    
    # Refresh subscription
    subscription = await get_user_subscription_by_id(db, subscription_id)
    
    return subscription