import open_clip import torch import tempfile import subprocess import os from datetime import datetime from collections import defaultdict from datasets import load_dataset from qdrant_client import QdrantClient from huggingface_hub import login tokenizer = open_clip.get_tokenizer("ViT-B-32") model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained=None) checkpoint_path = "finetuned_clip.pt" model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) model.eval() qdrant = QdrantClient(url=os.environ["QDRANT_CLOUD_URL"], api_key=os.environ["QDRANT_API_KEY"], prefer_grpc=False) collection_name = "video_chunks" login(token=os.environ["HF_API_TOKEN"]) def timestamp_to_seconds(ts): h, m, s = ts.split(":") return int(h) * 3600 + int(m) * 60 + float(s) def seconds_to_timestamp(seconds): h = int(seconds // 3600) m = int((seconds % 3600) // 60) s = seconds % 60 return f"{h:02}:{m:02}:{s:06.3f}" def smart_merge_subtitles(a, b): if b in a: return a if a in b: return b for i in range(min(len(a), len(b)), 0, -1): if a.endswith(b[:i]): return a + b[i:] if b.endswith(a[:i]): return b + a[i:] return a + " " + b def merge_chunks(chunks): grouped = defaultdict(list) for chunk in chunks: payload = chunk.payload grouped[payload["video_id"]].append({ "start": timestamp_to_seconds(payload["start_time"]), "end": timestamp_to_seconds(payload["end_time"]), "subtitle": payload["subtitle"].strip(), "merged": False }) merged_chunks = [] for video_id, video_chunks in grouped.items(): for i, chunk in enumerate(video_chunks): if chunk["merged"]: continue merged_chunk = chunk.copy() chunk["merged"] = True for j, other in enumerate(video_chunks): if i == j or other["merged"]: continue if not (merged_chunk["end"] < other["start"] or other["end"] < merged_chunk["start"]): merged_chunk["start"] = min(merged_chunk["start"], other["start"]) merged_chunk["end"] = max(merged_chunk["end"], other["end"]) merged_chunk["subtitle"] = smart_merge_subtitles(merged_chunk["subtitle"], other["subtitle"]) other["merged"] = True merged_chunks.append({ "video_id": video_id, "start_time": seconds_to_timestamp(merged_chunk["start"]), "end_time": seconds_to_timestamp(merged_chunk["end"]), "subtitle": merged_chunk["subtitle"] }) return merged_chunks def get_video_segment(video_id, start_time, end_time): dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True) for sample in dataset: if sample["__key__"] == video_id: break tmp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") full_path = os.path.join(tmp_dir, f"full_{timestamp}.mp4") trimmed_path = os.path.join(tmp_dir, f"clip_{timestamp}.mp4") with open(full_path, "wb") as f: f.write(sample["mp4"]) cmd = [ "ffmpeg", "-ss", start_time, "-to", end_time, "-i", full_path, "-c:v", "copy", "-c:a", "copy", "-y", trimmed_path ] result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) os.remove(full_path) if result.returncode != 0: print("FFmpeg failed") return None return trimmed_path def retrieval(question): text_tokens = tokenizer([question]) with torch.no_grad(): query_vec = model.encode_text(text_tokens).squeeze(0).cpu().numpy() search_result = qdrant.search( collection_name=collection_name, query_vector=query_vec.tolist(), limit=40 ) filtered_results = [ res for res in search_result if len(res.payload.get("subtitle", "")) >= 35 ] return filtered_results[:10]