hardik8588 commited on
Commit
c09a046
·
verified ·
1 Parent(s): 1ceb85a

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitattributes +3 -35
  2. Dockerfile +10 -0
  3. README.md +40 -11
  4. app.py +1408 -0
  5. auth.py +655 -0
  6. fix_users_table.py +180 -0
  7. initialize_plans.py +25 -0
  8. legal_analysis.db +0 -0
  9. paypal_integration.py +1004 -0
  10. requirements.txt +21 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /code
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,40 @@
1
- ---
2
- title: Doc Analyzer
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Legal Document Analysis API
3
+ emoji: 📄
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ # Legal Document Analysis API
12
+
13
+ This API provides tools for analyzing legal documents, videos, and audio files. It uses NLP models to extract insights, summarize content, and answer legal questions.
14
+
15
+ ## Features
16
+
17
+ - Document analysis (PDF)
18
+ - Video and audio transcription and analysis
19
+ - Legal question answering
20
+ - Risk assessment and visualization
21
+ - Contract clause analysis
22
+
23
+ ## Deployment
24
+
25
+ This API is deployed on Hugging Face Spaces.
26
+
27
+ ## API Endpoints
28
+
29
+ - `/analyze_document` - Analyze legal documents
30
+ - `/analyze_legal_video` - Analyze legal videos
31
+ - `/analyze_legal_audio` - Analyze legal audio
32
+ - `/ask_legal_question` - Ask questions about legal documents
33
+
34
+ ## Technologies
35
+
36
+ - FastAPI
37
+ - Hugging Face Transformers
38
+ - SpaCy
39
+ - PyTorch
40
+ - MoviePy
app.py ADDED
@@ -0,0 +1,1408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import uuid
5
+ import tempfile
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import pdfplumber
9
+ import spacy
10
+ import torch
11
+ import sqlite3
12
+ import uvicorn
13
+ import moviepy.editor as mp
14
+ from threading import Thread
15
+ from datetime import datetime, timedelta
16
+ from typing import List, Dict, Optional
17
+ from fastapi import FastAPI, File, UploadFile, Form, Depends, HTTPException, status, Header
18
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ import logging
22
+ from pydantic import BaseModel
23
+ from transformers import (
24
+ AutoTokenizer,
25
+ AutoModelForQuestionAnswering,
26
+ pipeline,
27
+ TrainingArguments,
28
+ Trainer
29
+ )
30
+ from sentence_transformers import SentenceTransformer
31
+ from passlib.context import CryptContext
32
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
33
+ import jwt
34
+ from dotenv import load_dotenv
35
+ # Import get_db_connection from auth
36
+ from auth import (
37
+ User, UserCreate, Token, get_current_active_user, authenticate_user,
38
+ create_access_token, hash_password, register_user, check_subscription_access,
39
+ SUBSCRIPTION_TIERS, JWT_EXPIRATION_DELTA, get_db_connection, update_auth_db_schema, get_subscription_plans
40
+
41
+ )
42
+ # Add this import near the top with your other imports
43
+ from paypal_integration import (
44
+ create_user_subscription, verify_subscription_payment,
45
+ update_user_subscription, handle_subscription_webhook, initialize_database
46
+ )
47
+ from fastapi import Request # Add this if not already imported
48
+
49
+ logging.basicConfig(
50
+ level=logging.INFO,
51
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
52
+ )
53
+ logger = logging.getLogger("app")
54
+
55
+ # Initialize the database
56
+ # Initialize FastAPI app
57
+ app = FastAPI(
58
+ title="Legal Document Analysis API",
59
+ description="API for analyzing legal documents, videos, and audio",
60
+ version="1.0.0"
61
+ )
62
+
63
+ # Set up CORS middleware
64
+ app.add_middleware(
65
+ CORSMiddleware,
66
+ allow_origins=["http://localhost:3000"], # Frontend URL
67
+ allow_credentials=True,
68
+ allow_methods=["*"],
69
+ allow_headers=["*"],
70
+ )
71
+ initialize_database()
72
+ try:
73
+ update_auth_db_schema()
74
+ logger.info("Database schema updated successfully")
75
+ except Exception as e:
76
+ logger.error(f"Database schema update error: {e}")
77
+
78
+ # Create static directory for file storage
79
+ os.makedirs("static", exist_ok=True)
80
+ os.makedirs("uploads", exist_ok=True)
81
+ os.makedirs("temp", exist_ok=True)
82
+ app.mount("/static", StaticFiles(directory="static"), name="static")
83
+
84
+ # Set device for model inference
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ print(f"Using device: {device}")
87
+
88
+ # Initialize chat history
89
+ chat_history = []
90
+
91
+ # Document context storage
92
+ document_contexts = {}
93
+
94
+ def store_document_context(task_id, text):
95
+ """Store document text for later retrieval."""
96
+ document_contexts[task_id] = text
97
+
98
+ def load_document_context(task_id):
99
+ """Load document text for a given task ID."""
100
+ return document_contexts.get(task_id, "")
101
+
102
+
103
+ load_dotenv()
104
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
105
+ os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
106
+
107
+ def fine_tune_qa_model():
108
+ """Fine-tunes a QA model on the CUAD dataset."""
109
+ print("Loading base model for fine-tuning...")
110
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
111
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
112
+
113
+ # Load and preprocess CUAD dataset
114
+ print("Loading CUAD dataset...")
115
+ from datasets import load_dataset
116
+
117
+ try:
118
+ dataset = load_dataset("cuad")
119
+ except Exception as e:
120
+ print(f"Error loading CUAD dataset: {str(e)}")
121
+ print("Downloading CUAD dataset from alternative source...")
122
+ # Implement alternative dataset loading here
123
+ return tokenizer, model
124
+
125
+ print(f"Dataset loaded with {len(dataset['train'])} training examples")
126
+
127
+ # Preprocess the dataset
128
+ def preprocess_function(examples):
129
+ questions = [q.strip() for q in examples["question"]]
130
+ contexts = [c.strip() for c in examples["context"]]
131
+
132
+ inputs = tokenizer(
133
+ questions,
134
+ contexts,
135
+ max_length=384,
136
+ truncation="only_second",
137
+ stride=128,
138
+ return_overflowing_tokens=True,
139
+ return_offsets_mapping=True,
140
+ padding="max_length",
141
+ )
142
+
143
+ offset_mapping = inputs.pop("offset_mapping")
144
+ sample_map = inputs.pop("overflow_to_sample_mapping")
145
+
146
+ answers = examples["answers"]
147
+ start_positions = []
148
+ end_positions = []
149
+
150
+ for i, offset in enumerate(offset_mapping):
151
+ sample_idx = sample_map[i]
152
+ answer = answers[sample_idx]
153
+
154
+ start_char = answer["answer_start"][0] if len(answer["answer_start"]) > 0 else 0
155
+ end_char = start_char + len(answer["text"][0]) if len(answer["text"]) > 0 else 0
156
+
157
+ sequence_ids = inputs.sequence_ids(i)
158
+
159
+ # Find the start and end of the context
160
+ idx = 0
161
+ while sequence_ids[idx] != 1:
162
+ idx += 1
163
+ context_start = idx
164
+
165
+ while idx < len(sequence_ids) and sequence_ids[idx] == 1:
166
+ idx += 1
167
+ context_end = idx - 1
168
+
169
+ # If the answer is not fully inside the context, label is (0, 0)
170
+ if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
171
+ start_positions.append(0)
172
+ end_positions.append(0)
173
+ else:
174
+ # Otherwise it's the start and end token positions
175
+ idx = context_start
176
+ while idx <= context_end and offset[idx][0] <= start_char:
177
+ idx += 1
178
+ start_positions.append(idx - 1)
179
+
180
+ idx = context_end
181
+ while idx >= context_start and offset[idx][1] >= end_char:
182
+ idx -= 1
183
+ end_positions.append(idx + 1)
184
+
185
+ inputs["start_positions"] = start_positions
186
+ inputs["end_positions"] = end_positions
187
+ return inputs
188
+
189
+ print("Preprocessing dataset...")
190
+ processed_dataset = dataset.map(
191
+ preprocess_function,
192
+ batched=True,
193
+ remove_columns=dataset["train"].column_names,
194
+ )
195
+
196
+ print("Splitting dataset...")
197
+ train_dataset = processed_dataset["train"]
198
+ val_dataset = processed_dataset["validation"]
199
+
200
+ train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
201
+ val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
202
+
203
+ training_args = TrainingArguments(
204
+ output_dir="./fine_tuned_legal_qa",
205
+ evaluation_strategy="steps",
206
+ eval_steps=100,
207
+ learning_rate=2e-5,
208
+ per_device_train_batch_size=16,
209
+ per_device_eval_batch_size=16,
210
+ num_train_epochs=1,
211
+ weight_decay=0.01,
212
+ logging_steps=50,
213
+ save_steps=100,
214
+ load_best_model_at_end=True,
215
+ report_to=[]
216
+ )
217
+
218
+ print("✅ Starting fine tuning on CUAD QA dataset...")
219
+ trainer = Trainer(
220
+ model=model,
221
+ args=training_args,
222
+ train_dataset=train_dataset,
223
+ eval_dataset=val_dataset,
224
+ tokenizer=tokenizer,
225
+ )
226
+
227
+ trainer.train()
228
+ print("✅ Fine tuning completed. Saving model...")
229
+
230
+ model.save_pretrained("./fine_tuned_legal_qa")
231
+ tokenizer.save_pretrained("./fine_tuned_legal_qa")
232
+
233
+ return tokenizer, model
234
+
235
+ #############################
236
+ # Load NLP Models #
237
+ #############################
238
+
239
+ # Initialize model variables
240
+ nlp = None
241
+ summarizer = None
242
+ embedding_model = None
243
+ ner_model = None
244
+ speech_to_text = None
245
+ cuad_model = None
246
+ cuad_tokenizer = None
247
+ qa_model = None
248
+
249
+ # Add model caching functionality
250
+ import pickle
251
+ import os.path
252
+
253
+ MODELS_CACHE_DIR = "c:\\Users\\hardi\\OneDrive\\Desktop\\New folder (7)\\doc-vid-analyze-main\\models_cache"
254
+ os.makedirs(MODELS_CACHE_DIR, exist_ok=True)
255
+
256
+ def save_model_to_cache(model, model_name):
257
+ """Save a model to the cache directory"""
258
+ try:
259
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
260
+ with open(cache_path, 'wb') as f:
261
+ pickle.dump(model, f)
262
+ print(f"✅ Saved {model_name} to cache")
263
+ return True
264
+ except Exception as e:
265
+ print(f"⚠️ Failed to save {model_name} to cache: {str(e)}")
266
+ return False
267
+
268
+ def load_model_from_cache(model_name):
269
+ """Load a model from the cache directory"""
270
+ try:
271
+ cache_path = os.path.join(MODELS_CACHE_DIR, f"{model_name}.pkl")
272
+ if os.path.exists(cache_path):
273
+ with open(cache_path, 'rb') as f:
274
+ model = pickle.load(f)
275
+ print(f"✅ Loaded {model_name} from cache")
276
+ return model
277
+ return None
278
+ except Exception as e:
279
+ print(f"⚠️ Failed to load {model_name} from cache: {str(e)}")
280
+ return None
281
+
282
+ # Add a flag to control model loading
283
+ LOAD_MODELS = os.getenv("LOAD_MODELS", "True").lower() in ("true", "1", "t")
284
+
285
+ try:
286
+ if LOAD_MODELS:
287
+ # Try to load SpaCy from cache first
288
+ nlp = load_model_from_cache("spacy_model")
289
+ if nlp is None:
290
+ try:
291
+ nlp = spacy.load("en_core_web_sm")
292
+ save_model_to_cache(nlp, "spacy_model")
293
+ except:
294
+ print("⚠️ SpaCy model not found, downloading...")
295
+ spacy.cli.download("en_core_web_sm")
296
+ nlp = spacy.load("en_core_web_sm")
297
+ save_model_to_cache(nlp, "spacy_model")
298
+
299
+ print("✅ Loading NLP models...")
300
+
301
+ # Load the summarizer with caching
302
+ print("Loading summarizer model...")
303
+ summarizer = load_model_from_cache("summarizer_model")
304
+ if summarizer is None:
305
+ try:
306
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn",
307
+ device=0 if torch.cuda.is_available() else -1)
308
+ save_model_to_cache(summarizer, "summarizer_model")
309
+ print("✅ Summarizer loaded successfully")
310
+ except Exception as e:
311
+ print(f"⚠️ Error loading summarizer: {str(e)}")
312
+ try:
313
+ print("Trying alternative summarizer model...")
314
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6",
315
+ device=0 if torch.cuda.is_available() else -1)
316
+ save_model_to_cache(summarizer, "summarizer_model")
317
+ print("✅ Alternative summarizer loaded successfully")
318
+ except Exception as e2:
319
+ print(f"⚠️ Error loading alternative summarizer: {str(e2)}")
320
+ summarizer = None
321
+
322
+ # Load the embedding model with caching
323
+ print("Loading embedding model...")
324
+ embedding_model = load_model_from_cache("embedding_model")
325
+ if embedding_model is None:
326
+ try:
327
+ embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
328
+ save_model_to_cache(embedding_model, "embedding_model")
329
+ print("✅ Embedding model loaded successfully")
330
+ except Exception as e:
331
+ print(f"⚠️ Error loading embedding model: {str(e)}")
332
+ embedding_model = None
333
+
334
+ # Load the NER model with caching
335
+ print("Loading NER model...")
336
+ ner_model = load_model_from_cache("ner_model")
337
+ if ner_model is None:
338
+ try:
339
+ ner_model = pipeline("ner", model="dslim/bert-base-NER",
340
+ device=0 if torch.cuda.is_available() else -1)
341
+ save_model_to_cache(ner_model, "ner_model")
342
+ print("✅ NER model loaded successfully")
343
+ except Exception as e:
344
+ print(f"⚠️ Error loading NER model: {str(e)}")
345
+ ner_model = None
346
+
347
+ # Speech to text model with caching
348
+ print("Loading speech to text model...")
349
+ speech_to_text = load_model_from_cache("speech_to_text_model")
350
+ if speech_to_text is None:
351
+ try:
352
+ speech_to_text = pipeline("automatic-speech-recognition",
353
+ model="openai/whisper-medium",
354
+ chunk_length_s=30,
355
+ device_map="auto" if torch.cuda.is_available() else "cpu")
356
+ save_model_to_cache(speech_to_text, "speech_to_text_model")
357
+ print("✅ Speech to text model loaded successfully")
358
+ except Exception as e:
359
+ print(f"⚠️ Error loading speech to text model: {str(e)}")
360
+ speech_to_text = None
361
+
362
+ # Load the fine-tuned model with caching
363
+ print("Loading fine-tuned CUAD QA model...")
364
+ cuad_model = load_model_from_cache("cuad_model")
365
+ cuad_tokenizer = load_model_from_cache("cuad_tokenizer")
366
+
367
+ if cuad_model is None or cuad_tokenizer is None:
368
+ try:
369
+ cuad_tokenizer = AutoTokenizer.from_pretrained("hardik8588/fine-tuned-legal-qa")
370
+ from transformers import AutoModelForQuestionAnswering
371
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("hardik8588/fine-tuned-legal-qa")
372
+ cuad_model.to(device)
373
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
374
+ save_model_to_cache(cuad_model, "cuad_model")
375
+ print("✅ Successfully loaded fine-tuned model")
376
+ except Exception as e:
377
+ print(f"⚠️ Error loading fine-tuned model: {str(e)}")
378
+ print("⚠️ Falling back to pre-trained model...")
379
+ try:
380
+ cuad_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
381
+ from transformers import AutoModelForQuestionAnswering
382
+ cuad_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
383
+ cuad_model.to(device)
384
+ save_model_to_cache(cuad_tokenizer, "cuad_tokenizer")
385
+ save_model_to_cache(cuad_model, "cuad_model")
386
+ print("✅ Pre-trained model loaded successfully")
387
+ except Exception as e2:
388
+ print(f"⚠️ Error loading pre-trained model: {str(e2)}")
389
+ cuad_model = None
390
+ cuad_tokenizer = None
391
+
392
+ # Load a general QA model with caching
393
+ print("Loading general QA model...")
394
+ qa_model = load_model_from_cache("qa_model")
395
+ if qa_model is None:
396
+ try:
397
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
398
+ save_model_to_cache(qa_model, "qa_model")
399
+ print("✅ QA model loaded successfully")
400
+ except Exception as e:
401
+ print(f"⚠️ Error loading QA model: {str(e)}")
402
+ qa_model = None
403
+
404
+ print("✅ All models loaded successfully")
405
+ else:
406
+ print("⚠️ Model loading skipped (LOAD_MODELS=False)")
407
+
408
+ except Exception as e:
409
+ print(f"⚠️ Error loading models: {str(e)}")
410
+ # Instead of raising an error, set fallback behavior
411
+ nlp = None
412
+ summarizer = None
413
+ embedding_model = None
414
+ ner_model = None
415
+ speech_to_text = None
416
+ cuad_model = None
417
+ cuad_tokenizer = None
418
+ qa_model = None
419
+ print("⚠️ Running with limited functionality due to model loading errors")
420
+
421
+ def legal_chatbot(user_input, context):
422
+ """Uses a real NLP model for legal Q&A."""
423
+ global chat_history
424
+ chat_history.append({"role": "user", "content": user_input})
425
+ response = qa_model(question=user_input, context=context)["answer"]
426
+ chat_history.append({"role": "assistant", "content": response})
427
+ return response
428
+
429
+ def extract_text_from_pdf(pdf_file):
430
+ """Extracts text from a PDF file using pdfplumber."""
431
+ try:
432
+ # Suppress pdfplumber warnings about CropBox
433
+ import logging
434
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
435
+
436
+ with pdfplumber.open(pdf_file) as pdf:
437
+ print(f"Processing PDF with {len(pdf.pages)} pages")
438
+ text = ""
439
+ for i, page in enumerate(pdf.pages):
440
+ page_text = page.extract_text() or ""
441
+ text += page_text + "\n"
442
+ if (i + 1) % 10 == 0: # Log progress every 10 pages
443
+ print(f"Processed {i + 1} pages...")
444
+
445
+ print(f"✅ PDF text extraction complete: {len(text)} characters extracted")
446
+ return text.strip() if text else None
447
+ except Exception as e:
448
+ print(f"❌ PDF extraction error: {str(e)}")
449
+ raise HTTPException(status_code=400, detail=f"PDF extraction failed: {str(e)}")
450
+
451
+ def process_video_to_text(video_file_path):
452
+ """Extract audio from video and convert to text."""
453
+ try:
454
+ print(f"Processing video file at {video_file_path}")
455
+ temp_audio_path = os.path.join("temp", "extracted_audio.wav")
456
+ video = mp.VideoFileClip(video_file_path)
457
+ video.audio.write_audiofile(temp_audio_path, codec='pcm_s16le')
458
+ print(f"Audio extracted to {temp_audio_path}")
459
+ result = speech_to_text(temp_audio_path)
460
+ transcript = result["text"]
461
+ print(f"Transcription completed: {len(transcript)} characters")
462
+ if os.path.exists(temp_audio_path):
463
+ os.remove(temp_audio_path)
464
+ return transcript
465
+ except Exception as e:
466
+ print(f"Error in video processing: {str(e)}")
467
+ raise HTTPException(status_code=400, detail=f"Video processing failed: {str(e)}")
468
+
469
+ def process_audio_to_text(audio_file_path):
470
+ """Process audio file and convert to text."""
471
+ try:
472
+ print(f"Processing audio file at {audio_file_path}")
473
+ result = speech_to_text(audio_file_path)
474
+ transcript = result["text"]
475
+ print(f"Transcription completed: {len(transcript)} characters")
476
+ return transcript
477
+ except Exception as e:
478
+ print(f"Error in audio processing: {str(e)}")
479
+ raise HTTPException(status_code=400, detail=f"Audio processing failed: {str(e)}")
480
+
481
+ def extract_named_entities(text):
482
+ """Extracts named entities from legal text."""
483
+ max_length = 10000
484
+ entities = []
485
+ for i in range(0, len(text), max_length):
486
+ chunk = text[i:i+max_length]
487
+ doc = nlp(chunk)
488
+ entities.extend([{"entity": ent.text, "label": ent.label_} for ent in doc.ents])
489
+ return entities
490
+
491
+ def analyze_risk(text):
492
+ """Analyzes legal risk in the document using keyword-based analysis."""
493
+ risk_keywords = {
494
+ "Liability": ["liability", "responsible", "responsibility", "legal obligation"],
495
+ "Termination": ["termination", "breach", "contract end", "default"],
496
+ "Indemnification": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"],
497
+ "Payment Risk": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"],
498
+ "Insurance": ["insurance", "coverage", "policy", "claims"],
499
+ }
500
+ risk_scores = {category: 0 for category in risk_keywords}
501
+ lower_text = text.lower()
502
+ for category, keywords in risk_keywords.items():
503
+ for keyword in keywords:
504
+ risk_scores[category] += lower_text.count(keyword.lower())
505
+ return risk_scores
506
+
507
+ def extract_context_for_risk_terms(text, risk_keywords, window=1):
508
+ """
509
+ Extracts and summarizes the context around risk terms.
510
+ """
511
+ doc = nlp(text)
512
+ sentences = list(doc.sents)
513
+ risk_contexts = {category: [] for category in risk_keywords}
514
+ for i, sent in enumerate(sentences):
515
+ sent_text_lower = sent.text.lower()
516
+ for category, details in risk_keywords.items():
517
+ for keyword in details["keywords"]:
518
+ if keyword.lower() in sent_text_lower:
519
+ start_idx = max(0, i - window)
520
+ end_idx = min(len(sentences), i + window + 1)
521
+ context_chunk = " ".join([s.text for s in sentences[start_idx:end_idx]])
522
+ risk_contexts[category].append(context_chunk)
523
+ summarized_contexts = {}
524
+ for category, contexts in risk_contexts.items():
525
+ if contexts:
526
+ combined_context = " ".join(contexts)
527
+ try:
528
+ summary_result = summarizer(combined_context, max_length=100, min_length=30, do_sample=False)
529
+ summary = summary_result[0]['summary_text']
530
+ except Exception as e:
531
+ summary = "Context summarization failed."
532
+ summarized_contexts[category] = summary
533
+ else:
534
+ summarized_contexts[category] = "No contextual details found."
535
+ return summarized_contexts
536
+
537
+ def get_detailed_risk_info(text):
538
+ """
539
+ Returns detailed risk information by merging risk scores with descriptive details
540
+ and contextual summaries from the document.
541
+ """
542
+ risk_details = {
543
+ "Liability": {
544
+ "description": "Liability refers to the legal responsibility for losses or damages.",
545
+ "common_concerns": "Broad liability clauses may expose parties to unforeseen risks.",
546
+ "recommendations": "Review and negotiate clear limits on liability.",
547
+ "example": "E.g., 'The party shall be liable for direct damages due to negligence.'"
548
+ },
549
+ "Termination": {
550
+ "description": "Termination involves conditions under which a contract can be ended.",
551
+ "common_concerns": "Unilateral termination rights or ambiguous conditions can be risky.",
552
+ "recommendations": "Ensure termination clauses are balanced and include notice periods.",
553
+ "example": "E.g., 'Either party may terminate the agreement with 30 days notice.'"
554
+ },
555
+ "Indemnification": {
556
+ "description": "Indemnification requires one party to compensate for losses incurred by the other.",
557
+ "common_concerns": "Overly broad indemnification can shift significant risk.",
558
+ "recommendations": "Negotiate clear limits and carve-outs where necessary.",
559
+ "example": "E.g., 'The seller shall indemnify the buyer against claims from product defects.'"
560
+ },
561
+ "Payment Risk": {
562
+ "description": "Payment risk pertains to terms regarding fees, schedules, and reimbursements.",
563
+ "common_concerns": "Vague payment terms or hidden charges increase risk.",
564
+ "recommendations": "Clarify payment conditions and include penalties for delays.",
565
+ "example": "E.g., 'Payments must be made within 30 days, with a 2% late fee thereafter.'"
566
+ },
567
+ "Insurance": {
568
+ "description": "Insurance risk covers the adequacy and scope of required coverage.",
569
+ "common_concerns": "Insufficient insurance can leave parties exposed in unexpected events.",
570
+ "recommendations": "Review insurance requirements to ensure they meet the risk profile.",
571
+ "example": "E.g., 'The contractor must maintain liability insurance with at least $1M coverage.'"
572
+ }
573
+ }
574
+ risk_scores = analyze_risk(text)
575
+ risk_keywords_context = {
576
+ "Liability": {"keywords": ["liability", "responsible", "responsibility", "legal obligation"]},
577
+ "Termination": {"keywords": ["termination", "breach", "contract end", "default"]},
578
+ "Indemnification": {"keywords": ["indemnification", "indemnify", "hold harmless", "compensate", "compensation"]},
579
+ "Payment Risk": {"keywords": ["payment", "terms", "reimbursement", "fee", "schedule", "invoice", "money"]},
580
+ "Insurance": {"keywords": ["insurance", "coverage", "policy", "claims"]}
581
+ }
582
+ risk_contexts = extract_context_for_risk_terms(text, risk_keywords_context, window=1)
583
+ detailed_info = {}
584
+ for risk_term, score in risk_scores.items():
585
+ if score > 0:
586
+ info = risk_details.get(risk_term, {"description": "No details available."})
587
+ detailed_info[risk_term] = {
588
+ "score": score,
589
+ "description": info.get("description", ""),
590
+ "common_concerns": info.get("common_concerns", ""),
591
+ "recommendations": info.get("recommendations", ""),
592
+ "example": info.get("example", ""),
593
+ "context_summary": risk_contexts.get(risk_term, "No context available.")
594
+ }
595
+ return detailed_info
596
+
597
+ def analyze_contract_clauses(text):
598
+ """Analyzes contract clauses using the fine-tuned CUAD QA model."""
599
+ max_length = 512
600
+ step = 256
601
+ clauses_detected = []
602
+ try:
603
+ clause_types = list(cuad_model.config.id2label.values())
604
+ except Exception as e:
605
+ clause_types = [
606
+ "Obligations of Seller", "Governing Law", "Termination", "Indemnification",
607
+ "Confidentiality", "Insurance", "Non-Compete", "Change of Control",
608
+ "Assignment", "Warranty", "Limitation of Liability", "Arbitration",
609
+ "IP Rights", "Force Majeure", "Revenue/Profit Sharing", "Audit Rights"
610
+ ]
611
+ chunks = [text[i:i+max_length] for i in range(0, len(text), step) if i+step < len(text)]
612
+ for chunk in chunks:
613
+ inputs = cuad_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512).to(device)
614
+ with torch.no_grad():
615
+ outputs = cuad_model(**inputs)
616
+ predictions = torch.sigmoid(outputs.start_logits).cpu().numpy()[0]
617
+ for idx, confidence in enumerate(predictions):
618
+ if confidence > 0.5 and idx < len(clause_types):
619
+ clauses_detected.append({"type": clause_types[idx], "confidence": float(confidence)})
620
+ aggregated_clauses = {}
621
+ for clause in clauses_detected:
622
+ clause_type = clause["type"]
623
+ if clause_type not in aggregated_clauses or clause["confidence"] > aggregated_clauses[clause_type]["confidence"]:
624
+ aggregated_clauses[clause_type] = clause
625
+ return list(aggregated_clauses.values())
626
+
627
+ def summarize_text(text):
628
+ """Summarizes legal text using the summarizer model."""
629
+ try:
630
+ if summarizer is None:
631
+ return "Basic analysis (NLP models not available)"
632
+
633
+ # Split text into chunks if it's too long
634
+ max_chunk_size = 1024
635
+ if len(text) > max_chunk_size:
636
+ chunks = [text[i:i+max_chunk_size] for i in range(0, len(text), max_chunk_size)]
637
+ summaries = []
638
+ for chunk in chunks:
639
+ summary = summarizer(chunk, max_length=100, min_length=30, do_sample=False)
640
+ summaries.append(summary[0]['summary_text'])
641
+ return " ".join(summaries)
642
+ else:
643
+ summary = summarizer(text, max_length=100, min_length=30, do_sample=False)
644
+ return summary[0]['summary_text']
645
+ except Exception as e:
646
+ print(f"Error in summarization: {str(e)}")
647
+ return "Summarization failed. Please try again later."
648
+
649
+ @app.post("/analyze_legal_document")
650
+ async def analyze_legal_document(
651
+ file: UploadFile = File(...),
652
+ current_user: User = Depends(get_current_active_user)
653
+ ):
654
+ """Analyzes a legal document (PDF) and returns insights based on subscription tier."""
655
+ try:
656
+ # Calculate file size in MB
657
+ file_content = await file.read()
658
+ file_size_mb = len(file_content) / (1024 * 1024)
659
+
660
+ # Check subscription access for document analysis
661
+ check_subscription_access(current_user, "document_analysis", file_size_mb)
662
+
663
+ print(f"Processing file: {file.filename}")
664
+
665
+ # Create a temporary file to store the uploaded PDF
666
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
667
+ tmp.write(file_content)
668
+ tmp_path = tmp.name
669
+
670
+ # Extract text from PDF
671
+ text = extract_text_from_pdf(tmp_path)
672
+
673
+ # Clean up the temporary file
674
+ os.unlink(tmp_path)
675
+
676
+ if not text:
677
+ raise HTTPException(status_code=400, detail="Could not extract text from PDF")
678
+
679
+ # Generate a task ID
680
+ task_id = str(uuid.uuid4())
681
+
682
+ # Store document context for later retrieval
683
+ store_document_context(task_id, text)
684
+
685
+ # Basic analysis available to all tiers
686
+ summary = summarize_text(text)
687
+ entities = extract_named_entities(text)
688
+ risk_scores = analyze_risk(text)
689
+
690
+ # Prepare response based on subscription tier
691
+ response = {
692
+ "task_id": task_id,
693
+ "summary": summary,
694
+ "entities": entities,
695
+ "risk_assessment": risk_scores,
696
+ "subscription_tier": current_user.subscription_tier
697
+ }
698
+
699
+ # Add premium features if user has access
700
+ if current_user.subscription_tier == "premium_tier":
701
+ # Add detailed risk assessment
702
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
703
+ detailed_risk = get_detailed_risk_info(text)
704
+ response["detailed_risk_assessment"] = detailed_risk
705
+
706
+ # Add contract clause analysis
707
+ if "contract_clause_analysis" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
708
+ clauses = analyze_contract_clauses(text)
709
+ response["contract_clauses"] = clauses
710
+
711
+ return response
712
+
713
+ except Exception as e:
714
+ print(f"Error analyzing document: {str(e)}")
715
+ raise HTTPException(status_code=500, detail=f"Error analyzing document: {str(e)}")
716
+
717
+ @app.post("/analyze_legal_video")
718
+ async def analyze_legal_video(
719
+ file: UploadFile = File(...),
720
+ current_user: User = Depends(get_current_active_user)
721
+ ):
722
+ """Analyzes legal video by transcribing and analyzing the transcript."""
723
+ try:
724
+ # Calculate file size in MB
725
+ file_content = await file.read()
726
+ file_size_mb = len(file_content) / (1024 * 1024)
727
+
728
+ # Check subscription access for video analysis
729
+ check_subscription_access(current_user, "video_analysis", file_size_mb)
730
+
731
+ print(f"Processing video file: {file.filename}")
732
+
733
+ # Create a temporary file to store the uploaded video
734
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp:
735
+ tmp.write(file_content)
736
+ tmp_path = tmp.name
737
+
738
+ # Process video to extract transcript
739
+ transcript = process_video_to_text(tmp_path)
740
+
741
+ # Clean up the temporary file
742
+ os.unlink(tmp_path)
743
+
744
+ if not transcript:
745
+ raise HTTPException(status_code=400, detail="Could not extract transcript from video")
746
+
747
+ # Generate a task ID
748
+ task_id = str(uuid.uuid4())
749
+
750
+ # Store document context for later retrieval
751
+ store_document_context(task_id, transcript)
752
+
753
+ # Basic analysis
754
+ summary = summarize_text(transcript)
755
+ entities = extract_named_entities(transcript)
756
+ risk_scores = analyze_risk(transcript)
757
+
758
+ # Prepare response
759
+ response = {
760
+ "task_id": task_id,
761
+ "transcript": transcript,
762
+ "summary": summary,
763
+ "entities": entities,
764
+ "risk_assessment": risk_scores,
765
+ "subscription_tier": current_user.subscription_tier
766
+ }
767
+
768
+ # Add premium features if user has access
769
+ if current_user.subscription_tier == "premium_tier":
770
+ # Add detailed risk assessment
771
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
772
+ detailed_risk = get_detailed_risk_info(transcript)
773
+ response["detailed_risk_assessment"] = detailed_risk
774
+
775
+ return response
776
+
777
+ except Exception as e:
778
+ print(f"Error analyzing video: {str(e)}")
779
+ raise HTTPException(status_code=500, detail=f"Error analyzing video: {str(e)}")
780
+
781
+
782
+ @app.post("/legal_chatbot/{task_id}")
783
+ async def chat_with_document(
784
+ task_id: str,
785
+ question: str = Form(...),
786
+ current_user: User = Depends(get_current_active_user)
787
+ ):
788
+ """Chat with a document using the legal chatbot."""
789
+ try:
790
+ # Check if user has access to chatbot feature
791
+ if "chatbot" not in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
792
+ raise HTTPException(
793
+ status_code=403,
794
+ detail=f"The chatbot feature is not available in your {current_user.subscription_tier} subscription. Please upgrade to access this feature."
795
+ )
796
+
797
+ # Check if document context exists
798
+ context = load_document_context(task_id)
799
+ if not context:
800
+ raise HTTPException(status_code=404, detail="Document context not found. Please analyze a document first.")
801
+
802
+ # Use the chatbot to answer the question
803
+ answer = legal_chatbot(question, context)
804
+
805
+ return {"answer": answer, "chat_history": chat_history}
806
+
807
+ except Exception as e:
808
+ print(f"Error in chatbot: {str(e)}")
809
+ raise HTTPException(status_code=500, detail=f"Error in chatbot: {str(e)}")
810
+
811
+ @app.get("/")
812
+ async def root():
813
+ """Root endpoint that returns a welcome message."""
814
+ return HTMLResponse(content="""
815
+ <html>
816
+ <head>
817
+ <title>Legal Document Analysis API</title>
818
+ <style>
819
+ body {
820
+ font-family: Arial, sans-serif;
821
+ max-width: 800px;
822
+ margin: 0 auto;
823
+ padding: 20px;
824
+ }
825
+ h1 {
826
+ color: #2c3e50;
827
+ }
828
+ .endpoint {
829
+ background-color: #f8f9fa;
830
+ padding: 15px;
831
+ margin-bottom: 10px;
832
+ border-radius: 5px;
833
+ }
834
+ .method {
835
+ font-weight: bold;
836
+ color: #e74c3c;
837
+ }
838
+ </style>
839
+ </head>
840
+ <body>
841
+ <h1>Legal Document Analysis API</h1>
842
+ <p>Welcome to the Legal Document Analysis API. This API provides tools for analyzing legal documents, videos, and audio.</p>
843
+ <h2>Available Endpoints:</h2>
844
+ <div class="endpoint">
845
+ <p><span class="method">POST</span> /analyze_legal_document - Analyze a legal document (PDF)</p>
846
+ </div>
847
+ <div class="endpoint">
848
+ <p><span class="method">POST</span> /analyze_legal_video - Analyze a legal video</p>
849
+ </div>
850
+ <div class="endpoint">
851
+ <p><span class="method">POST</span> /analyze_legal_audio - Analyze legal audio</p>
852
+ </div>
853
+ <div class="endpoint">
854
+ <p><span class="method">POST</span> /legal_chatbot/{task_id} - Chat with a document</p>
855
+ </div>
856
+ <div class="endpoint">
857
+ <p><span class="method">POST</span> /register - Register a new user</p>
858
+ </div>
859
+ <div class="endpoint">
860
+ <p><span class="method">POST</span> /token - Login to get an access token</p>
861
+ </div>
862
+ <div class="endpoint">
863
+ <p><span class="method">GET</span> /users/me - Get current user information</p>
864
+ </div>
865
+ <div class="endpoint">
866
+ <p><span class="method">POST</span> /subscribe/{tier} - Subscribe to a plan</p>
867
+ </div>
868
+ <p>For more details, visit the <a href="/docs">API documentation</a>.</p>
869
+ </body>
870
+ </html>
871
+ """)
872
+
873
+ @app.post("/register", response_model=Token)
874
+ async def register_new_user(user_data: UserCreate):
875
+ """Register a new user with a free subscription"""
876
+ try:
877
+ success, result = register_user(user_data.email, user_data.password)
878
+
879
+ if not success:
880
+ raise HTTPException(status_code=400, detail=result)
881
+
882
+ return {"access_token": result["access_token"], "token_type": "bearer"}
883
+
884
+ except HTTPException:
885
+ # Re-raise HTTP exceptions
886
+ raise
887
+ except Exception as e:
888
+ print(f"Registration error: {str(e)}")
889
+ raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
890
+
891
+ @app.post("/token", response_model=Token)
892
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
893
+ """Endpoint for OAuth2 token generation"""
894
+ try:
895
+ # Add debug logging
896
+ logger.info(f"Token request for username: {form_data.username}")
897
+
898
+ user = authenticate_user(form_data.username, form_data.password)
899
+ if not user:
900
+ logger.warning(f"Authentication failed for: {form_data.username}")
901
+ raise HTTPException(
902
+ status_code=status.HTTP_401_UNAUTHORIZED,
903
+ detail="Incorrect username or password",
904
+ headers={"WWW-Authenticate": "Bearer"},
905
+ )
906
+
907
+ access_token = create_access_token(user.id)
908
+ if not access_token:
909
+ logger.error(f"Failed to create access token for user: {user.id}")
910
+ raise HTTPException(
911
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
912
+ detail="Could not create access token",
913
+ )
914
+
915
+ logger.info(f"Login successful for: {form_data.username}")
916
+ return {"access_token": access_token, "token_type": "bearer"}
917
+ except Exception as e:
918
+ logger.error(f"Token endpoint error: {e}")
919
+ raise HTTPException(
920
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
921
+ detail=f"Login error: {str(e)}",
922
+ )
923
+
924
+
925
+ @app.get("/debug/token")
926
+ async def debug_token(authorization: str = Header(None)):
927
+ """Debug endpoint to check token validity"""
928
+ try:
929
+ if not authorization:
930
+ return {"valid": False, "error": "No authorization header provided"}
931
+
932
+ # Extract token from Authorization header
933
+ scheme, token = authorization.split()
934
+ if scheme.lower() != 'bearer':
935
+ return {"valid": False, "error": "Not a bearer token"}
936
+
937
+ # Log the token for debugging
938
+ logger.info(f"Debugging token: {token[:10]}...")
939
+
940
+ # Try to validate the token
941
+ try:
942
+ user = await get_current_active_user(token)
943
+ return {"valid": True, "user_id": user.id, "email": user.email}
944
+ except Exception as e:
945
+ return {"valid": False, "error": str(e)}
946
+ except Exception as e:
947
+ return {"valid": False, "error": f"Token debug error: {str(e)}"}
948
+
949
+
950
+ @app.post("/login")
951
+ async def api_login(email: str, password: str):
952
+ success, result = login_user(email, password)
953
+ if not success:
954
+ raise HTTPException(
955
+ status_code=status.HTTP_401_UNAUTHORIZED,
956
+ detail=result
957
+ )
958
+ return result
959
+
960
+ @app.get("/health")
961
+ def health_check():
962
+ """Simple health check endpoint to verify the API is running"""
963
+ return {"status": "ok", "message": "API is running"}
964
+
965
+ @app.get("/users/me", response_model=User)
966
+ async def read_users_me(current_user: User = Depends(get_current_active_user)):
967
+ return current_user
968
+
969
+ @app.post("/analyze_legal_audio")
970
+ async def analyze_legal_audio(
971
+ file: UploadFile = File(...),
972
+ current_user: User = Depends(get_current_active_user)
973
+ ):
974
+ """Analyzes legal audio by transcribing and analyzing the transcript."""
975
+ try:
976
+ # Calculate file size in MB
977
+ file_content = await file.read()
978
+ file_size_mb = len(file_content) / (1024 * 1024)
979
+
980
+ # Check subscription access for audio analysis
981
+ check_subscription_access(current_user, "audio_analysis", file_size_mb)
982
+
983
+ print(f"Processing audio file: {file.filename}")
984
+
985
+ # Create a temporary file to store the uploaded audio
986
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
987
+ tmp.write(file_content)
988
+ tmp_path = tmp.name
989
+
990
+ # Process audio to extract transcript
991
+ transcript = process_audio_to_text(tmp_path)
992
+
993
+ # Clean up the temporary file
994
+ os.unlink(tmp_path)
995
+
996
+ if not transcript:
997
+ raise HTTPException(status_code=400, detail="Could not extract transcript from audio")
998
+
999
+ # Generate a task ID
1000
+ task_id = str(uuid.uuid4())
1001
+
1002
+ # Store document context for later retrieval
1003
+ store_document_context(task_id, transcript)
1004
+
1005
+ # Basic analysis
1006
+ summary = summarize_text(transcript)
1007
+ entities = extract_named_entities(transcript)
1008
+ risk_scores = analyze_risk(transcript)
1009
+
1010
+ # Prepare response
1011
+ response = {
1012
+ "task_id": task_id,
1013
+ "transcript": transcript,
1014
+ "summary": summary,
1015
+ "entities": entities,
1016
+ "risk_assessment": risk_scores,
1017
+ "subscription_tier": current_user.subscription_tier
1018
+ }
1019
+
1020
+ # Add premium features if user has access
1021
+ if current_user.subscription_tier == "premium_tier": # Change from premium_tier to premium
1022
+ # Add detailed risk assessment
1023
+ if "detailed_risk_assessment" in SUBSCRIPTION_TIERS[current_user.subscription_tier]["features"]:
1024
+ detailed_risk = get_detailed_risk_info(transcript)
1025
+ response["detailed_risk_assessment"] = detailed_risk
1026
+
1027
+ return response
1028
+
1029
+ except Exception as e:
1030
+ print(f"Error analyzing audio: {str(e)}")
1031
+ raise HTTPException(status_code=500, detail=f"Error analyzing audio: {str(e)}")
1032
+
1033
+
1034
+
1035
+ # Add these new endpoints before the if __name__ == "__main__" line
1036
+ @app.get("/users/me/subscription")
1037
+ async def get_user_subscription(current_user: User = Depends(get_current_active_user)):
1038
+ """Get the current user's subscription details"""
1039
+ try:
1040
+ # Get subscription details from database
1041
+ conn = get_db_connection()
1042
+ cursor = conn.cursor()
1043
+
1044
+ # Get the most recent active subscription
1045
+ try:
1046
+ cursor.execute(
1047
+ "SELECT id, tier, status, created_at, expires_at, paypal_subscription_id FROM subscriptions "
1048
+ "WHERE user_id = ? AND status = 'active' ORDER BY created_at DESC LIMIT 1",
1049
+ (current_user.id,)
1050
+ )
1051
+ subscription = cursor.fetchone()
1052
+ except sqlite3.OperationalError as e:
1053
+ # Handle missing tier column
1054
+ if "no such column: tier" in str(e):
1055
+ logger.warning("Subscriptions table missing 'tier' column. Returning default subscription.")
1056
+ subscription = None
1057
+ else:
1058
+ raise
1059
+
1060
+ # Get subscription tiers with pricing directly from SUBSCRIPTION_TIERS
1061
+ subscription_tiers = {
1062
+ "free_tier": {
1063
+ "price": SUBSCRIPTION_TIERS["free_tier"]["price"],
1064
+ "currency": SUBSCRIPTION_TIERS["free_tier"]["currency"],
1065
+ "features": SUBSCRIPTION_TIERS["free_tier"]["features"]
1066
+ },
1067
+ "standard_tier": {
1068
+ "price": SUBSCRIPTION_TIERS["standard_tier"]["price"],
1069
+ "currency": SUBSCRIPTION_TIERS["standard_tier"]["currency"],
1070
+ "features": SUBSCRIPTION_TIERS["standard_tier"]["features"]
1071
+ },
1072
+ "premium_tier": {
1073
+ "price": SUBSCRIPTION_TIERS["premium_tier"]["price"],
1074
+ "currency": SUBSCRIPTION_TIERS["premium_tier"]["currency"],
1075
+ "features": SUBSCRIPTION_TIERS["premium_tier"]["features"]
1076
+ }
1077
+ }
1078
+
1079
+ if subscription:
1080
+ sub_id, tier, status, created_at, expires_at, paypal_id = subscription
1081
+ result = {
1082
+ "id": sub_id,
1083
+ "tier": tier,
1084
+ "status": status,
1085
+ "created_at": created_at,
1086
+ "expires_at": expires_at,
1087
+ "paypal_subscription_id": paypal_id,
1088
+ "current_tier": current_user.subscription_tier,
1089
+ "subscription_tiers": subscription_tiers
1090
+ }
1091
+ else:
1092
+ result = {
1093
+ "tier": "free_tier",
1094
+ "status": "active",
1095
+ "current_tier": current_user.subscription_tier,
1096
+ "subscription_tiers": subscription_tiers
1097
+ }
1098
+
1099
+ conn.close()
1100
+ return result
1101
+ except Exception as e:
1102
+ logger.error(f"Error getting subscription: {str(e)}")
1103
+ raise HTTPException(status_code=500, detail=f"Error getting subscription: {str(e)}")
1104
+ # Add this model definition before your endpoints
1105
+ class SubscriptionCreate(BaseModel):
1106
+ tier: str
1107
+
1108
+ @app.post("/create_subscription")
1109
+ async def create_subscription(
1110
+ subscription: SubscriptionCreate,
1111
+ current_user: User = Depends(get_current_active_user)
1112
+ ):
1113
+ """Create a subscription for the current user"""
1114
+ try:
1115
+ # Log the request for debugging
1116
+ logger.info(f"Creating subscription for user {current_user.email} with tier {subscription.tier}")
1117
+ logger.info(f"Available tiers: {list(SUBSCRIPTION_TIERS.keys())}")
1118
+
1119
+ # Validate tier
1120
+ valid_tiers = ["standard_tier", "premium_tier"]
1121
+ if subscription.tier not in valid_tiers:
1122
+ logger.warning(f"Invalid tier requested: {subscription.tier}")
1123
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {subscription.tier}. Must be one of {valid_tiers}")
1124
+
1125
+ # Create subscription
1126
+ logger.info(f"Calling create_user_subscription with email: {current_user.email}, tier: {subscription.tier}")
1127
+ success, result = create_user_subscription(current_user.email, subscription.tier)
1128
+
1129
+ if not success:
1130
+ logger.error(f"Failed to create subscription: {result}")
1131
+ raise HTTPException(status_code=400, detail=result)
1132
+
1133
+ logger.info(f"Subscription created successfully: {result}")
1134
+ return result
1135
+ except Exception as e:
1136
+ logger.error(f"Error creating subscription: {str(e)}")
1137
+ # Include the full traceback for better debugging
1138
+ import traceback
1139
+ logger.error(f"Traceback: {traceback.format_exc()}")
1140
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1141
+
1142
+ @app.post("/subscribe/{tier}")
1143
+ async def subscribe_to_tier(
1144
+ tier: str,
1145
+ current_user: User = Depends(get_current_active_user)
1146
+ ):
1147
+ """Subscribe to a specific tier"""
1148
+ try:
1149
+ # Validate tier
1150
+ valid_tiers = ["standard_tier", "premium_tier"]
1151
+ if tier not in valid_tiers:
1152
+ raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}. Must be one of {valid_tiers}")
1153
+
1154
+ # Create subscription
1155
+ success, result = create_user_subscription(current_user.email, tier)
1156
+
1157
+ if not success:
1158
+ raise HTTPException(status_code=400, detail=result)
1159
+
1160
+ return result
1161
+ except Exception as e:
1162
+ logger.error(f"Error creating subscription: {str(e)}")
1163
+ raise HTTPException(status_code=500, detail=f"Error creating subscription: {str(e)}")
1164
+
1165
+ @app.post("/subscription/create")
1166
+ async def create_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1167
+ """Create a subscription for the current user"""
1168
+ try:
1169
+ data = await request.json()
1170
+ tier = data.get("tier")
1171
+
1172
+ if not tier:
1173
+ return JSONResponse(
1174
+ status_code=400,
1175
+ content={"detail": "Tier is required"}
1176
+ )
1177
+
1178
+ # Log the request for debugging
1179
+ logger.info(f"Creating subscription for user {current_user.email} with tier {tier}")
1180
+
1181
+ # Create the subscription using the imported function directly
1182
+ success, result = create_user_subscription(current_user.email, tier)
1183
+
1184
+ if success:
1185
+ # Make sure we're returning the approval_url in the response
1186
+ logger.info(f"Subscription created successfully: {result}")
1187
+ logger.info(f"Approval URL: {result.get('approval_url')}")
1188
+
1189
+ return {
1190
+ "success": True,
1191
+ "data": {
1192
+ "approval_url": result["approval_url"],
1193
+ "subscription_id": result["subscription_id"],
1194
+ "tier": result["tier"]
1195
+ }
1196
+ }
1197
+ else:
1198
+ logger.error(f"Failed to create subscription: {result}")
1199
+ return JSONResponse(
1200
+ status_code=400,
1201
+ content={"success": False, "detail": result}
1202
+ )
1203
+ except Exception as e:
1204
+ logger.error(f"Error creating subscription: {str(e)}")
1205
+ import traceback
1206
+ logger.error(f"Traceback: {traceback.format_exc()}")
1207
+ return JSONResponse(
1208
+ status_code=500,
1209
+ content={"success": False, "detail": f"Error creating subscription: {str(e)}"}
1210
+ )
1211
+
1212
+ @app.post("/admin/initialize-paypal-plans")
1213
+ async def initialize_paypal_plans(request: Request):
1214
+ """Initialize PayPal subscription plans"""
1215
+ try:
1216
+ # This should be protected with admin authentication in production
1217
+ plans = initialize_subscription_plans()
1218
+
1219
+ if plans:
1220
+ return JSONResponse(
1221
+ status_code=200,
1222
+ content={"success": True, "plans": plans}
1223
+ )
1224
+ else:
1225
+ return JSONResponse(
1226
+ status_code=500,
1227
+ content={"success": False, "detail": "Failed to initialize plans"}
1228
+ )
1229
+ except Exception as e:
1230
+ logger.error(f"Error initializing PayPal plans: {str(e)}")
1231
+ return JSONResponse(
1232
+ status_code=500,
1233
+ content={"success": False, "detail": f"Error initializing plans: {str(e)}"}
1234
+ )
1235
+
1236
+
1237
+ @app.post("/subscription/verify")
1238
+ async def verify_subscription(request: Request, current_user: User = Depends(get_current_active_user)):
1239
+ """Verify a subscription after payment"""
1240
+ try:
1241
+ data = await request.json()
1242
+ subscription_id = data.get("subscription_id")
1243
+
1244
+ if not subscription_id:
1245
+ return JSONResponse(
1246
+ status_code=400,
1247
+ content={"success": False, "detail": "Subscription ID is required"}
1248
+ )
1249
+
1250
+ logger.info(f"Verifying subscription: {subscription_id}")
1251
+
1252
+ # Verify the subscription with PayPal
1253
+ success, result = verify_paypal_subscription(subscription_id)
1254
+
1255
+ if not success:
1256
+ logger.error(f"Subscription verification failed: {result}")
1257
+ return JSONResponse(
1258
+ status_code=400,
1259
+ content={"success": False, "detail": str(result)}
1260
+ )
1261
+
1262
+ # Update the user's subscription in the database
1263
+ conn = get_db_connection()
1264
+ cursor = conn.cursor()
1265
+
1266
+ # Get the subscription details
1267
+ cursor.execute(
1268
+ "SELECT tier FROM subscriptions WHERE paypal_subscription_id = ?",
1269
+ (subscription_id,)
1270
+ )
1271
+ subscription = cursor.fetchone()
1272
+
1273
+ if not subscription:
1274
+ # This is a new subscription, get the tier from the PayPal response
1275
+ tier = "standard_tier" # Default to standard tier
1276
+ # You could extract the tier from the PayPal plan ID if needed
1277
+
1278
+ # Create a new subscription record
1279
+ sub_id = str(uuid.uuid4())
1280
+ start_date = datetime.now()
1281
+ expires_at = start_date + timedelta(days=30)
1282
+
1283
+ cursor.execute(
1284
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
1285
+ (sub_id, current_user.id, tier, "active", start_date, expires_at, subscription_id)
1286
+ )
1287
+ else:
1288
+ # Update existing subscription
1289
+ tier = subscription[0]
1290
+ cursor.execute(
1291
+ "UPDATE subscriptions SET status = 'active' WHERE paypal_subscription_id = ?",
1292
+ (subscription_id,)
1293
+ )
1294
+
1295
+ # Update user's subscription tier
1296
+ cursor.execute(
1297
+ "UPDATE users SET subscription_tier = ? WHERE id = ?",
1298
+ (tier, current_user.id)
1299
+ )
1300
+
1301
+ conn.commit()
1302
+ conn.close()
1303
+
1304
+ return JSONResponse(
1305
+ status_code=200,
1306
+ content={"success": True, "detail": "Subscription verified successfully"}
1307
+ )
1308
+
1309
+ except Exception as e:
1310
+ logger.error(f"Error verifying subscription: {str(e)}")
1311
+ return JSONResponse(
1312
+ status_code=500,
1313
+ content={"success": False, "detail": f"Error verifying subscription: {str(e)}"}
1314
+ )
1315
+
1316
+ @app.post("/subscription/webhook")
1317
+ async def subscription_webhook(request: Request):
1318
+ """Handle PayPal subscription webhooks"""
1319
+ try:
1320
+ payload = await request.json()
1321
+ success, result = handle_subscription_webhook(payload)
1322
+
1323
+ if not success:
1324
+ logger.error(f"Webhook processing failed: {result}")
1325
+ return {"status": "error", "message": result}
1326
+
1327
+ return {"status": "success", "message": result}
1328
+ except Exception as e:
1329
+ logger.error(f"Error processing webhook: {str(e)}")
1330
+ return {"status": "error", "message": f"Error processing webhook: {str(e)}"}
1331
+
1332
+ @app.get("/subscription/verify/{subscription_id}")
1333
+ async def verify_subscription(
1334
+ subscription_id: str,
1335
+ current_user: User = Depends(get_current_active_user)
1336
+ ):
1337
+ """Verify a subscription payment and update user tier"""
1338
+ try:
1339
+ # Verify the subscription
1340
+ success, result = verify_subscription_payment(subscription_id)
1341
+
1342
+ if not success:
1343
+ raise HTTPException(status_code=400, detail=f"Subscription verification failed: {result}")
1344
+
1345
+ # Get the plan ID from the subscription to determine tier
1346
+ plan_id = result.get("plan_id", "")
1347
+
1348
+ # Connect to DB to get the tier for this plan
1349
+ conn = get_db_connection()
1350
+ cursor = conn.cursor()
1351
+ cursor.execute("SELECT tier FROM paypal_plans WHERE plan_id = ?", (plan_id,))
1352
+ tier_result = cursor.fetchone()
1353
+ conn.close()
1354
+
1355
+ if not tier_result:
1356
+ raise HTTPException(status_code=400, detail="Could not determine subscription tier")
1357
+
1358
+ tier = tier_result[0]
1359
+
1360
+ # Update the user's subscription
1361
+ success, update_result = update_user_subscription(current_user.email, subscription_id, tier)
1362
+
1363
+ if not success:
1364
+ raise HTTPException(status_code=500, detail=f"Failed to update subscription: {update_result}")
1365
+
1366
+ return {
1367
+ "message": f"Successfully subscribed to {tier} tier",
1368
+ "subscription_id": subscription_id,
1369
+ "status": result.get("status", ""),
1370
+ "next_billing_time": result.get("billing_info", {}).get("next_billing_time", "")
1371
+ }
1372
+
1373
+ except HTTPException:
1374
+ raise
1375
+ except Exception as e:
1376
+ print(f"Subscription verification error: {str(e)}")
1377
+ raise HTTPException(status_code=500, detail=f"Subscription verification failed: {str(e)}")
1378
+
1379
+ @app.post("/webhook/paypal")
1380
+ async def paypal_webhook(request: Request):
1381
+ """Handle PayPal subscription webhooks"""
1382
+ try:
1383
+ payload = await request.json()
1384
+ logger.info(f"Received PayPal webhook: {payload.get('event_type', 'unknown event')}")
1385
+
1386
+ # Process the webhook
1387
+ result = handle_subscription_webhook(payload)
1388
+
1389
+ return {"status": "success", "message": "Webhook processed"}
1390
+ except Exception as e:
1391
+ logger.error(f"Webhook processing error: {str(e)}")
1392
+ # Return 200 even on error to acknowledge receipt to PayPal
1393
+ return {"status": "error", "message": str(e)}
1394
+
1395
+ # Add this to your startup code
1396
+ @app.on_event("startup")
1397
+ async def startup_event():
1398
+ """Initialize subscription plans on startup"""
1399
+ try:
1400
+ # Initialize PayPal subscription plans if needed
1401
+ # If you have an initialize_subscription_plans function in your paypal_integration.py,
1402
+ # you can call it here
1403
+ print("Application started successfully")
1404
+ except Exception as e:
1405
+ print(f"Error during startup: {str(e)}")
1406
+
1407
+ if __name__ == "__main__":
1408
+ uvicorn.run("app:app", host="0.0.0.0", port=8500, reload=True)
auth.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import uuid
3
+ import os
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ import hashlib # Use hashlib instead of jwt
7
+ from passlib.hash import bcrypt
8
+ from dotenv import load_dotenv
9
+ from fastapi import Depends, HTTPException
10
+ from fastapi.security import OAuth2PasswordBearer
11
+ from pydantic import BaseModel
12
+ from typing import Optional
13
+ from fastapi import HTTPException, status
14
+ import jwt
15
+ from jwt.exceptions import PyJWTError
16
+ import sqlite3
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger('auth')
27
+
28
+ # Security configuration
29
+ SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only")
30
+ ALGORITHM = "HS256"
31
+ JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day
32
+ # Database path from environment variable or default
33
+ # Fix the incorrect DB_PATH
34
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
35
+
36
+ # FastAPI OAuth2 scheme
37
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
38
+
39
+ # Pydantic models for FastAPI
40
+ class User(BaseModel):
41
+ id: str
42
+ email: str
43
+ subscription_tier: str = "free_tier"
44
+ subscription_expiry: Optional[datetime] = None
45
+ api_calls_remaining: int = 5
46
+ last_reset_date: Optional[datetime] = None
47
+
48
+ class UserCreate(BaseModel):
49
+ email: str
50
+ password: str
51
+
52
+ class Token(BaseModel):
53
+ access_token: str
54
+ token_type: str
55
+
56
+ class TokenData(BaseModel):
57
+ user_id: Optional[str] = None
58
+
59
+ # Subscription tiers and limits
60
+ # Update the SUBSCRIPTION_TIERS dictionary
61
+ SUBSCRIPTION_TIERS = {
62
+ "free_tier": {
63
+ "price": 0,
64
+ "currency": "INR",
65
+ "features": ["basic_document_analysis", "basic_risk_assessment"],
66
+ "limits": {
67
+ "document_size_mb": 5,
68
+ "documents_per_month": 3,
69
+ "video_size_mb": 0,
70
+ "audio_size_mb": 0,
71
+ "daily_api_calls": 10, # <-- Add this
72
+ "max_document_size_mb": 5 # <-- Add this
73
+ }
74
+ },
75
+ "standard_tier": {
76
+ "price": 799,
77
+ "currency": "INR",
78
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"],
79
+ "limits": {
80
+ "document_size_mb": 20,
81
+ "documents_per_month": 20,
82
+ "video_size_mb": 100,
83
+ "audio_size_mb": 50,
84
+ "daily_api_calls": 100, # <-- Add this
85
+ "max_document_size_mb": 20 # <-- Add this
86
+ }
87
+ },
88
+ "premium_tier": {
89
+ "price": 1499,
90
+ "currency": "INR",
91
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"],
92
+ "limits": {
93
+ "document_size_mb": 50,
94
+ "documents_per_month": 999999,
95
+ "video_size_mb": 500,
96
+ "audio_size_mb": 200,
97
+ "daily_api_calls": 1000, # <-- Add this
98
+ "max_document_size_mb": 50 # <-- Add this
99
+ }
100
+ }
101
+ }
102
+
103
+ # Database connection management
104
+ def get_db_connection():
105
+ """Create and return a database connection with proper error handling"""
106
+ try:
107
+ # Ensure the directory exists
108
+ db_dir = os.path.dirname(DB_PATH)
109
+ os.makedirs(db_dir, exist_ok=True)
110
+
111
+ conn = sqlite3.connect(DB_PATH)
112
+ conn.row_factory = sqlite3.Row # Return rows as dictionaries
113
+ return conn
114
+ except sqlite3.Error as e:
115
+ logger.error(f"Database connection error: {e}")
116
+ raise Exception(f"Database connection failed: {e}")
117
+
118
+ # Database setup
119
+ # In the init_auth_db function, update the CREATE TABLE statement to match our schema
120
+ def init_auth_db():
121
+ """Initialize the authentication database with required tables"""
122
+ try:
123
+ conn = get_db_connection()
124
+ c = conn.cursor()
125
+
126
+ # Create users table with the correct schema
127
+ c.execute('''
128
+ CREATE TABLE IF NOT EXISTS users (
129
+ id TEXT PRIMARY KEY,
130
+ email TEXT UNIQUE NOT NULL,
131
+ hashed_password TEXT NOT NULL,
132
+ password TEXT,
133
+ subscription_tier TEXT DEFAULT 'free_tier',
134
+ is_active BOOLEAN DEFAULT 1,
135
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
136
+ api_calls_remaining INTEGER DEFAULT 10,
137
+ last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
138
+ )
139
+ ''')
140
+
141
+ # Create subscriptions table
142
+ c.execute('''
143
+ CREATE TABLE IF NOT EXISTS subscriptions (
144
+ id TEXT PRIMARY KEY,
145
+ user_id TEXT,
146
+ tier TEXT,
147
+ plan_id TEXT,
148
+ status TEXT,
149
+ created_at TIMESTAMP,
150
+ expires_at TIMESTAMP,
151
+ paypal_subscription_id TEXT,
152
+ FOREIGN KEY (user_id) REFERENCES users (id)
153
+ )
154
+ ''')
155
+
156
+ # Create usage stats table
157
+ c.execute('''
158
+ CREATE TABLE IF NOT EXISTS usage_stats (
159
+ id TEXT PRIMARY KEY,
160
+ user_id TEXT,
161
+ month INTEGER,
162
+ year INTEGER,
163
+ analyses_used INTEGER,
164
+ FOREIGN KEY (user_id) REFERENCES users (id)
165
+ )
166
+ ''')
167
+
168
+ # Create tokens table for refresh tokens
169
+ c.execute('''
170
+ CREATE TABLE IF NOT EXISTS refresh_tokens (
171
+ user_id TEXT,
172
+ token TEXT,
173
+ expires_at TIMESTAMP,
174
+ FOREIGN KEY (user_id) REFERENCES users (id)
175
+ )
176
+ ''')
177
+
178
+ conn.commit()
179
+ logger.info("Database initialized successfully")
180
+ except Exception as e:
181
+ logger.error(f"Database initialization error: {e}")
182
+ raise
183
+ finally:
184
+ if conn:
185
+ conn.close()
186
+
187
+ # Initialize the database
188
+ init_auth_db()
189
+
190
+ # Password hashing with bcrypt
191
+ # Update the password hashing and verification functions to use a more reliable method
192
+
193
+ # Replace these functions
194
+ # Remove these conflicting functions
195
+ # def hash_password(password):
196
+ # """Hash a password using bcrypt"""
197
+ # return bcrypt.hash(password)
198
+ #
199
+ # def verify_password(plain_password, hashed_password):
200
+ # """Verify a password against its hash"""
201
+ # return bcrypt.verify(plain_password, hashed_password)
202
+
203
+ # Keep only these improved functions
204
+ def hash_password(password):
205
+ """Hash a password using bcrypt"""
206
+ # Use a more direct approach to avoid bcrypt version issues
207
+ import bcrypt
208
+ # Convert password to bytes if it's not already
209
+ if isinstance(password, str):
210
+ password = password.encode('utf-8')
211
+ # Generate salt and hash
212
+ salt = bcrypt.gensalt()
213
+ hashed = bcrypt.hashpw(password, salt)
214
+ # Return as string for storage
215
+ return hashed.decode('utf-8')
216
+
217
+ def verify_password(plain_password, hashed_password):
218
+ """Verify a password against its hash"""
219
+ import bcrypt
220
+ # Convert inputs to bytes if they're not already
221
+ if isinstance(plain_password, str):
222
+ plain_password = plain_password.encode('utf-8')
223
+ if isinstance(hashed_password, str):
224
+ hashed_password = hashed_password.encode('utf-8')
225
+
226
+ try:
227
+ # Use direct bcrypt verification
228
+ return bcrypt.checkpw(plain_password, hashed_password)
229
+ except Exception as e:
230
+ logger.error(f"Password verification error: {e}")
231
+ return False
232
+
233
+ # User registration
234
+ def register_user(email, password):
235
+ try:
236
+ conn = get_db_connection()
237
+ c = conn.cursor()
238
+
239
+ # Check if user already exists
240
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
241
+ if c.fetchone():
242
+ return False, "Email already registered"
243
+
244
+ # Create new user
245
+ user_id = str(uuid.uuid4())
246
+
247
+ # Add more detailed logging
248
+ logger.info(f"Registering new user with email: {email}")
249
+ hashed_pw = hash_password(password)
250
+ logger.info(f"Password hashed successfully: {bool(hashed_pw)}")
251
+
252
+ c.execute("""
253
+ INSERT INTO users
254
+ (id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date)
255
+ VALUES (?, ?, ?, ?, ?, ?)
256
+ """, (user_id, email, hashed_pw, "free_tier", 5, datetime.now()))
257
+
258
+ conn.commit()
259
+ logger.info(f"User registered successfully: {email}")
260
+
261
+ # Verify the user was actually stored
262
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
263
+ stored_user = c.fetchone()
264
+ logger.info(f"User verification after registration: {bool(stored_user)}")
265
+
266
+ access_token = create_access_token(user_id)
267
+ return True, {
268
+ "user_id": user_id,
269
+ "access_token": access_token,
270
+ "token_type": "bearer"
271
+ }
272
+ except Exception as e:
273
+ logger.error(f"User registration error: {e}")
274
+ return False, f"Registration failed: {str(e)}"
275
+ finally:
276
+ if conn:
277
+ conn.close()
278
+
279
+ # User login
280
+ # Fix the authenticate_user function
281
+ # In the authenticate_user function, update the password verification to use hashed_password
282
+ def authenticate_user(email, password):
283
+ """Authenticate a user and return user data with tokens"""
284
+ try:
285
+ conn = get_db_connection()
286
+ c = conn.cursor()
287
+
288
+ # Get user by email
289
+ c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,))
290
+ user = c.fetchone()
291
+
292
+ if not user:
293
+ logger.warning(f"User not found: {email}")
294
+ return None
295
+
296
+ # Add debug logging for password verification
297
+ logger.info(f"Verifying password for user: {email}")
298
+ logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...")
299
+
300
+ try:
301
+ # Check if password verification works
302
+ is_valid = verify_password(password, user['hashed_password'])
303
+ logger.info(f"Password verification result: {is_valid}")
304
+
305
+ if not is_valid:
306
+ logger.warning(f"Password verification failed for user: {email}")
307
+ return None
308
+ except Exception as e:
309
+ logger.error(f"Password verification error: {e}")
310
+ return None
311
+
312
+ # Update last login time if column exists
313
+ try:
314
+ c.execute("UPDATE users SET last_login = ? WHERE id = ?",
315
+ (datetime.now(), user['id']))
316
+ conn.commit()
317
+ except sqlite3.OperationalError:
318
+ # last_login column might not exist
319
+ pass
320
+
321
+ # Convert sqlite3.Row to dict to use get() method
322
+ user_dict = dict(user)
323
+
324
+ # Create and return a User object
325
+ return User(
326
+ id=user_dict['id'],
327
+ email=user_dict['email'],
328
+ subscription_tier=user_dict.get('subscription_tier', 'free_tier'),
329
+ subscription_expiry=None, # Handle this properly if needed
330
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
331
+ last_reset_date=user_dict.get('last_reset_date')
332
+ )
333
+ except Exception as e:
334
+ logger.error(f"Login error: {e}")
335
+ return None
336
+ finally:
337
+ if conn:
338
+ conn.close()
339
+
340
+ # Token generation and validation - completely replaced
341
+ def create_access_token(user_id):
342
+ """Create a new access token for a user"""
343
+ try:
344
+ # Create a JWT token with user_id and expiration
345
+ expiration = datetime.now() + JWT_EXPIRATION_DELTA
346
+
347
+ # Create a token payload
348
+ payload = {
349
+ "sub": user_id,
350
+ "exp": expiration.timestamp()
351
+ }
352
+
353
+ # Generate the JWT token
354
+ token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
355
+
356
+ logger.info(f"Created access token for user: {user_id}")
357
+ return token
358
+ except Exception as e:
359
+ logger.error(f"Token creation error: {e}")
360
+ return None
361
+
362
+
363
+ def update_auth_db_schema():
364
+ """Update the authentication database schema with any missing columns"""
365
+ try:
366
+ conn = get_db_connection()
367
+ c = conn.cursor()
368
+
369
+ # Check if tier column exists in subscriptions table
370
+ c.execute("PRAGMA table_info(subscriptions)")
371
+ columns = [column[1] for column in c.fetchall()]
372
+
373
+ # Add tier column if it doesn't exist
374
+ if "tier" not in columns:
375
+ logger.info("Adding 'tier' column to subscriptions table")
376
+ c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT")
377
+ conn.commit()
378
+ logger.info("Database schema updated successfully")
379
+
380
+ conn.close()
381
+ except Exception as e:
382
+ logger.error(f"Database schema update error: {e}")
383
+ raise HTTPException(
384
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
385
+ detail=f"Database schema update error: {str(e)}"
386
+ )
387
+
388
+ # Add this to your get_current_user function
389
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
390
+ credentials_exception = HTTPException(
391
+ status_code=status.HTTP_401_UNAUTHORIZED,
392
+ detail="Could not validate credentials",
393
+ headers={"WWW-Authenticate": "Bearer"},
394
+ )
395
+ try:
396
+ # Decode the JWT token
397
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
398
+ user_id: str = payload.get("sub")
399
+ if user_id is None:
400
+ logger.error("Token missing 'sub' field")
401
+ raise credentials_exception
402
+ except Exception as e:
403
+ logger.error(f"Token validation error: {str(e)}")
404
+ raise credentials_exception
405
+
406
+ # Get user from database
407
+ conn = get_db_connection()
408
+ cursor = conn.cursor()
409
+ cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,))
410
+ user_data = cursor.fetchone()
411
+ conn.close()
412
+
413
+ if user_data is None:
414
+ logger.error(f"User not found: {user_id}")
415
+ raise credentials_exception
416
+
417
+ user = User(
418
+ id=user_data[0],
419
+ email=user_data[1],
420
+ subscription_tier=user_data[2],
421
+ is_active=bool(user_data[3])
422
+ )
423
+
424
+ return user
425
+
426
+ async def get_current_active_user(current_user: User = Depends(get_current_user)):
427
+ """Get the current active user"""
428
+ return current_user
429
+
430
+ def create_user_subscription(email, tier):
431
+ """Create a subscription for a user"""
432
+ try:
433
+ # Get user by email
434
+ conn = get_db_connection()
435
+ c = conn.cursor()
436
+
437
+ # Get user ID
438
+ c.execute("SELECT id FROM users WHERE email = ?", (email,))
439
+ user_data = c.fetchone()
440
+
441
+ if not user_data:
442
+ return False, "User not found"
443
+
444
+ user_id = user_data['id']
445
+
446
+ # Check if tier is valid
447
+ valid_tiers = ["standard_tier", "premium_tier"]
448
+ if tier not in valid_tiers:
449
+ return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
450
+
451
+ # Create subscription
452
+ subscription_id = str(uuid.uuid4())
453
+ created_at = datetime.now()
454
+ expires_at = created_at + timedelta(days=30) # 30-day subscription
455
+
456
+ # Insert subscription
457
+ c.execute("""
458
+ INSERT INTO subscriptions
459
+ (id, user_id, tier, status, created_at, expires_at)
460
+ VALUES (?, ?, ?, ?, ?, ?)
461
+ """, (subscription_id, user_id, tier, "active", created_at, expires_at))
462
+
463
+ # Update user's subscription tier
464
+ c.execute("""
465
+ UPDATE users
466
+ SET subscription_tier = ?
467
+ WHERE id = ?
468
+ """, (tier, user_id))
469
+
470
+ conn.commit()
471
+
472
+ return True, {
473
+ "id": subscription_id,
474
+ "user_id": user_id,
475
+ "tier": tier,
476
+ "status": "active",
477
+ "created_at": created_at.isoformat(),
478
+ "expires_at": expires_at.isoformat()
479
+ }
480
+ except Exception as e:
481
+ logger.error(f"Subscription creation error: {e}")
482
+ return False, f"Failed to create subscription: {str(e)}"
483
+ finally:
484
+ if conn:
485
+ conn.close()
486
+
487
+ def get_user(user_id: str):
488
+ """Get user by ID"""
489
+ try:
490
+ conn = get_db_connection()
491
+ c = conn.cursor()
492
+
493
+ # Get user
494
+ c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,))
495
+ user_data = c.fetchone()
496
+
497
+ if not user_data:
498
+ return None
499
+
500
+ # Convert to User model
501
+ user_dict = dict(user_data)
502
+
503
+ # Handle datetime conversions if needed
504
+ if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str):
505
+ user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"])
506
+ if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str):
507
+ user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"])
508
+
509
+ return User(
510
+ id=user_dict['id'],
511
+ email=user_dict['email'],
512
+ subscription_tier=user_dict['subscription_tier'],
513
+ subscription_expiry=user_dict.get('subscription_expiry'),
514
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
515
+ last_reset_date=user_dict.get('last_reset_date')
516
+ )
517
+ except Exception as e:
518
+ logger.error(f"Get user error: {e}")
519
+ return None
520
+ finally:
521
+ if conn:
522
+ conn.close()
523
+
524
+ def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None):
525
+ """Check if the user has access to the requested feature and file size"""
526
+ # Check if subscription is expired
527
+ if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now():
528
+ # Downgrade to free tier if subscription expired
529
+ user.subscription_tier = "free_tier"
530
+ user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["daily_api_calls"]
531
+ with get_db_connection() as conn:
532
+ c = conn.cursor()
533
+ c.execute("""
534
+ UPDATE users
535
+ SET subscription_tier = ?, api_calls_remaining = ?
536
+ WHERE id = ?
537
+ """, (user.subscription_tier, user.api_calls_remaining, user.id))
538
+ conn.commit()
539
+
540
+ # Reset API calls if needed
541
+ user = reset_api_calls_if_needed(user)
542
+
543
+ # Check if user has API calls remaining
544
+ if user.api_calls_remaining <= 0:
545
+ raise HTTPException(
546
+ status_code=429,
547
+ detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow."
548
+ )
549
+
550
+ # Check if feature is available in user's subscription tier
551
+ tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"]
552
+ if feature not in tier_features:
553
+ raise HTTPException(
554
+ status_code=403,
555
+ detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature."
556
+ )
557
+
558
+ # Check file size limit if applicable
559
+ if file_size_mb:
560
+ max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["max_document_size_mb"]
561
+ if file_size_mb > max_size:
562
+ raise HTTPException(
563
+ status_code=413,
564
+ detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file."
565
+ )
566
+
567
+ # Decrement API calls remaining
568
+ user.api_calls_remaining -= 1
569
+ with get_db_connection() as conn:
570
+ c = conn.cursor()
571
+ c.execute("""
572
+ UPDATE users
573
+ SET api_calls_remaining = ?
574
+ WHERE id = ?
575
+ """, (user.api_calls_remaining, user.id))
576
+ conn.commit()
577
+
578
+ return True
579
+
580
+ def reset_api_calls_if_needed(user: User):
581
+ """Reset API call counter if it's a new day"""
582
+ today = datetime.now().date()
583
+ if user.last_reset_date is None or user.last_reset_date.date() < today:
584
+ tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier]
585
+ user.api_calls_remaining = tier_limits["daily_api_calls"]
586
+ user.last_reset_date = datetime.now()
587
+ # Update the user in the database
588
+ with get_db_connection() as conn:
589
+ c = conn.cursor()
590
+ c.execute("""
591
+ UPDATE users
592
+ SET api_calls_remaining = ?, last_reset_date = ?
593
+ WHERE id = ?
594
+ """, (user.api_calls_remaining, user.last_reset_date, user.id))
595
+ conn.commit()
596
+
597
+ return user
598
+
599
+ def login_user(email, password):
600
+ """Login a user with email and password"""
601
+ try:
602
+ # Authenticate user
603
+ user = authenticate_user(email, password)
604
+ if not user:
605
+ return False, "Incorrect username or password"
606
+
607
+ # Create access token
608
+ access_token = create_access_token(user.id)
609
+
610
+ # Create refresh token
611
+ refresh_token = str(uuid.uuid4())
612
+ expires_at = datetime.now() + timedelta(days=30)
613
+
614
+ # Store refresh token
615
+ conn = get_db_connection()
616
+ c = conn.cursor()
617
+ c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)",
618
+ (user.id, refresh_token, expires_at))
619
+ conn.commit()
620
+
621
+ # Get subscription info
622
+ c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,))
623
+ subscription = c.fetchone()
624
+
625
+ # Convert subscription to dict if it exists, otherwise set to None
626
+ subscription_dict = dict(subscription) if subscription else None
627
+
628
+ conn.close()
629
+
630
+ return True, {
631
+ "user_id": user.id,
632
+ "email": user.email,
633
+ "access_token": access_token,
634
+ "refresh_token": refresh_token,
635
+ "subscription": subscription_dict
636
+ }
637
+ except Exception as e:
638
+ logger.error(f"Login error: {e}")
639
+ return False, f"Login failed: {str(e)}"
640
+
641
+
642
+ def get_subscription_plans():
643
+ """
644
+ Returns a list of available subscription plans based on SUBSCRIPTION_TIERS.
645
+ """
646
+ plans = []
647
+ for tier, details in SUBSCRIPTION_TIERS.items():
648
+ plans.append({
649
+ "tier": tier,
650
+ "price": details["price"],
651
+ "currency": details["currency"],
652
+ "features": details["features"],
653
+ "limits": details["limits"]
654
+ })
655
+ return plans
fix_users_table.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+ import uuid
4
+ import datetime
5
+
6
+ # Define both database paths
7
+ DB_PATH_1 = os.path.join(os.path.dirname(__file__), "../data/user_data.db")
8
+ DB_PATH_2 = os.path.join(os.path.dirname(__file__), "data/user_data.db")
9
+
10
+ # Define the function to create users table
11
+ # Make sure the create_users_table function allows NULL for hashed_password temporarily
12
+ def create_users_table(cursor):
13
+ """Create the users table with all required columns"""
14
+ cursor.execute('''
15
+ CREATE TABLE users (
16
+ id TEXT PRIMARY KEY,
17
+ email TEXT UNIQUE NOT NULL,
18
+ hashed_password TEXT DEFAULT 'temp_hash_for_migration',
19
+ password TEXT,
20
+ subscription_tier TEXT DEFAULT 'free',
21
+ is_active BOOLEAN DEFAULT 1,
22
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
23
+ api_calls_remaining INTEGER DEFAULT 10,
24
+ last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
25
+ )
26
+ ''')
27
+
28
+ # Update the CREATE TABLE statement to include all necessary columns
29
+ def fix_users_table(db_path):
30
+ # Make sure the data directory exists
31
+ data_dir = os.path.dirname(db_path)
32
+ if not os.path.exists(data_dir):
33
+ print(f"Creating data directory: {data_dir}")
34
+ os.makedirs(data_dir, exist_ok=True)
35
+
36
+ if not os.path.exists(db_path):
37
+ print(f"Database does not exist at: {os.path.abspath(db_path)}")
38
+ return False
39
+
40
+ print(f"Using database path: {os.path.abspath(db_path)}")
41
+
42
+ # Connect to the database
43
+ conn = sqlite3.connect(db_path)
44
+ cursor = conn.cursor()
45
+
46
+ # Check if users table exists
47
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'")
48
+ if cursor.fetchone():
49
+ print("Users table exists, checking schema...")
50
+
51
+ # Check columns
52
+ cursor.execute("PRAGMA table_info(users)")
53
+ columns_info = cursor.fetchall()
54
+ columns = [column[1] for column in columns_info]
55
+
56
+ # List of all required columns
57
+ required_columns = ['id', 'email', 'hashed_password', 'password', 'subscription_tier',
58
+ 'is_active', 'created_at', 'api_calls_remaining', 'last_reset_date']
59
+
60
+ # Check if any required column is missing
61
+ missing_columns = [col for col in required_columns if col not in columns]
62
+
63
+ if missing_columns:
64
+ print(f"Schema needs fixing. Missing columns: {', '.join(missing_columns)}")
65
+
66
+ # Dynamically build the SELECT query based on available columns
67
+ available_columns = [col for col in columns if col != 'id'] # Exclude id as we'll generate new ones
68
+
69
+ if not available_columns:
70
+ print("No usable columns found in users table, creating new table...")
71
+ cursor.execute("DROP TABLE users")
72
+ create_users_table(cursor)
73
+ print("Created new empty users table with correct schema")
74
+ else:
75
+ # Backup existing users with available columns
76
+ select_query = f"SELECT {', '.join(available_columns)} FROM users"
77
+ print(f"Backing up users with query: {select_query}")
78
+ cursor.execute(select_query)
79
+ existing_users = cursor.fetchall()
80
+
81
+ # Drop the existing table
82
+ cursor.execute("DROP TABLE users")
83
+
84
+ # Create the table with the correct schema
85
+ create_users_table(cursor)
86
+
87
+ # Restore the users with new UUIDs for IDs
88
+ if existing_users:
89
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
90
+ for user in existing_users:
91
+ user_id = str(uuid.uuid4())
92
+
93
+ # Create a dictionary to map column names to values
94
+ user_data = {'id': user_id}
95
+ for i, col in enumerate(available_columns):
96
+ user_data[col] = user[i]
97
+
98
+ # Set default values for missing columns
99
+ # Add a default value for hashed_password in the Set default values section
100
+ if 'hashed_password' not in user_data:
101
+ user_data['hashed_password'] = 'temp_hash_for_migration' # Temporary hash for migration
102
+ if 'subscription_tier' not in user_data:
103
+ user_data['subscription_tier'] = 'free'
104
+ if 'is_active' not in user_data:
105
+ user_data['is_active'] = 1
106
+ if 'created_at' not in user_data:
107
+ user_data['created_at'] = current_time
108
+ if 'api_calls_remaining' not in user_data:
109
+ user_data['api_calls_remaining'] = 10
110
+ if 'last_reset_date' not in user_data:
111
+ user_data['last_reset_date'] = current_time
112
+
113
+ # Build INSERT query with all required columns
114
+ insert_columns = ['id']
115
+ insert_values = [user_id]
116
+
117
+ # Add values for columns that exist in the old table
118
+ for col in available_columns:
119
+ insert_columns.append(col)
120
+ insert_values.append(user_data[col])
121
+
122
+ # Add default values for columns that don't exist in the old table
123
+ for col in required_columns:
124
+ # Add hashed_password to the column default values section
125
+ if col not in ['id'] + available_columns:
126
+ insert_columns.append(col)
127
+ if col == 'subscription_tier':
128
+ insert_values.append('free')
129
+ elif col == 'is_active':
130
+ insert_values.append(1)
131
+ elif col == 'created_at':
132
+ insert_values.append(current_time)
133
+ elif col == 'api_calls_remaining':
134
+ insert_values.append(10)
135
+ elif col == 'last_reset_date':
136
+ insert_values.append(current_time)
137
+ elif col == 'hashed_password':
138
+ insert_values.append('temp_hash_for_migration') # Temporary hash for migration
139
+ else:
140
+ insert_values.append(None) # Default to NULL for other columns
141
+
142
+ placeholders = ', '.join(['?'] * len(insert_columns))
143
+ insert_query = f"INSERT INTO users ({', '.join(insert_columns)}) VALUES ({placeholders})"
144
+
145
+ cursor.execute(insert_query, insert_values)
146
+
147
+ print(f"Fixed users table, restored {len(existing_users)} users")
148
+ else:
149
+ print("Users table schema is correct")
150
+ else:
151
+ print("Users table doesn't exist, creating it now...")
152
+ create_users_table(cursor)
153
+ print("Users table created successfully")
154
+
155
+ # Commit changes and close connection
156
+ conn.commit()
157
+ conn.close()
158
+ return True
159
+
160
+ if __name__ == "__main__":
161
+ print("Checking first database location...")
162
+ success1 = fix_users_table(DB_PATH_1)
163
+
164
+ print("\nChecking second database location...")
165
+ success2 = fix_users_table(DB_PATH_2)
166
+
167
+ if not (success1 or success2):
168
+ print("\nWarning: Could not find any existing database files.")
169
+ print("Creating a new database at the primary location...")
170
+ # Create a new database at the primary location
171
+ data_dir = os.path.dirname(DB_PATH_1)
172
+ if not os.path.exists(data_dir):
173
+ os.makedirs(data_dir, exist_ok=True)
174
+
175
+ conn = sqlite3.connect(DB_PATH_1)
176
+ cursor = conn.cursor()
177
+ create_users_table(cursor)
178
+ conn.commit()
179
+ conn.close()
180
+ print(f"Created new database at: {os.path.abspath(DB_PATH_1)}")
initialize_plans.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from dotenv import load_dotenv
4
+ from paypal_integration import initialize_subscription_plans
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ def main():
10
+ """Initialize PayPal subscription plans"""
11
+ print("Initializing PayPal subscription plans...")
12
+ plans = initialize_subscription_plans()
13
+
14
+ if plans:
15
+ print("✅ Plans initialized successfully:")
16
+ for tier, plan_id in plans.items():
17
+ print(f" - {tier}: {plan_id}")
18
+ return True
19
+ else:
20
+ print("❌ Failed to initialize plans. Check the logs for details.")
21
+ return False
22
+
23
+ if __name__ == "__main__":
24
+ success = main()
25
+ sys.exit(0 if success else 1)
legal_analysis.db ADDED
Binary file (28.7 kB). View file
 
