Spaces:
Running
Running
File size: 4,833 Bytes
0b80ea1 6e37271 0b80ea1 10c7a75 0b80ea1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# video_rag_routes.py
import os
import uuid
from fastapi import APIRouter, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from google import genai
from google.genai import types
router = APIRouter()
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββ
def init_google_client():
api_key = os.getenv("GOOGLE_API_KEY", "")
if not api_key:
raise ValueError("GOOGLE_API_KEY must be set")
return genai.Client(api_key=api_key)
def get_llm():
api_key = os.getenv("CHATGROQ_API_KEY", "")
if not api_key:
raise ValueError("CHATGROQ_API_KEY must be set")
return ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=1024,
api_key=api_key,
)
def get_embeddings():
return HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# Simple prompt template for RAG
quiz_prompt = """
You are an assistant specialized in answering questions based on the provided context.
If the context does not contain the answer, reply βI don't know.β
Context:
{context}
Question:
{question}
Answer:
"""
chat_prompt = ChatPromptTemplate.from_messages([
("system", quiz_prompt),
("human", "{question}"),
])
def create_chain(retriever):
return ConversationalRetrievalChain.from_llm(
llm=get_llm(),
retriever=retriever,
return_source_documents=True,
chain_type="stuff",
combine_docs_chain_kwargs={"prompt": chat_prompt},
verbose=False,
)
# In-memory session store
sessions: dict[str, dict] = {}
def process_transcription(text: str) -> str:
# split β embed β index β store retriever & empty history
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
chunks = splitter.split_text(text)
vs = FAISS.from_texts(chunks, get_embeddings())
retr = vs.as_retriever(search_kwargs={"k": 3})
sid = str(uuid.uuid4())
sessions[sid] = {"retriever": retr, "history": []}
return sid
# βββ Endpoints βββββββββββββββββββββββββββββββββββββββββββ
class URLIn(BaseModel):
youtube_url: str
@router.post("/transcribe_video")
async def transcribe_url(body: URLIn):
client = init_google_client()
try:
resp = client.models.generate_content(
model="models/gemini-2.0-flash",
contents=types.Content(parts=[
types.Part(text="Transcribe the video"),
types.Part(file_data=types.FileData(file_uri=body.youtube_url))
])
)
txt = resp.candidates[0].content.parts[0].text
sid = process_transcription(txt)
return {"session_id": sid}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/upload_video")
async def upload_file(
file: UploadFile = File(...),
prompt: str = "Transcribe the video",
):
data = await file.read()
client = init_google_client()
try:
resp = client.models.generate_content(
model="models/gemini-2.0-flash",
contents=types.Content(parts=[
types.Part(text=prompt),
types.Part(inline_data=types.Blob(data=data, mime_type=file.content_type))
])
)
txt = resp.candidates[0].content.parts[0].text
sid = process_transcription(txt)
return {"session_id": sid}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class QueryIn(BaseModel):
session_id: str
query: str
@router.post("/vid_query")
async def query_rag(body: QueryIn):
sess = sessions.get(body.session_id)
if not sess:
raise HTTPException(status_code=404, detail="Session not found")
chain = create_chain(sess["retriever"])
result = chain.invoke({
"question": body.query,
"chat_history": sess["history"]
})
answer = result.get("answer", "I don't know.")
# update history
sess["history"].append((body.query, answer))
# collect source snippets
docs = result.get("source_documents") or []
srcs = [getattr(d, "page_content", str(d)) for d in docs]
return {"answer": answer, "source_documents": srcs}
|