gavinzli commited on
Commit
c75e17a
·
1 Parent(s): a979028

Refactor Google OAuth2 callback to include state validation and error handling

Browse files
Files changed (2) hide show
  1. app/models/db/__init__.py +6 -17
  2. app/router/auth.py +74 -46
app/models/db/__init__.py CHANGED
@@ -1,32 +1,21 @@
1
  """This module is responsible for initializing the database connection and creating the necessary tables."""
2
- # import faiss
3
  from pinecone import Pinecone, ServerlessSpec
4
- # from langchain_community.vectorstores import FAISS
5
- # from langchain_community.docstore.in_memory import InMemoryDocstore
6
  from langchain_pinecone import PineconeVectorStore
7
  from models.llm import EmbeddingsModel
8
 
9
  embeddings = EmbeddingsModel("all-MiniLM-L6-v2")
10
 
11
- # vectorstore = FAISS(
12
- # embedding_function=embeddings,
13
- # index=faiss.IndexFlatL2(len(embeddings.embed_query("hello world"))),
14
- # docstore=InMemoryDocstore(),
15
- # index_to_docstore_id={}
16
- # )
17
-
18
  pc = Pinecone()
19
- index_name = "mails"
20
- embedding_dim = len(embeddings.embed_query("hello world"))
21
- if not pc.has_index(index_name):
22
  pc.create_index(
23
- name=index_name,
24
- dimension=embedding_dim, # Replace with your model dimensions
25
- metric="cosine", # Replace with your model metric
26
  spec=ServerlessSpec(
27
  cloud="aws",
28
  region="us-east-1"
29
  )
30
  )
31
- index = pc.Index(index_name)
32
  vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
 
1
  """This module is responsible for initializing the database connection and creating the necessary tables."""
 
2
  from pinecone import Pinecone, ServerlessSpec
 
 
3
  from langchain_pinecone import PineconeVectorStore
4
  from models.llm import EmbeddingsModel
5
 
6
  embeddings = EmbeddingsModel("all-MiniLM-L6-v2")
7
 
 
 
 
 
 
 
 
8
  pc = Pinecone()
9
+ INDEX_NAME = "mails"
10
+ if not pc.has_index(INDEX_NAME):
 
11
  pc.create_index(
12
+ name=INDEX_NAME,
13
+ dimension=len(embeddings.embed_query("hello")),
14
+ metric="cosine",
15
  spec=ServerlessSpec(
16
  cloud="aws",
17
  region="us-east-1"
18
  )
19
  )
20
+ index = pc.Index(INDEX_NAME)
21
  vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
app/router/auth.py CHANGED
@@ -2,10 +2,11 @@
2
  import os
3
  import json
4
  import pickle
5
- from fastapi import APIRouter, Request
6
  from fastapi.responses import JSONResponse
7
  from google_auth_oauthlib.flow import InstalledAppFlow
8
  from googleapiclient.discovery import build
 
9
 
10
  router = APIRouter(tags=["auth"])
11
 
@@ -44,52 +45,79 @@ async def get_auth_url():
44
  return JSONResponse({"url": auth_url})
45
 
46
  @router.get("/auth/google/callback")
47
- async def google_callback(code: str, request: Request):
48
- """
49
- Handles the Google OAuth2 callback by exchanging the authorization code for credentials,
50
- retrieving the user's Gmail profile, and saving the credentials to a file.
 
51
 
52
- Args:
53
- code (str): The authorization code returned by Google's OAuth2 server.
54
- request (Request): The incoming HTTP request object.
 
 
55
 
56
- Returns:
57
- JSONResponse: A JSON response containing the user's Gmail profile information.
58
 
59
- Side Effects:
60
- - Saves the user's credentials to a pickle file named after their email address.
61
- - Stores the credentials in the session state of the request.
 
62
 
63
- Dependencies:
64
- - google_auth_oauthlib.flow.InstalledAppFlow: Used to handle the OAuth2 flow.
65
- - googleapiclient.discovery.build: Used to build the Gmail API service.
66
- - json: Used to serialize and deserialize credentials.
67
- - pickle: Used to save credentials to a file.
68
- """
69
- flow = InstalledAppFlow.from_client_config(CLIENT_CONFIG, SCOPES)
70
- flow.redirect_uri = REDIRECT_URI
71
- flow.fetch_token(code=code)
72
- credentials = flow.credentials
73
- request.state.session["credential"] = json.loads(credentials.to_json())
74
- # cred_dict = (request.state.session.get("credential"))
75
- # cred = Credentials(
76
- # token=cred_dict["token"],
77
- # refresh_token=cred_dict["refresh_token"],
78
- # token_uri=cred_dict["token_uri"],
79
- # client_id=cred_dict["client_id"],
80
- # client_secret=cred_dict["client_secret"],
81
- # scopes=cred_dict["scopes"],
82
- # )
83
- # service = build("gmail", "v1", credentials=Credentials(
84
- # token=cred_dict["token"],
85
- # refresh_token=cred_dict["refresh_token"],
86
- # token_uri=cred_dict["token_uri"],
87
- # client_id=cred_dict["client_id"],
88
- # client_secret=cred_dict["client_secret"],
89
- # scopes=cred_dict["scopes"],
90
- # ))
91
- service = build("gmail", "v1", credentials=credentials)
92
- profile = service.users().getProfile(userId="me").execute()
93
- with open(f"cache/{profile['emailAddress']}.pickle", "wb") as token:
94
- pickle.dump(credentials, token)
95
- return JSONResponse(profile)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import json
4
  import pickle
5
+ from fastapi import APIRouter, Request, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from google_auth_oauthlib.flow import InstalledAppFlow
8
  from googleapiclient.discovery import build
9
+ from oauthlib.oauth2.rfc6749.errors import InvalidGrantError
10
 
11
  router = APIRouter(tags=["auth"])
12
 
 
45
  return JSONResponse({"url": auth_url})
46
 
47
  @router.get("/auth/google/callback")
48
+ async def google_callback(code: str, state: str = None, scope: str = None, request: Request = None):
49
+ try:
50
+ # Validate state (optional, for CSRF protection)
51
+ if state and request.state.session.get("oauth_state") != state:
52
+ raise HTTPException(status_code=400, detail="Invalid state parameter")
53
 
54
+ flow = InstalledAppFlow.from_client_config(CLIENT_CONFIG, SCOPES)
55
+ flow.redirect_uri = REDIRECT_URI
56
+ flow.fetch_token(code=code)
57
+ credentials = flow.credentials
58
+ request.state.session["credential"] = json.loads(credentials.to_json())
59
 
60
+ service = build("gmail", "v1", credentials=credentials)
61
+ profile = service.users().getProfile(userId="me").execute()
62
 
63
+ # Ensure cache directory exists
64
+ os.makedirs("cache", exist_ok=True)
65
+ with open(f"cache/{profile['emailAddress']}.pickle", "wb") as token:
66
+ pickle.dump(credentials, token)
67
 
68
+ return JSONResponse(profile)
69
+ except InvalidGrantError:
70
+ raise HTTPException(status_code=400, detail="Invalid or expired authorization code")
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=f"Authentication failed: {str(e)}")
73
+
74
+ # @router.get("/auth/google/callback")
75
+ # async def google_callback(code: str, request: Request):
76
+ # """
77
+ # Handles the Google OAuth2 callback by exchanging the authorization code for credentials,
78
+ # retrieving the user's Gmail profile, and saving the credentials to a file.
79
+
80
+ # Args:
81
+ # code (str): The authorization code returned by Google's OAuth2 server.
82
+ # request (Request): The incoming HTTP request object.
83
+
84
+ # Returns:
85
+ # JSONResponse: A JSON response containing the user's Gmail profile information.
86
+
87
+ # Side Effects:
88
+ # - Saves the user's credentials to a pickle file named after their email address.
89
+ # - Stores the credentials in the session state of the request.
90
+
91
+ # Dependencies:
92
+ # - google_auth_oauthlib.flow.InstalledAppFlow: Used to handle the OAuth2 flow.
93
+ # - googleapiclient.discovery.build: Used to build the Gmail API service.
94
+ # - json: Used to serialize and deserialize credentials.
95
+ # - pickle: Used to save credentials to a file.
96
+ # """
97
+ # flow = InstalledAppFlow.from_client_config(CLIENT_CONFIG, SCOPES)
98
+ # flow.redirect_uri = REDIRECT_URI
99
+ # flow.fetch_token(code=code)
100
+ # credentials = flow.credentials
101
+ # request.state.session["credential"] = json.loads(credentials.to_json())
102
+ # # cred_dict = (request.state.session.get("credential"))
103
+ # # cred = Credentials(
104
+ # # token=cred_dict["token"],
105
+ # # refresh_token=cred_dict["refresh_token"],
106
+ # # token_uri=cred_dict["token_uri"],
107
+ # # client_id=cred_dict["client_id"],
108
+ # # client_secret=cred_dict["client_secret"],
109
+ # # scopes=cred_dict["scopes"],
110
+ # # )
111
+ # # service = build("gmail", "v1", credentials=Credentials(
112
+ # # token=cred_dict["token"],
113
+ # # refresh_token=cred_dict["refresh_token"],
114
+ # # token_uri=cred_dict["token_uri"],
115
+ # # client_id=cred_dict["client_id"],
116
+ # # client_secret=cred_dict["client_secret"],
117
+ # # scopes=cred_dict["scopes"],
118
+ # # ))
119
+ # service = build("gmail", "v1", credentials=credentials)
120
+ # profile = service.users().getProfile(userId="me").execute()
121
+ # with open(f"cache/{profile['emailAddress']}.pickle", "wb") as token:
122
+ # pickle.dump(credentials, token)
123
+ # return JSONResponse(profile)