mominah commited on
Commit
0b80ea1
Β·
verified Β·
1 Parent(s): 2d010d1

Create video_rag_routes.py

Browse files
Files changed (1) hide show
  1. video_rag_routes.py +148 -0
video_rag_routes.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # video_rag_routes.py
2
+
3
+ import os
4
+ import uuid
5
+ from fastapi import APIRouter, HTTPException, UploadFile, File
6
+ from fastapi.responses import JSONResponse
7
+ from pydantic import BaseModel
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_groq import ChatGroq
14
+ from google import genai
15
+ from google.genai import types
16
+
17
+ router = APIRouter(prefix="/video_rag", tags=["video_rag"])
18
+
19
+ # β€”β€”β€” Helpers β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
20
+
21
+ def init_google_client():
22
+ api_key = os.getenv("GOOGLE_API_KEY", "")
23
+ if not api_key:
24
+ raise ValueError("GOOGLE_API_KEY must be set")
25
+ return genai.Client(api_key=api_key)
26
+
27
+ def get_llm():
28
+ api_key = os.getenv("CHATGROQ_API_KEY", "")
29
+ if not api_key:
30
+ raise ValueError("CHATGROQ_API_KEY must be set")
31
+ return ChatGroq(
32
+ model="llama-3.3-70b-versatile",
33
+ temperature=0,
34
+ max_tokens=1024,
35
+ api_key=api_key,
36
+ )
37
+
38
+ def get_embeddings():
39
+ return HuggingFaceEmbeddings(
40
+ model_name="BAAI/bge-small-en",
41
+ model_kwargs={"device": "cpu"},
42
+ encode_kwargs={"normalize_embeddings": True},
43
+ )
44
+
45
+ # Simple prompt template for RAG
46
+ quiz_prompt = """
47
+ You are an assistant specialized in answering questions based on the provided context.
48
+ If the context does not contain the answer, reply β€œI don't know.”
49
+ Context:
50
+ {context}
51
+
52
+ Question:
53
+ {question}
54
+
55
+ Answer:
56
+ """
57
+ chat_prompt = ChatPromptTemplate.from_messages([
58
+ ("system", quiz_prompt),
59
+ ("human", "{question}"),
60
+ ])
61
+
62
+ def create_chain(retriever):
63
+ return ConversationalRetrievalChain.from_llm(
64
+ llm=get_llm(),
65
+ retriever=retriever,
66
+ return_source_documents=True,
67
+ chain_type="stuff",
68
+ combine_docs_chain_kwargs={"prompt": chat_prompt},
69
+ verbose=False,
70
+ )
71
+
72
+ # In-memory session store
73
+ sessions: dict[str, dict] = {}
74
+
75
+ def process_transcription(text: str) -> str:
76
+ # split β†’ embed β†’ index β†’ store retriever & empty history
77
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
78
+ chunks = splitter.split_text(text)
79
+ vs = FAISS.from_texts(chunks, get_embeddings())
80
+ retr = vs.as_retriever(search_kwargs={"k": 3})
81
+ sid = str(uuid.uuid4())
82
+ sessions[sid] = {"retriever": retr, "history": []}
83
+ return sid
84
+
85
+ # β€”β€”β€” Endpoints β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
86
+
87
+ class URLIn(BaseModel):
88
+ youtube_url: str
89
+
90
+ @router.post("/transcribe_video")
91
+ async def transcribe_url(body: URLIn):
92
+ client = init_google_client()
93
+ try:
94
+ resp = client.models.generate_content(
95
+ model="models/gemini-2.0-flash",
96
+ contents=types.Content(parts=[
97
+ types.Part(text="Transcribe the video"),
98
+ types.Part(file_data=types.FileData(file_uri=body.youtube_url))
99
+ ])
100
+ )
101
+ txt = resp.candidates[0].content.parts[0].text
102
+ sid = process_transcription(txt)
103
+ return {"session_id": sid}
104
+ except Exception as e:
105
+ raise HTTPException(status_code=500, detail=str(e))
106
+
107
+ @router.post("/upload_video")
108
+ async def upload_file(
109
+ file: UploadFile = File(...),
110
+ prompt: str = "Transcribe the video",
111
+ ):
112
+ data = await file.read()
113
+ client = init_google_client()
114
+ try:
115
+ resp = client.models.generate_content(
116
+ model="models/gemini-2.0-flash",
117
+ contents=types.Content(parts=[
118
+ types.Part(text=prompt),
119
+ types.Part(inline_data=types.Blob(data=data, mime_type=file.content_type))
120
+ ])
121
+ )
122
+ txt = resp.candidates[0].content.parts[0].text
123
+ sid = process_transcription(txt)
124
+ return {"session_id": sid}
125
+ except Exception as e:
126
+ raise HTTPException(status_code=500, detail=str(e))
127
+
128
+ class QueryIn(BaseModel):
129
+ session_id: str
130
+ query: str
131
+
132
+ @router.post("/vid_query")
133
+ async def query_rag(body: QueryIn):
134
+ sess = sessions.get(body.session_id)
135
+ if not sess:
136
+ raise HTTPException(status_code=404, detail="Session not found")
137
+ chain = create_chain(sess["retriever"])
138
+ result = chain({
139
+ "question": body.query,
140
+ "chat_history": sess["history"]
141
+ })
142
+ answer = result.get("answer", "I don't know.")
143
+ # update history
144
+ sess["history"].append((body.query, answer))
145
+ # collect source snippets
146
+ docs = result.get("source_documents") or []
147
+ srcs = [getattr(d, "page_content", str(d)) for d in docs]
148
+ return {"answer": answer, "source_documents": srcs}