Spaces:
Running
Running
import open_clip | |
import torch | |
import tempfile | |
import subprocess | |
import os | |
from datetime import datetime | |
from collections import defaultdict | |
from datasets import load_dataset | |
from qdrant_client import QdrantClient | |
from huggingface_hub import login | |
tokenizer = open_clip.get_tokenizer("ViT-B-32") | |
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained=None) | |
checkpoint_path = "finetuned_clip.pt" | |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
model.eval() | |
qdrant = QdrantClient(url=os.environ["QDRANT_CLOUD_URL"], | |
api_key=os.environ["QDRANT_API_KEY"], | |
prefer_grpc=False) | |
collection_name = "video_chunks" | |
login(token=os.environ["HF_API_TOKEN"]) | |
def timestamp_to_seconds(ts): | |
h, m, s = ts.split(":") | |
return int(h) * 3600 + int(m) * 60 + float(s) | |
def seconds_to_timestamp(seconds): | |
h = int(seconds // 3600) | |
m = int((seconds % 3600) // 60) | |
s = seconds % 60 | |
return f"{h:02}:{m:02}:{s:06.3f}" | |
def smart_merge_subtitles(a, b): | |
if b in a: | |
return a | |
if a in b: | |
return b | |
for i in range(min(len(a), len(b)), 0, -1): | |
if a.endswith(b[:i]): | |
return a + b[i:] | |
if b.endswith(a[:i]): | |
return b + a[i:] | |
return a + " " + b | |
def merge_chunks(chunks): | |
grouped = defaultdict(list) | |
for chunk in chunks: | |
payload = chunk.payload | |
grouped[payload["video_id"]].append({ | |
"start": timestamp_to_seconds(payload["start_time"]), | |
"end": timestamp_to_seconds(payload["end_time"]), | |
"subtitle": payload["subtitle"].strip(), | |
"merged": False | |
}) | |
merged_chunks = [] | |
for video_id, video_chunks in grouped.items(): | |
for i, chunk in enumerate(video_chunks): | |
if chunk["merged"]: | |
continue | |
merged_chunk = chunk.copy() | |
chunk["merged"] = True | |
for j, other in enumerate(video_chunks): | |
if i == j or other["merged"]: | |
continue | |
if not (merged_chunk["end"] < other["start"] or other["end"] < merged_chunk["start"]): | |
merged_chunk["start"] = min(merged_chunk["start"], other["start"]) | |
merged_chunk["end"] = max(merged_chunk["end"], other["end"]) | |
merged_chunk["subtitle"] = smart_merge_subtitles(merged_chunk["subtitle"], other["subtitle"]) | |
other["merged"] = True | |
merged_chunks.append({ | |
"video_id": video_id, | |
"start_time": seconds_to_timestamp(merged_chunk["start"]), | |
"end_time": seconds_to_timestamp(merged_chunk["end"]), | |
"subtitle": merged_chunk["subtitle"] | |
}) | |
return merged_chunks | |
def get_video_segment(video_id, start_time, end_time): | |
dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True) | |
for sample in dataset: | |
if sample["__key__"] == video_id: | |
break | |
tmp_dir = tempfile.gettempdir() | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |
full_path = os.path.join(tmp_dir, f"full_{timestamp}.mp4") | |
trimmed_path = os.path.join(tmp_dir, f"clip_{timestamp}.mp4") | |
with open(full_path, "wb") as f: | |
f.write(sample["mp4"]) | |
cmd = [ | |
"ffmpeg", "-ss", start_time, "-to", end_time, | |
"-i", full_path, | |
"-c:v", "copy", "-c:a", "copy", "-y", | |
trimmed_path | |
] | |
result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
os.remove(full_path) | |
if result.returncode != 0: | |
print("FFmpeg failed") | |
return None | |
return trimmed_path | |
def retrieval(question): | |
text_tokens = tokenizer([question]) | |
with torch.no_grad(): | |
query_vec = model.encode_text(text_tokens).squeeze(0).cpu().numpy() | |
search_result = qdrant.search( | |
collection_name=collection_name, | |
query_vector=query_vec.tolist(), | |
limit=40 | |
) | |
filtered_results = [ | |
res for res in search_result | |
if len(res.payload.get("subtitle", "")) >= 35 | |
] | |
return filtered_results[:10] |