Spaces:
Running
Running
import gradio as gr | |
from llama_cpp import Llama | |
from qdrant_client import QdrantClient | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import cv2 | |
import os | |
import tempfile | |
import uuid | |
import re | |
import subprocess | |
import time | |
# Configuration | |
QDRANT_COLLECTION_NAME = "video_frames" | |
VIDEO_SEGMENT_DURATION = 60 | |
# Load Qdrant key | |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
if not QDRANT_API_KEY: | |
print("Error: QDRANT_API_KEY environment variable not found.") | |
print("Please add your Qdrant API key as a secret named 'QDRANT_API_KEY' in your Hugging Face Space settings.") | |
raise ValueError("QDRANT_API_KEY environment variable not set.") | |
print("Initializing LLM...") | |
try: | |
llm = Llama.from_pretrained( | |
repo_id="m1tch/gemma-finetune-ai_class_gguf", | |
filename="gemma-3_ai_class.Q8_0.gguf", | |
n_gpu_layers=-1, | |
n_ctx=2048, | |
verbose=False | |
) | |
print("LLM initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing LLM: {e}") | |
raise | |
print("Connecting to Qdrant...") | |
try: | |
qdrant_client = QdrantClient( | |
url="https://2c18d413-cbb5-441c-b060-4c8c2302dcde.us-east4-0.gcp.cloud.qdrant.io:6333/", | |
api_key=QDRANT_API_KEY, | |
timeout=60 | |
) | |
qdrant_client.get_collections() | |
print("Qdrant connection successful.") | |
except Exception as e: | |
print(f"Error connecting to Qdrant: {e}") | |
raise | |
print("Loading dataset stream...") | |
try: | |
# Load video dataset | |
dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True) | |
print(f"Dataset loaded. First item example: {next(iter(dataset))['__key__']}") | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
raise | |
try: | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
print("Sentence Transformer model loaded.") | |
except Exception as e: | |
print(f"Error loading Sentence Transformer model: {e}") | |
raise | |
def rag_query(client, collection_name, query_text, top_k=5, filter_condition=None): | |
""" | |
Test RAG by querying the vector database with text. Returns a dictionary with search results and metadata. | |
Uses the pre-loaded embedding_model. | |
""" | |
try: | |
query_vector = embedding_model.encode(query_text).tolist() | |
search_params = { | |
"collection_name": collection_name, | |
"query_vector": query_vector, | |
"limit": top_k, | |
"with_payload": True, | |
"with_vectors": False | |
} | |
if filter_condition: | |
search_params["filter"] = filter_condition | |
search_results = client.search(**search_params) | |
formatted_results = [] | |
for idx, result in enumerate(search_results): | |
formatted_results.append({ | |
"rank": idx + 1, | |
"score": result.score, | |
"video_id": result.payload.get("video_id"), | |
"timestamp": result.payload.get("timestamp"), | |
"subtitle": result.payload.get("subtitle"), | |
"frame_number": result.payload.get("frame_number") | |
}) | |
return { | |
"query": query_text, | |
"results": formatted_results, | |
"avg_score": sum(r.score for r in search_results) / len(search_results) if search_results else 0 | |
} | |
except Exception as e: | |
print(f"Error during RAG query: {e}") | |
return {"error": str(e), "query": query_text, "results": []} | |
def extract_video_segment(video_id, start_time, duration, dataset): | |
""" | |
Generator function that extracts and yields a single video segment file path. | |
Modified to return a single path suitable for Gradio. | |
""" | |
target_id = str(video_id) | |
target_key = f"videos/{target_id}/{target_id}" | |
start_time = float(start_time) | |
duration = float(duration) | |
unique_id = str(uuid.uuid4()) | |
temp_dir = os.path.join(tempfile.gettempdir(), f"gradio_video_{unique_id}") | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_video_path = os.path.join(temp_dir, f"{target_id}_full_{unique_id}.mp4") | |
output_path_opencv = os.path.join(temp_dir, f"output_opencv_{unique_id}.mp4") | |
output_path_ffmpeg = os.path.join(temp_dir, f"output_ffmpeg_{unique_id}.mp4") | |
print(f"Attempting to extract segment for video_id={target_id}, start={start_time}, duration={duration}") | |
print(f"Looking for dataset key: {target_key}") | |
print(f"Temporary directory: {temp_dir}") | |
try: | |
found = False | |
retries = 3 | |
dataset_iterator = iter(dataset) | |
for _ in range(retries * 100): | |
try: | |
sample = next(dataset_iterator) | |
if '__key__' in sample and sample['__key__'] == target_key: | |
found = True | |
print(f"Found video key {target_key}. Saving to {temp_video_path}...") | |
with open(temp_video_path, 'wb') as f: | |
f.write(sample['mp4']) | |
print(f"Video saved successfully ({os.path.getsize(temp_video_path)} bytes).") | |
break | |
except StopIteration: | |
print("Reached end of dataset stream without finding the video.") | |
break | |
except Exception as e: | |
print(f"Error iterating dataset: {e}") | |
time.sleep(1) | |
if not found: | |
print(f"Could not find video with ID {target_id} (key: {target_key}) in the dataset stream after {_ + 1} attempts.") | |
return None | |
# Process the saved video | |
if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: | |
print(f"Temporary video file {temp_video_path} is missing or empty.") | |
return None | |
cap = cv2.VideoCapture(temp_video_path) | |
if not cap.isOpened(): | |
print(f"Error opening video file with OpenCV: {temp_video_path}") | |
return None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0: | |
print(f"Warning: Invalid FPS ({fps}) detected for {temp_video_path}. Assuming 30 FPS.") | |
fps = 30 | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_vid_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
vid_duration = total_vid_frames / fps if fps > 0 else 0 | |
print(f"Video properties: {width}x{height} @ {fps:.2f}fps, Total Duration: {vid_duration:.2f}s") | |
start_frame = int(start_time * fps) | |
end_frame = int((start_time + duration) * fps) | |
# Clamp frame numbers to valid range | |
start_frame = max(0, start_frame) | |
end_frame = min(total_vid_frames, end_frame) | |
if start_frame >= total_vid_frames or start_frame >= end_frame: | |
print(f"Calculated start frame ({start_frame}) is beyond video length ({total_vid_frames}) or segment is invalid.") | |
cap.release() | |
return None | |
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
frames_to_write = end_frame - start_frame | |
print(f"Extracting frames from {start_frame} to {end_frame} ({frames_to_write} frames)") | |
# Try OpenCV first | |
fourcc_opencv = cv2.VideoWriter_fourcc(*'mp4v') # mp4v is often more compatible than avc1 with base OpenCV | |
out_opencv = cv2.VideoWriter(output_path_opencv, fourcc_opencv, fps, (width, height)) | |
if not out_opencv.isOpened(): | |
print("Error opening OpenCV VideoWriter with mp4v.") | |
cap.release() | |
return None | |
frames_written_opencv = 0 | |
while frames_written_opencv < frames_to_write: | |
ret, frame = cap.read() | |
if not ret: | |
print("Warning: Ran out of frames before reaching target end frame.") | |
break | |
out_opencv.write(frame) | |
frames_written_opencv += 1 | |
out_opencv.release() | |
print(f"OpenCV finished writing {frames_written_opencv} frames to {output_path_opencv}") | |
cap.release() | |
# FFmpeg | |
final_output_path = None | |
try: | |
cmd = [ | |
'ffmpeg', | |
'-ss', str(start_time), # Start time | |
'-i', temp_video_path, # Input file (original downloaded) | |
'-t', str(duration), # Duration of the segment | |
'-c:v', 'libx264', | |
'-profile:v', 'baseline', | |
'-level', '3.0', | |
'-preset', 'fast', | |
'-pix_fmt', 'yuv420p', | |
'-movflags', '+faststart', | |
'-c:a', 'aac', | |
'-b:a', '128k', | |
'-y', | |
output_path_ffmpeg | |
] | |
print(f"Running FFmpeg command: {' '.join(cmd)}") | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) # Add timeout | |
if result.returncode == 0 and os.path.exists(output_path_ffmpeg) and os.path.getsize(output_path_ffmpeg) > 0: | |
print(f"FFmpeg processing successful. Output: {output_path_ffmpeg}") | |
final_output_path = output_path_ffmpeg | |
else: | |
print(f"FFmpeg error (Return Code: {result.returncode}):") | |
print(f"FFmpeg stdout:\n{result.stdout}") | |
print(f"FFmpeg stderr:\n{result.stderr}") | |
print("Falling back to OpenCV output.") | |
if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
final_output_path = output_path_opencv | |
else: | |
print("OpenCV output is also invalid or empty.") | |
final_output_path = None | |
except subprocess.TimeoutExpired: | |
print("FFmpeg command timed out.") | |
print("Falling back to OpenCV output.") | |
if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
final_output_path = output_path_opencv | |
else: | |
print("OpenCV output is also invalid or empty.") | |
final_output_path = None | |
except FileNotFoundError: | |
print("Error: ffmpeg command not found. Make sure FFmpeg is installed and in your system's PATH.") | |
print("Falling back to OpenCV output.") | |
if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
final_output_path = output_path_opencv | |
else: | |
print("OpenCV output is also invalid or empty.") | |
final_output_path = None | |
except Exception as e: | |
print(f"An unexpected error occurred during FFmpeg processing: {e}") | |
print("Falling back to OpenCV output.") | |
if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
final_output_path = output_path_opencv | |
else: | |
print("OpenCV output is also invalid or empty.") | |
final_output_path = None | |
if os.path.exists(temp_video_path): | |
try: | |
os.remove(temp_video_path) | |
print(f"Cleaned up temporary full video: {temp_video_path}") | |
except Exception as e: | |
print(f"Warning: Could not remove temporary file {temp_video_path}: {e}") | |
# If FFmpeg failed | |
if final_output_path != output_path_ffmpeg and os.path.exists(output_path_ffmpeg): | |
try: | |
os.remove(output_path_ffmpeg) | |
except Exception as e: | |
print(f"Warning: Could not remove failed ffmpeg output {output_path_ffmpeg}: {e}") | |
print(f"Returning video segment path: {final_output_path}") | |
return final_output_path | |
except Exception as e: | |
print(f"Error processing video segment for {video_id}: {e}") | |
import traceback | |
traceback.print_exc() | |
if 'cap' in locals() and cap.isOpened(): cap.release() | |
if 'out_opencv' in locals() and out_opencv.isOpened(): out_opencv.release() | |
if os.path.exists(temp_video_path): os.remove(temp_video_path) | |
if os.path.exists(output_path_opencv): os.remove(output_path_opencv) | |
if os.path.exists(output_path_ffmpeg): os.remove(output_path_ffmpeg) | |
return None | |
QDRANT_COLLECTION_NAME = "video_frames" | |
VIDEO_SEGMENT_DURATION = 40 # Extract 40 seconds around the timestamp | |
def parse_llm_output(text): | |
""" | |
Parses the LLM's structured output using a mix of regex for simple | |
fields (video_id, timestamp) and string manipulation for reasoning | |
as a workaround for regex matching issues. | |
""" | |
data = {} | |
# Parse video_id and timestamp with regex | |
simple_patterns = { | |
'video_id': r"\{Best Result:\s*\[?([^\]\}]+)\]?\s*\}", | |
'timestamp': r"\{Timestamp:\s*\[?([^\]\}]+)\]?\s*\}", | |
} | |
for key, pattern in simple_patterns.items(): | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
value = match.group(1).strip() | |
value = value.strip('\'"ββ') | |
data[key] = value | |
else: | |
print(f"Warning: Could not parse '{key}' using regex pattern: {pattern}") | |
data[key] = None | |
# Parse reasoning | |
reasoning_value = None | |
try: | |
key_marker_lower = "{reasoning:" | |
start_index = text.lower().find(key_marker_lower) | |
if start_index != -1: | |
search_start_for_brace = start_index + len(key_marker_lower) | |
end_index = text.find('}', search_start_for_brace) | |
if end_index != -1: | |
actual_marker_end = start_index + len(key_marker_lower) | |
value = text[actual_marker_end : end_index] | |
value = value.strip() | |
if value.startswith('[') and value.endswith(']'): | |
value = value[1:-1] | |
value = value.strip('\'"ββ') | |
value = value.strip() | |
reasoning_value = value | |
else: | |
print("Warning: Found '{reasoning:' marker but no closing '}' found afterwards.") | |
else: | |
print("Warning: Marker '{reasoning:' not found in text.") | |
except Exception as e: | |
print(f"Error during string manipulation parsing for reasoning: {e}") | |
data['reasoning'] = reasoning_value | |
if data.get('timestamp'): | |
try: | |
float(data['timestamp']) | |
except ValueError: | |
print(f"Warning: Parsed timestamp '{data['timestamp']}' is not a valid number.") | |
print(f"Parsed LLM output (Using String Manipulation for Reasoning): {data}") | |
return data | |
def process_query_and_get_video(query_text): | |
""" | |
Orchestrates RAG, LLM query, parsing, and video extraction. | |
""" | |
print(f"\n--- Processing query: '{query_text}' ---") | |
# 1. RAG Query | |
print("Step 1: Performing RAG query...") | |
rag_results = rag_query(qdrant_client, QDRANT_COLLECTION_NAME, query_text) | |
if "error" in rag_results or not rag_results.get("results"): | |
error_msg = rag_results.get('error', 'No relevant segments found by RAG.') | |
print(f"RAG Error/No Results: {error_msg}") | |
return f"Error during RAG search: {error_msg}", None | |
print(f"RAG query successful. Found {len(rag_results['results'])} results.") | |
# Format LLM Prompt | |
print("Step 2: Formatting prompt for LLM...") | |
prompt = f"""You are tasked with selecting the most relevant information from a set of video subtitle segments to answer a query. | |
QUERY (also seen below): "{query_text}" | |
For each result provided, evaluate how well it directly addresses the definition or explanation related to the query. Pay attention to: | |
1. Clarity of explanation | |
2. Relevance to the query | |
3. Completeness of information | |
From the provided results, select the SINGLE BEST match that most directly answers the query. | |
Format your response STRICTLY as follows, with each field on a new line: | |
{{Best Result: [video_id]}} | |
{{Timestamp: [timestamp]}} | |
{{Content: [subtitle text]}} | |
{{Reasoning: [Brief explanation of why this result best answers the query]}} | |
{rag_results}""" | |
# 3. Call LLM | |
print("Step 3: Querying the LLM...") | |
try: | |
output = llm.create_chat_completion( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant designed to select the best video segment based on relevance to a query, following a specific output format."}, | |
{"role": "user", "content": prompt}, | |
], | |
temperature=0.1, | |
max_tokens=300 | |
) | |
llm_response_text = output['choices'][0]['message']['content'] | |
print(f"LLM Response:\n{llm_response_text}") | |
except Exception as e: | |
print(f"Error during LLM call: {e}") | |
return f"Error calling LLM: {e}", None | |
# 4. Parse LLM Response | |
print("Step 4: Parsing LLM response...") | |
parsed_data = parse_llm_output(llm_response_text) | |
video_id = parsed_data.get('video_id') | |
timestamp_str = parsed_data.get('timestamp') | |
reasoning = parsed_data.get('reasoning') | |
if not video_id or not timestamp_str: | |
print("Error: Could not parse required video_id or timestamp from LLM response.") | |
fallback_reasoning = reasoning if reasoning else "Could not determine the best segment." | |
error_msg = f"Failed to parse LLM response. LLM said:\n---\n{llm_response_text}\n---\nReasoning (if found): {fallback_reasoning}" | |
return error_msg, None | |
try: | |
timestamp = float(timestamp_str) | |
# Adjust timestamp slightly - start a bit earlier if possible | |
start_time = max(0.0, timestamp - (VIDEO_SEGMENT_DURATION / 4)) | |
except ValueError: | |
print(f"Error: Could not convert parsed timestamp '{timestamp_str}' to float.") | |
error_msg = f"Invalid timestamp format from LLM ('{timestamp_str}'). LLM reasoning (if found): {reasoning}" | |
return error_msg, None | |
final_reasoning = reasoning if reasoning else "No reasoning provided by LLM." | |
# Extract Video Segment | |
print(f"Step 5: Extracting video segment (ID: {video_id}, Start: {start_time:.2f}s, Duration: {VIDEO_SEGMENT_DURATION}s)...") | |
global dataset | |
video_path = extract_video_segment(video_id, start_time, VIDEO_SEGMENT_DURATION, dataset) | |
if video_path and os.path.exists(video_path): | |
print(f"Video segment extracted successfully: {video_path}") | |
return final_reasoning, video_path | |
else: | |
print("Failed to extract video segment.") | |
error_msg = f"{final_reasoning}\n\n(However, failed to extract the corresponding video segment for ID {video_id} at timestamp {timestamp_str}.)" | |
return error_msg, None | |
with gr.Blocks() as iface: | |
gr.Markdown( | |
""" | |
# Lecture Videos Q&A | |
Ask a question about the lectures. The system will find relevant segments using RAG | |
and a fine-tuned LLM to select the best one, and display the corresponding video clip. | |
""" | |
) | |
with gr.Row(): | |
query_input = gr.Textbox(label="Your Question", placeholder="Using only the videos, explain how ResNets work.") | |
submit_button = gr.Button("Ask & Find Video") | |
with gr.Row(): | |
reasoning_output = gr.Markdown(label="LLM Reasoning") | |
with gr.Row(): | |
video_output = gr.Video(label="Relevant Video Segment") | |
submit_button.click( | |
fn=process_query_and_get_video, | |
inputs=query_input, | |
outputs=[reasoning_output, video_output] | |
) | |
gr.Examples( | |
examples=[ | |
"Using only the videos, explain how ResNets work.", | |
"Using only the videos, explain the advantages of CNNs over fully connected networks.", | |
"Using only the videos, explain the the binary cross entropy loss function.", | |
], | |
inputs=query_input, | |
outputs=[reasoning_output, video_output], | |
fn=process_query_and_get_video, | |
cache_examples=False, | |
) | |
print("Launching Gradio interface...") | |
iface.launch(debug=True, share=False) |