Spaces:
Running
Running
Create video_rag_routes.py
Browse files- video_rag_routes.py +148 -0
video_rag_routes.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# video_rag_routes.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File
|
6 |
+
from fastapi.responses import JSONResponse
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
+
from langchain_community.vectorstores import FAISS
|
10 |
+
from langchain.chains import ConversationalRetrievalChain
|
11 |
+
from langchain_core.prompts import ChatPromptTemplate
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
from langchain_groq import ChatGroq
|
14 |
+
from google import genai
|
15 |
+
from google.genai import types
|
16 |
+
|
17 |
+
router = APIRouter(prefix="/video_rag", tags=["video_rag"])
|
18 |
+
|
19 |
+
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββ
|
20 |
+
|
21 |
+
def init_google_client():
|
22 |
+
api_key = os.getenv("GOOGLE_API_KEY", "")
|
23 |
+
if not api_key:
|
24 |
+
raise ValueError("GOOGLE_API_KEY must be set")
|
25 |
+
return genai.Client(api_key=api_key)
|
26 |
+
|
27 |
+
def get_llm():
|
28 |
+
api_key = os.getenv("CHATGROQ_API_KEY", "")
|
29 |
+
if not api_key:
|
30 |
+
raise ValueError("CHATGROQ_API_KEY must be set")
|
31 |
+
return ChatGroq(
|
32 |
+
model="llama-3.3-70b-versatile",
|
33 |
+
temperature=0,
|
34 |
+
max_tokens=1024,
|
35 |
+
api_key=api_key,
|
36 |
+
)
|
37 |
+
|
38 |
+
def get_embeddings():
|
39 |
+
return HuggingFaceEmbeddings(
|
40 |
+
model_name="BAAI/bge-small-en",
|
41 |
+
model_kwargs={"device": "cpu"},
|
42 |
+
encode_kwargs={"normalize_embeddings": True},
|
43 |
+
)
|
44 |
+
|
45 |
+
# Simple prompt template for RAG
|
46 |
+
quiz_prompt = """
|
47 |
+
You are an assistant specialized in answering questions based on the provided context.
|
48 |
+
If the context does not contain the answer, reply βI don't know.β
|
49 |
+
Context:
|
50 |
+
{context}
|
51 |
+
|
52 |
+
Question:
|
53 |
+
{question}
|
54 |
+
|
55 |
+
Answer:
|
56 |
+
"""
|
57 |
+
chat_prompt = ChatPromptTemplate.from_messages([
|
58 |
+
("system", quiz_prompt),
|
59 |
+
("human", "{question}"),
|
60 |
+
])
|
61 |
+
|
62 |
+
def create_chain(retriever):
|
63 |
+
return ConversationalRetrievalChain.from_llm(
|
64 |
+
llm=get_llm(),
|
65 |
+
retriever=retriever,
|
66 |
+
return_source_documents=True,
|
67 |
+
chain_type="stuff",
|
68 |
+
combine_docs_chain_kwargs={"prompt": chat_prompt},
|
69 |
+
verbose=False,
|
70 |
+
)
|
71 |
+
|
72 |
+
# In-memory session store
|
73 |
+
sessions: dict[str, dict] = {}
|
74 |
+
|
75 |
+
def process_transcription(text: str) -> str:
|
76 |
+
# split β embed β index β store retriever & empty history
|
77 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
|
78 |
+
chunks = splitter.split_text(text)
|
79 |
+
vs = FAISS.from_texts(chunks, get_embeddings())
|
80 |
+
retr = vs.as_retriever(search_kwargs={"k": 3})
|
81 |
+
sid = str(uuid.uuid4())
|
82 |
+
sessions[sid] = {"retriever": retr, "history": []}
|
83 |
+
return sid
|
84 |
+
|
85 |
+
# βββ Endpoints βββββββββββββββββββββββββββββββββββββββββββ
|
86 |
+
|
87 |
+
class URLIn(BaseModel):
|
88 |
+
youtube_url: str
|
89 |
+
|
90 |
+
@router.post("/transcribe_video")
|
91 |
+
async def transcribe_url(body: URLIn):
|
92 |
+
client = init_google_client()
|
93 |
+
try:
|
94 |
+
resp = client.models.generate_content(
|
95 |
+
model="models/gemini-2.0-flash",
|
96 |
+
contents=types.Content(parts=[
|
97 |
+
types.Part(text="Transcribe the video"),
|
98 |
+
types.Part(file_data=types.FileData(file_uri=body.youtube_url))
|
99 |
+
])
|
100 |
+
)
|
101 |
+
txt = resp.candidates[0].content.parts[0].text
|
102 |
+
sid = process_transcription(txt)
|
103 |
+
return {"session_id": sid}
|
104 |
+
except Exception as e:
|
105 |
+
raise HTTPException(status_code=500, detail=str(e))
|
106 |
+
|
107 |
+
@router.post("/upload_video")
|
108 |
+
async def upload_file(
|
109 |
+
file: UploadFile = File(...),
|
110 |
+
prompt: str = "Transcribe the video",
|
111 |
+
):
|
112 |
+
data = await file.read()
|
113 |
+
client = init_google_client()
|
114 |
+
try:
|
115 |
+
resp = client.models.generate_content(
|
116 |
+
model="models/gemini-2.0-flash",
|
117 |
+
contents=types.Content(parts=[
|
118 |
+
types.Part(text=prompt),
|
119 |
+
types.Part(inline_data=types.Blob(data=data, mime_type=file.content_type))
|
120 |
+
])
|
121 |
+
)
|
122 |
+
txt = resp.candidates[0].content.parts[0].text
|
123 |
+
sid = process_transcription(txt)
|
124 |
+
return {"session_id": sid}
|
125 |
+
except Exception as e:
|
126 |
+
raise HTTPException(status_code=500, detail=str(e))
|
127 |
+
|
128 |
+
class QueryIn(BaseModel):
|
129 |
+
session_id: str
|
130 |
+
query: str
|
131 |
+
|
132 |
+
@router.post("/vid_query")
|
133 |
+
async def query_rag(body: QueryIn):
|
134 |
+
sess = sessions.get(body.session_id)
|
135 |
+
if not sess:
|
136 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
137 |
+
chain = create_chain(sess["retriever"])
|
138 |
+
result = chain({
|
139 |
+
"question": body.query,
|
140 |
+
"chat_history": sess["history"]
|
141 |
+
})
|
142 |
+
answer = result.get("answer", "I don't know.")
|
143 |
+
# update history
|
144 |
+
sess["history"].append((body.query, answer))
|
145 |
+
# collect source snippets
|
146 |
+
docs = result.get("source_documents") or []
|
147 |
+
srcs = [getattr(d, "page_content", str(d)) for d in docs]
|
148 |
+
return {"answer": answer, "source_documents": srcs}
|