paypal_integration.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import sqlite3
4
+ from datetime import datetime, timedelta
5
+ import uuid
6
+ import os
7
+ import logging
8
+ from requests.adapters import HTTPAdapter
9
+ from requests.packages.urllib3.util.retry import Retry
10
+ from auth import get_db_connection
11
+ from dotenv import load_dotenv
12
+
13
+ # PayPal API Configuration - Remove default values for production
14
+ PAYPAL_CLIENT_ID = os.getenv("PAYPAL_CLIENT_ID")
15
+ PAYPAL_SECRET = os.getenv("PAYPAL_SECRET")
16
+ PAYPAL_BASE_URL = os.getenv("PAYPAL_BASE_URL", "https://api-m.sandbox.paypal.com")
17
+
18
+ # Add validation to ensure credentials are provided
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler(os.path.join(os.path.dirname(__file__), "../logs/paypal.log")),
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+ logger = logging.getLogger("paypal_integration")
29
+
30
+ # Then replace print statements with logger calls
31
+ # For example:
32
+ if not PAYPAL_CLIENT_ID or not PAYPAL_SECRET:
33
+ logger.warning("PayPal credentials not found in environment variables")
34
+
35
+
36
+ # Get PayPal access token
37
+ # Add better error handling for production
38
+ # Create a session with retry capability
39
+ def create_retry_session(retries=3, backoff_factor=0.3):
40
+ session = requests.Session()
41
+ retry = Retry(
42
+ total=retries,
43
+ read=retries,
44
+ connect=retries,
45
+ backoff_factor=backoff_factor,
46
+ status_forcelist=[500, 502, 503, 504],
47
+ )
48
+ adapter = HTTPAdapter(max_retries=retry)
49
+ session.mount('http://', adapter)
50
+ session.mount('https://', adapter)
51
+ return session
52
+
53
+ # Then use this session for API calls
54
+ # Replace get_access_token with logger instead of print
55
+ def get_access_token():
56
+ url = f"{PAYPAL_BASE_URL}/v1/oauth2/token"
57
+ headers = {
58
+ "Accept": "application/json",
59
+ "Accept-Language": "en_US"
60
+ }
61
+ data = "grant_type=client_credentials"
62
+
63
+ try:
64
+ session = create_retry_session()
65
+ response = session.post(
66
+ url,
67
+ auth=(PAYPAL_CLIENT_ID, PAYPAL_SECRET),
68
+ headers=headers,
69
+ data=data
70
+ )
71
+
72
+ if response.status_code == 200:
73
+ return response.json()["access_token"]
74
+ else:
75
+ logger.error(f"Error getting access token: {response.status_code}")
76
+ return None
77
+ except Exception as e:
78
+ logger.error(f"Exception in get_access_token: {str(e)}")
79
+ return None
80
+
81
+ def call_paypal_api(endpoint, method="GET", data=None, token=None):
82
+ """
83
+ Helper function to make PayPal API calls
84
+
85
+ Args:
86
+ endpoint: API endpoint (without base URL)
87
+ method: HTTP method (GET, POST, etc.)
88
+ data: Request payload (for POST/PUT)
89
+ token: PayPal access token (will be fetched if None)
90
+
91
+ Returns:
92
+ tuple: (success, response_data or error_message)
93
+ """
94
+ try:
95
+ if not token:
96
+ token = get_access_token()
97
+ if not token:
98
+ return False, "Failed to get PayPal access token"
99
+
100
+ url = f"{PAYPAL_BASE_URL}{endpoint}"
101
+ headers = {
102
+ "Content-Type": "application/json",
103
+ "Authorization": f"Bearer {token}"
104
+ }
105
+
106
+ session = create_retry_session()
107
+
108
+ if method.upper() == "GET":
109
+ response = session.get(url, headers=headers)
110
+ elif method.upper() == "POST":
111
+ response = session.post(url, headers=headers, data=json.dumps(data) if data else None)
112
+ elif method.upper() == "PUT":
113
+ response = session.put(url, headers=headers, data=json.dumps(data) if data else None)
114
+ else:
115
+ return False, f"Unsupported HTTP method: {method}"
116
+
117
+ if response.status_code in [200, 201, 204]:
118
+ if response.status_code == 204: # No content
119
+ return True, {}
120
+ return True, response.json() if response.text else {}
121
+ else:
122
+ logger.error(f"PayPal API error: {response.status_code} - {response.text}")
123
+ return False, f"PayPal API error: {response.status_code} - {response.text}"
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error calling PayPal API: {str(e)}")
127
+ return False, f"Error calling PayPal API: {str(e)}"
128
+
129
+ def create_paypal_subscription(user_id, tier):
130
+ """Create a PayPal subscription for a user"""
131
+ try:
132
+ # Get the price from the subscription tier
133
+ from auth import SUBSCRIPTION_TIERS
134
+
135
+ if tier not in SUBSCRIPTION_TIERS:
136
+ return False, f"Invalid tier: {tier}"
137
+
138
+ price = SUBSCRIPTION_TIERS[tier]["price"]
139
+ currency = SUBSCRIPTION_TIERS[tier]["currency"]
140
+
141
+ # Create a PayPal subscription (implement PayPal API calls here)
142
+ # For now, just return a success response
143
+ return True, {
144
+ "subscription_id": f"test_sub_{uuid.uuid4()}",
145
+ "status": "ACTIVE",
146
+ "tier": tier,
147
+ "price": price,
148
+ "currency": currency
149
+ }
150
+ except Exception as e:
151
+ logger.error(f"Error creating PayPal subscription: {str(e)}")
152
+ return False, f"Failed to create PayPal subscription: {str(e)}"
153
+
154
+
155
+ # Create a product in PayPal
156
+ def create_product(name, description):
157
+ """Create a product in PayPal"""
158
+ payload = {
159
+ "name": name,
160
+ "description": description,
161
+ "type": "SERVICE",
162
+ "category": "SOFTWARE"
163
+ }
164
+
165
+ success, result = call_paypal_api("/v1/catalogs/products", "POST", payload)
166
+ if success:
167
+ return result["id"]
168
+ else:
169
+ logger.error(f"Failed to create product: {result}")
170
+ return None
171
+
172
+ # Create a subscription plan in PayPal
173
+ # Update create_plan to use INR instead of USD
174
+ def create_plan(product_id, name, price, interval="MONTH", interval_count=1):
175
+ """Create a subscription plan in PayPal"""
176
+ payload = {
177
+ "product_id": product_id,
178
+ "name": name,
179
+ "billing_cycles": [
180
+ {
181
+ "frequency": {
182
+ "interval_unit": interval,
183
+ "interval_count": interval_count
184
+ },
185
+ "tenure_type": "REGULAR",
186
+ "sequence": 1,
187
+ "total_cycles": 0, # Infinite cycles
188
+ "pricing_scheme": {
189
+ "fixed_price": {
190
+ "value": str(price),
191
+ "currency_code": "USD"
192
+ }
193
+ }
194
+ }
195
+ ],
196
+ "payment_preferences": {
197
+ "auto_bill_outstanding": True,
198
+ "setup_fee": {
199
+ "value": "0",
200
+ "currency_code": "USD"
201
+ },
202
+ "setup_fee_failure_action": "CONTINUE",
203
+ "payment_failure_threshold": 3
204
+ }
205
+ }
206
+
207
+ success, result = call_paypal_api("/v1/billing/plans", "POST", payload)
208
+ if success:
209
+ return result["id"]
210
+ else:
211
+ logger.error(f"Failed to create plan: {result}")
212
+ return None
213
+
214
+ # Update initialize_subscription_plans to use INR pricing
215
+ def initialize_subscription_plans():
216
+ """
217
+ Initialize PayPal subscription plans for the application.
218
+ This should be called once to set up the plans in PayPal.
219
+ """
220
+ try:
221
+ # Check if plans already exist
222
+ existing_plans = get_subscription_plans()
223
+ if existing_plans and len(existing_plans) >= 2:
224
+ logger.info("PayPal plans already initialized")
225
+ return existing_plans
226
+
227
+ # First, create products for each tier
228
+ products = {
229
+ "standard_tier": {
230
+ "name": "Standard Legal Document Analysis",
231
+ "description": "Standard subscription with document analysis features",
232
+ "type": "SERVICE",
233
+ "category": "SOFTWARE"
234
+ },
235
+ "premium_tier": {
236
+ "name": "Premium Legal Document Analysis",
237
+ "description": "Premium subscription with all document analysis features",
238
+ "type": "SERVICE",
239
+ "category": "SOFTWARE"
240
+ }
241
+ }
242
+
243
+ product_ids = {}
244
+ for tier, product_data in products.items():
245
+ success, result = call_paypal_api("/v1/catalogs/products", "POST", product_data)
246
+ if success:
247
+ product_ids[tier] = result["id"]
248
+ logger.info(f"Created PayPal product for {tier}: {result['id']}")
249
+ else:
250
+ logger.error(f"Failed to create product for {tier}: {result}")
251
+ return None
252
+
253
+ # Define the plans with product IDs - Changed currency to USD
254
+ plans = {
255
+ "standard_tier": {
256
+ "product_id": product_ids["standard_tier"],
257
+ "name": "Standard Plan",
258
+ "description": "Standard subscription with basic features",
259
+ "billing_cycles": [
260
+ {
261
+ "frequency": {
262
+ "interval_unit": "MONTH",
263
+ "interval_count": 1
264
+ },
265
+ "tenure_type": "REGULAR",
266
+ "sequence": 1,
267
+ "total_cycles": 0,
268
+ "pricing_scheme": {
269
+ "fixed_price": {
270
+ "value": "9.99",
271
+ "currency_code": "USD"
272
+ }
273
+ }
274
+ }
275
+ ],
276
+ "payment_preferences": {
277
+ "auto_bill_outstanding": True,
278
+ "setup_fee": {
279
+ "value": "0",
280
+ "currency_code": "USD"
281
+ },
282
+ "setup_fee_failure_action": "CONTINUE",
283
+ "payment_failure_threshold": 3
284
+ }
285
+ },
286
+ "premium_tier": {
287
+ "product_id": product_ids["premium_tier"],
288
+ "name": "Premium Plan",
289
+ "description": "Premium subscription with all features",
290
+ "billing_cycles": [
291
+ {
292
+ "frequency": {
293
+ "interval_unit": "MONTH",
294
+ "interval_count": 1
295
+ },
296
+ "tenure_type": "REGULAR",
297
+ "sequence": 1,
298
+ "total_cycles": 0,
299
+ "pricing_scheme": {
300
+ "fixed_price": {
301
+ "value": "19.99",
302
+ "currency_code": "USD"
303
+ }
304
+ }
305
+ }
306
+ ],
307
+ "payment_preferences": {
308
+ "auto_bill_outstanding": True,
309
+ "setup_fee": {
310
+ "value": "0",
311
+ "currency_code": "USD"
312
+ },
313
+ "setup_fee_failure_action": "CONTINUE",
314
+ "payment_failure_threshold": 3
315
+ }
316
+ }
317
+ }
318
+
319
+ # Create the plans in PayPal
320
+ created_plans = {}
321
+ for tier, plan_data in plans.items():
322
+ success, result = call_paypal_api("/v1/billing/plans", "POST", plan_data)
323
+ if success:
324
+ created_plans[tier] = result["id"]
325
+ logger.info(f"Created PayPal plan for {tier}: {result['id']}")
326
+ else:
327
+ logger.error(f"Failed to create plan for {tier}: {result}")
328
+
329
+ # Save the plan IDs to a file
330
+ if created_plans:
331
+ save_subscription_plans(created_plans)
332
+ return created_plans
333
+ else:
334
+ logger.error("Failed to create any PayPal plans")
335
+ return None
336
+ except Exception as e:
337
+ logger.error(f"Error initializing subscription plans: {str(e)}")
338
+ return None
339
+
340
+ # Update create_subscription_link to use call_paypal_api helper
341
+ def create_subscription_link(plan_id):
342
+ # Get the plan IDs
343
+ plans = get_subscription_plans()
344
+ if not plans:
345
+ return None
346
+
347
+ # Use environment variable for the app URL to make it work in different environments
348
+ app_url = os.getenv("APP_URL", "http://localhost:8501")
349
+
350
+ payload = {
351
+ "plan_id": plans[plan_id],
352
+ "application_context": {
353
+ "brand_name": "Legal Document Analyzer",
354
+ "locale": "en_US",
355
+ "shipping_preference": "NO_SHIPPING",
356
+ "user_action": "SUBSCRIBE_NOW",
357
+ "return_url": f"{app_url}?status=success&subscription_id={{id}}",
358
+ "cancel_url": f"{app_url}?status=cancel"
359
+ }
360
+ }
361
+
362
+ success, data = call_paypal_api("/v1/billing/subscriptions", "POST", payload)
363
+ if not success:
364
+ logger.error(f"Error creating subscription: {data}")
365
+ return None
366
+
367
+ try:
368
+ return {
369
+ "subscription_id": data["id"],
370
+ "approval_url": next(link["href"] for link in data["links"] if link["rel"] == "approve")
371
+ }
372
+ except Exception as e:
373
+ logger.error(f"Exception processing subscription response: {str(e)}")
374
+ return None
375
+
376
+ # Fix the webhook handler function signature to match how it's called in app.py
377
+ def handle_subscription_webhook(payload):
378
+ """
379
+ Handle PayPal subscription webhooks
380
+
381
+ Args:
382
+ payload: The full webhook payload
383
+
384
+ Returns:
385
+ tuple: (success, result)
386
+ - success: True if successful, False otherwise
387
+ - result: Success message or error message
388
+ """
389
+ try:
390
+ event_type = payload.get("event_type")
391
+ resource = payload.get("resource", {})
392
+
393
+ logger.info(f"Received PayPal webhook: {event_type}")
394
+
395
+ # Handle different event types
396
+ if event_type == "BILLING.SUBSCRIPTION.CREATED":
397
+ # A subscription was created
398
+ subscription_id = resource.get("id")
399
+ if not subscription_id:
400
+ return False, "Missing subscription ID in webhook"
401
+
402
+ # Update subscription status in database
403
+ conn = get_db_connection()
404
+ cursor = conn.cursor()
405
+ cursor.execute(
406
+ "UPDATE subscriptions SET status = 'pending' WHERE paypal_subscription_id = ?",
407
+ (subscription_id,)
408
+ )
409
+ conn.commit()
410
+ conn.close()
411
+
412
+ return True, "Subscription created successfully"
413
+
414
+ elif event_type == "BILLING.SUBSCRIPTION.ACTIVATED":
415
+ # A subscription was activated
416
+ subscription_id = resource.get("id")
417
+ if not subscription_id:
418
+ return False, "Missing subscription ID in webhook"
419
+
420
+ # Update subscription status in database
421
+ conn = get_db_connection()
422
+ cursor = conn.cursor()
423
+ cursor.execute(
424
+ "UPDATE subscriptions SET status = 'active' WHERE paypal_subscription_id = ?",
425
+ (subscription_id,)
426
+ )
427
+ conn.commit()
428
+ conn.close()
429
+
430
+ return True, "Subscription activated successfully"
431
+
432
+ elif event_type == "BILLING.SUBSCRIPTION.CANCELLED":
433
+ # A subscription was cancelled
434
+ subscription_id = resource.get("id")
435
+ if not subscription_id:
436
+ return False, "Missing subscription ID in webhook"
437
+
438
+ # Update subscription status in database
439
+ conn = get_db_connection()
440
+ cursor = conn.cursor()
441
+ cursor.execute(
442
+ "UPDATE subscriptions SET status = 'cancelled' WHERE paypal_subscription_id = ?",
443
+ (subscription_id,)
444
+ )
445
+ conn.commit()
446
+ conn.close()
447
+
448
+ return True, "Subscription cancelled successfully"
449
+
450
+ elif event_type == "BILLING.SUBSCRIPTION.SUSPENDED":
451
+ # A subscription was suspended
452
+ subscription_id = resource.get("id")
453
+ if not subscription_id:
454
+ return False, "Missing subscription ID in webhook"
455
+
456
+ # Update subscription status in database
457
+ conn = get_db_connection()
458
+ cursor = conn.cursor()
459
+ cursor.execute(
460
+ "UPDATE subscriptions SET status = 'suspended' WHERE paypal_subscription_id = ?",
461
+ (subscription_id,)
462
+ )
463
+ conn.commit()
464
+ conn.close()
465
+
466
+ return True, "Subscription suspended successfully"
467
+
468
+ else:
469
+ # Unhandled event type
470
+ logger.info(f"Unhandled webhook event type: {event_type}")
471
+ return True, f"Unhandled event type: {event_type}"
472
+
473
+ except Exception as e:
474
+ logger.error(f"Error handling webhook: {str(e)}")
475
+ return False, f"Error handling webhook: {str(e)}"
476
+ # Add this function to update user subscription
477
+ def update_user_subscription(user_email, subscription_id, tier):
478
+ """
479
+ Update a user's subscription status
480
+
481
+ Args:
482
+ user_email: The email of the user
483
+ subscription_id: The PayPal subscription ID
484
+ tier: The subscription tier
485
+
486
+ Returns:
487
+ tuple: (success, result)
488
+ - success: True if successful, False otherwise
489
+ - result: Success message or error message
490
+ """
491
+ try:
492
+ # Get user ID from email
493
+ conn = get_db_connection()
494
+ cursor = conn.cursor()
495
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
496
+ user_result = cursor.fetchone()
497
+
498
+ if not user_result:
499
+ conn.close()
500
+ return False, f"User not found: {user_email}"
501
+
502
+ user_id = user_result[0]
503
+
504
+ # Update the subscription status
505
+ cursor.execute(
506
+ "UPDATE subscriptions SET status = 'active' WHERE user_id = ? AND paypal_subscription_id = ?",
507
+ (user_id, subscription_id)
508
+ )
509
+
510
+ # Deactivate any other active subscriptions for this user
511
+ cursor.execute(
512
+ "UPDATE subscriptions SET status = 'inactive' WHERE user_id = ? AND paypal_subscription_id != ? AND status = 'active'",
513
+ (user_id, subscription_id)
514
+ )
515
+
516
+ # Update the user's subscription tier
517
+ cursor.execute(
518
+ "UPDATE users SET subscription_tier = ? WHERE email = ?",
519
+ (tier, user_email)
520
+ )
521
+
522
+ conn.commit()
523
+ conn.close()
524
+
525
+ return True, f"Subscription updated to {tier} tier"
526
+
527
+ except Exception as e:
528
+ logger.error(f"Error updating user subscription: {str(e)}")
529
+ return False, f"Error updating subscription: {str(e)}"
530
+
531
+ # Add this near the top with other path definitions
532
+ # Update the PLAN_IDS_PATH definition to use the correct path
533
+ PLAN_IDS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data", "plan_ids.json"))
534
+
535
+ # Make sure the data directory exists
536
+ os.makedirs(os.path.dirname(PLAN_IDS_PATH), exist_ok=True)
537
+
538
+ # Add this debug log to see where the file is expected
539
+ logger.info(f"PayPal plans will be stored at: {PLAN_IDS_PATH}")
540
+
541
+ # Add this function if it's not defined elsewhere
542
+ def get_db_connection():
543
+ """Get a connection to the SQLite database"""
544
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "../data/user_data.db"))
545
+ # Make sure the data directory exists
546
+ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
547
+ return sqlite3.connect(DB_PATH)
548
+
549
+ # Add this function to create subscription tables if needed
550
+ def initialize_database():
551
+ """Initialize the database tables needed for subscriptions"""
552
+ conn = get_db_connection()
553
+ cursor = conn.cursor()
554
+
555
+ # Check if subscriptions table exists
556
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='subscriptions'")
557
+ if cursor.fetchone():
558
+ # Table exists, check if required columns exist
559
+ cursor.execute("PRAGMA table_info(subscriptions)")
560
+ columns = [column[1] for column in cursor.fetchall()]
561
+
562
+ # Check for missing columns and add them if needed
563
+ if "user_id" not in columns:
564
+ logger.info("Adding 'user_id' column to subscriptions table")
565
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN user_id TEXT NOT NULL DEFAULT ''")
566
+
567
+ if "created_at" not in columns:
568
+ logger.info("Adding 'created_at' column to subscriptions table")
569
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN created_at TIMESTAMP")
570
+
571
+ if "expires_at" not in columns:
572
+ logger.info("Adding 'expires_at' column to subscriptions table")
573
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN expires_at TIMESTAMP")
574
+
575
+ if "paypal_subscription_id" not in columns:
576
+ logger.info("Adding 'paypal_subscription_id' column to subscriptions table")
577
+ cursor.execute("ALTER TABLE subscriptions ADD COLUMN paypal_subscription_id TEXT")
578
+ else:
579
+ # Create subscriptions table with all required columns
580
+ cursor.execute('''
581
+ CREATE TABLE IF NOT EXISTS subscriptions (
582
+ id TEXT PRIMARY KEY,
583
+ user_id TEXT NOT NULL,
584
+ tier TEXT NOT NULL,
585
+ status TEXT NOT NULL,
586
+ created_at TIMESTAMP NOT NULL,
587
+ expires_at TIMESTAMP,
588
+ paypal_subscription_id TEXT
589
+ )
590
+ ''')
591
+ logger.info("Created subscriptions table with all required columns")
592
+
593
+ # Create PayPal plans table if it doesn't exist
594
+ cursor.execute('''
595
+ CREATE TABLE IF NOT EXISTS paypal_plans (
596
+ plan_id TEXT PRIMARY KEY,
597
+ tier TEXT NOT NULL,
598
+ price REAL NOT NULL,
599
+ currency TEXT NOT NULL,
600
+ created_at TIMESTAMP NOT NULL
601
+ )
602
+ ''')
603
+
604
+ conn.commit()
605
+ conn.close()
606
+ logger.info("Database initialization completed")
607
+
608
+
609
+ def create_user_subscription_mock(user_email, tier):
610
+ """
611
+ Create a mock subscription for testing
612
+
613
+ Args:
614
+ user_email: The email of the user
615
+ tier: The subscription tier
616
+
617
+ Returns:
618
+ tuple: (success, result)
619
+ """
620
+ try:
621
+ logger.info(f"Creating mock subscription for {user_email} at tier {tier}")
622
+
623
+ # Get user ID from email
624
+ conn = get_db_connection()
625
+ cursor = conn.cursor()
626
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
627
+ user_result = cursor.fetchone()
628
+
629
+ if not user_result:
630
+ conn.close()
631
+ return False, f"User not found: {user_email}"
632
+
633
+ user_id = user_result[0]
634
+
635
+ # Create a mock subscription ID
636
+ subscription_id = f"mock_sub_{uuid.uuid4()}"
637
+
638
+ # Store the subscription in database
639
+ sub_id = str(uuid.uuid4())
640
+ start_date = datetime.now()
641
+
642
+ cursor.execute(
643
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
644
+ (sub_id, user_id, tier, "active", start_date, start_date + timedelta(days=30), subscription_id)
645
+ )
646
+
647
+ # Update user's subscription tier
648
+ cursor.execute(
649
+ "UPDATE users SET subscription_tier = ? WHERE id = ?",
650
+ (tier, user_id)
651
+ )
652
+
653
+ conn.commit()
654
+ conn.close()
655
+
656
+ # Use environment variable for the app URL
657
+ app_url = os.getenv("APP_URL", "http://localhost:3000")
658
+
659
+ # Return success with mock approval URL that matches the real PayPal URL pattern
660
+ return True, {
661
+ "subscription_id": subscription_id,
662
+ "approval_url": f"{app_url}/subscription/callback?status=success&subscription_id={subscription_id}",
663
+ "tier": tier
664
+ }
665
+
666
+ except Exception as e:
667
+ logger.error(f"Error creating mock subscription: {str(e)}")
668
+ return False, f"Error creating subscription: {str(e)}"
669
+
670
+ # Add this at the end of the file
671
+ def initialize():
672
+ """Initialize the PayPal integration module"""
673
+ try:
674
+ # Create necessary directories
675
+ os.makedirs(os.path.dirname(PLAN_IDS_PATH), exist_ok=True)
676
+
677
+ # Initialize database
678
+ initialize_database()
679
+
680
+ # Initialize subscription plans
681
+ plans = get_subscription_plans()
682
+ if plans:
683
+ logger.info(f"Subscription plans initialized: {plans}")
684
+ else:
685
+ logger.warning("Failed to initialize subscription plans")
686
+
687
+ return True
688
+ except Exception as e:
689
+ logger.error(f"Error initializing PayPal integration: {str(e)}")
690
+ return False
691
+
692
+ # Call initialize when the module is imported
693
+ initialize()
694
+
695
+ # Add this function to get subscription plans
696
+ def get_subscription_plans():
697
+ """
698
+ Get all available subscription plans with correct pricing
699
+ """
700
+ try:
701
+ # Check if we have plan IDs saved in a file
702
+ if os.path.exists(PLAN_IDS_PATH):
703
+ try:
704
+ with open(PLAN_IDS_PATH, 'r') as f:
705
+ plans = json.load(f)
706
+ logger.info(f"Loaded subscription plans from {PLAN_IDS_PATH}: {plans}")
707
+ return plans
708
+ except Exception as e:
709
+ logger.error(f"Error reading plan IDs file: {str(e)}")
710
+ return {}
711
+
712
+ # If no file exists, return empty dict
713
+ logger.warning(f"No plan IDs file found at {PLAN_IDS_PATH}. Please initialize subscription plans.")
714
+ return {}
715
+
716
+ except Exception as e:
717
+ logger.error(f"Error getting subscription plans: {str(e)}")
718
+ return {}
719
+
720
+ # Add this function to create subscription tables if needed
721
+ def initialize_database():
722
+ """Initialize the database tables needed for subscriptions"""
723
+ conn = get_db_connection()
724
+ cursor = conn.cursor()
725
+
726
+ # Create subscriptions table if it doesn't exist
727
+ cursor.execute('''
728
+ CREATE TABLE IF NOT EXISTS subscriptions (
729
+ id TEXT PRIMARY KEY,
730
+ user_id TEXT NOT NULL,
731
+ tier TEXT NOT NULL,
732
+ status TEXT NOT NULL,
733
+ created_at TIMESTAMP NOT NULL,
734
+ expires_at TIMESTAMP,
735
+ paypal_subscription_id TEXT
736
+ )
737
+ ''')
738
+
739
+ # Create PayPal plans table if it doesn't exist
740
+ cursor.execute('''
741
+ CREATE TABLE IF NOT EXISTS paypal_plans (
742
+ plan_id TEXT PRIMARY KEY,
743
+ tier TEXT NOT NULL,
744
+ price REAL NOT NULL,
745
+ currency TEXT NOT NULL,
746
+ created_at TIMESTAMP NOT NULL
747
+ )
748
+ ''')
749
+
750
+ conn.commit()
751
+ conn.close()
752
+
753
+
754
+ def create_user_subscription(user_email, tier):
755
+ """
756
+ Create a real PayPal subscription for a user
757
+
758
+ Args:
759
+ user_email: The email of the user
760
+ tier: The subscription tier (standard_tier or premium_tier)
761
+
762
+ Returns:
763
+ tuple: (success, result)
764
+ - success: True if successful, False otherwise
765
+ - result: Dictionary with subscription details or error message
766
+ """
767
+ try:
768
+ # Validate tier
769
+ valid_tiers = ["standard_tier", "premium_tier"]
770
+ if tier not in valid_tiers:
771
+ return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
772
+
773
+ # Get the plan IDs
774
+ plans = get_subscription_plans()
775
+
776
+ # Log the plans for debugging
777
+ logger.info(f"Available subscription plans: {plans}")
778
+
779
+ # If no plans found, check if the file exists and try to load it directly
780
+ if not plans:
781
+ if os.path.exists(PLAN_IDS_PATH):
782
+ logger.info(f"Plan IDs file exists at {PLAN_IDS_PATH}, but couldn't load plans. Trying direct load.")
783
+ try:
784
+ with open(PLAN_IDS_PATH, 'r') as f:
785
+ plans = json.load(f)
786
+ logger.info(f"Directly loaded plans: {plans}")
787
+ except Exception as e:
788
+ logger.error(f"Error directly loading plans: {str(e)}")
789
+ else:
790
+ logger.error(f"Plan IDs file does not exist at {PLAN_IDS_PATH}")
791
+
792
+ # If still no plans, return error
793
+ if not plans:
794
+ logger.error("No PayPal plans found. Please initialize plans first.")
795
+ return False, "PayPal plans not configured. Please contact support."
796
+
797
+ # Check if the tier exists in plans
798
+ if tier not in plans:
799
+ return False, f"No plan found for tier: {tier}"
800
+
801
+ # Use environment variable for the app URL
802
+ app_url = os.getenv("APP_URL", "http://localhost:3000")
803
+
804
+ # Create the subscription with PayPal
805
+ payload = {
806
+ "plan_id": plans[tier],
807
+ "subscriber": {
808
+ "email_address": user_email
809
+ },
810
+ "application_context": {
811
+ "brand_name": "Legal Document Analyzer",
812
+ "locale": "en-US", # Changed from en_US to en-US
813
+ "shipping_preference": "NO_SHIPPING",
814
+ "user_action": "SUBSCRIBE_NOW",
815
+ "return_url": f"{app_url}/subscription/callback?status=success",
816
+ "cancel_url": f"{app_url}/subscription/callback?status=cancel"
817
+ }
818
+ }
819
+
820
+ # Make the API call to PayPal
821
+ success, subscription_data = call_paypal_api("/v1/billing/subscriptions", "POST", payload)
822
+ if not success:
823
+ return False, subscription_data # This is already an error message
824
+
825
+ # Extract the approval URL
826
+ approval_url = next((link["href"] for link in subscription_data["links"]
827
+ if link["rel"] == "approve"), None)
828
+
829
+ if not approval_url:
830
+ return False, "No approval URL found in PayPal response"
831
+
832
+ # Get user ID from email
833
+ conn = get_db_connection()
834
+ cursor = conn.cursor()
835
+ cursor.execute("SELECT id FROM users WHERE email = ?", (user_email,))
836
+ user_result = cursor.fetchone()
837
+
838
+ if not user_result:
839
+ conn.close()
840
+ return False, f"User not found: {user_email}"
841
+
842
+ user_id = user_result[0]
843
+
844
+ # Store pending subscription in database
845
+ sub_id = str(uuid.uuid4())
846
+ start_date = datetime.now()
847
+
848
+ cursor.execute(
849
+ "INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at, paypal_subscription_id) VALUES (?, ?, ?, ?, ?, ?, ?)",
850
+ (sub_id, user_id, tier, "pending", start_date, None, subscription_data["id"])
851
+ )
852
+
853
+ conn.commit()
854
+ conn.close()
855
+
856
+ # Return success with approval URL
857
+ return True, {
858
+ "subscription_id": subscription_data["id"],
859
+ "approval_url": approval_url,
860
+ "tier": tier
861
+ }
862
+
863
+ except Exception as e:
864
+ logger.error(f"Error creating user subscription: {str(e)}")
865
+ return False, f"Error creating subscription: {str(e)}"
866
+
867
+ # Add a function to cancel a subscription
868
+ def cancel_subscription(subscription_id, reason="Customer requested cancellation"):
869
+ """
870
+ Cancel a PayPal subscription
871
+
872
+ Args:
873
+ subscription_id: The PayPal subscription ID
874
+ reason: The reason for cancellation
875
+
876
+ Returns:
877
+ tuple: (success, result)
878
+ - success: True if successful, False otherwise
879
+ - result: Success message or error message
880
+ """
881
+ try:
882
+ # Cancel the subscription with PayPal
883
+ payload = {
884
+ "reason": reason
885
+ }
886
+
887
+ success, result = call_paypal_api(
888
+ f"/v1/billing/subscriptions/{subscription_id}/cancel",
889
+ "POST",
890
+ payload
891
+ )
892
+
893
+ if not success:
894
+ return False, result
895
+
896
+ # Update subscription status in database
897
+ conn = get_db_connection()
898
+ cursor = conn.cursor()
899
+ cursor.execute(
900
+ "UPDATE subscriptions SET status = 'cancelled' WHERE paypal_subscription_id = ?",
901
+ (subscription_id,)
902
+ )
903
+
904
+ # Get the user ID for this subscription
905
+ cursor.execute(
906
+ "SELECT user_id FROM subscriptions WHERE paypal_subscription_id = ?",
907
+ (subscription_id,)
908
+ )
909
+ user_result = cursor.fetchone()
910
+
911
+ if user_result:
912
+ # Update user to free tier
913
+ cursor.execute(
914
+ "UPDATE users SET subscription_tier = 'free_tier' WHERE id = ?",
915
+ (user_result[0],)
916
+ )
917
+
918
+ conn.commit()
919
+ conn.close()
920
+
921
+ return True, "Subscription cancelled successfully"
922
+
923
+ except Exception as e:
924
+ logger.error(f"Error cancelling subscription: {str(e)}")
925
+ return False, f"Error cancelling subscription: {str(e)}"
926
+
927
+ def verify_subscription_payment(subscription_id):
928
+ """
929
+ Verify a subscription payment with PayPal
930
+
931
+ Args:
932
+ subscription_id: The PayPal subscription ID
933
+
934
+ Returns:
935
+ tuple: (success, result)
936
+ - success: True if successful, False otherwise
937
+ - result: Dictionary with subscription details or error message
938
+ """
939
+ try:
940
+ # Get subscription details from PayPal using our helper
941
+ success, subscription_data = call_paypal_api(f"/v1/billing/subscriptions/{subscription_id}")
942
+ if not success:
943
+ return False, subscription_data # This is already an error message
944
+
945
+ # Check subscription status
946
+ status = subscription_data.get("status", "").upper()
947
+
948
+ if status not in ["ACTIVE", "APPROVED"]:
949
+ return False, f"Subscription is not active: {status}"
950
+
951
+ # Return success with subscription data
952
+ return True, subscription_data
953
+
954
+ except Exception as e:
955
+ logger.error(f"Error verifying subscription: {str(e)}")
956
+ return False, f"Error verifying subscription: {str(e)}"
957
+
958
+ def verify_paypal_subscription(subscription_id):
959
+ """
960
+ Verify a PayPal subscription
961
+
962
+ Args:
963
+ subscription_id: The PayPal subscription ID
964
+
965
+ Returns:
966
+ tuple: (success, result)
967
+ """
968
+ try:
969
+ # Skip verification for mock subscriptions
970
+ if subscription_id.startswith("mock_sub_"):
971
+ return True, {"status": "ACTIVE"}
972
+
973
+ # For real subscriptions, call PayPal API
974
+ success, result = call_paypal_api(f"/v1/billing/subscriptions/{subscription_id}", "GET")
975
+
976
+ if success:
977
+ # Check subscription status
978
+ if result.get("status") == "ACTIVE":
979
+ return True, result
980
+ else:
981
+ return False, f"Subscription is not active: {result.get('status')}"
982
+ else:
983
+ logger.error(f"PayPal API error: {result}")
984
+ return False, f"Failed to verify subscription: {result}"
985
+ except Exception as e:
986
+ logger.error(f"Error verifying PayPal subscription: {str(e)}")
987
+ return False, f"Error verifying subscription: {str(e)}"
988
+
989
+ # Add this function to save subscription plans
990
+ def save_subscription_plans(plans):
991
+ """
992
+ Save subscription plans to a file
993
+
994
+ Args:
995
+ plans: Dictionary of plan IDs by tier
996
+ """
997
+ try:
998
+ with open(PLAN_IDS_PATH, 'w') as f:
999
+ json.dump(plans, f)
1000
+ logger.info(f"Saved subscription plans to {PLAN_IDS_PATH}")
1001
+ return True
1002
+ except Exception as e:
1003
+ logger.error(f"Error saving subscription plans: {str(e)}")
1004
+ return False
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.95.0
2
+ uvicorn>=0.21.1
3
+ pydantic>=1.10.7
4
+ python-multipart>=0.0.6
5
+ python-dotenv>=1.0.0
6
+ pdfplumber>=0.9.0
7
+ spacy>=3.5.2
8
+ torch>=2.0.0
9
+ transformers>=4.28.1
10
+ sentence-transformers>=2.2.2
11
+ moviepy>=1.0.3
12
+ matplotlib>=3.7.1
13
+ numpy>=1.24.2
14
+ passlib>=1.7.4
15
+ python-jose[cryptography]>=3.3.0
16
+ bcrypt>=4.0.1
17
+ requests>=2.28.2
18
+ SQLAlchemy>=2.0.9
19
+ aiofiles>=23.1.0
20
+ huggingface_hub>=0.16.4
21
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl