Spaces:
Running
Running
File size: 4,164 Bytes
6d30351 afea1c5 6d30351 ebcd31f 1b1c5b7 6d30351 afea1c5 6d30351 afea1c5 6d30351 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import open_clip
import torch
import tempfile
import subprocess
import os
from datetime import datetime
from collections import defaultdict
from datasets import load_dataset
from qdrant_client import QdrantClient
from huggingface_hub import login
tokenizer = open_clip.get_tokenizer("ViT-B-32")
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained=None)
checkpoint_path = "finetuned_clip.pt"
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
qdrant = QdrantClient(url=os.environ["QDRANT_CLOUD_URL"],
api_key=os.environ["QDRANT_API_KEY"],
prefer_grpc=False)
collection_name = "video_chunks"
login(token=os.environ["HF_API_TOKEN"])
def timestamp_to_seconds(ts):
h, m, s = ts.split(":")
return int(h) * 3600 + int(m) * 60 + float(s)
def seconds_to_timestamp(seconds):
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = seconds % 60
return f"{h:02}:{m:02}:{s:06.3f}"
def smart_merge_subtitles(a, b):
if b in a:
return a
if a in b:
return b
for i in range(min(len(a), len(b)), 0, -1):
if a.endswith(b[:i]):
return a + b[i:]
if b.endswith(a[:i]):
return b + a[i:]
return a + " " + b
def merge_chunks(chunks):
grouped = defaultdict(list)
for chunk in chunks:
payload = chunk.payload
grouped[payload["video_id"]].append({
"start": timestamp_to_seconds(payload["start_time"]),
"end": timestamp_to_seconds(payload["end_time"]),
"subtitle": payload["subtitle"].strip(),
"merged": False
})
merged_chunks = []
for video_id, video_chunks in grouped.items():
for i, chunk in enumerate(video_chunks):
if chunk["merged"]:
continue
merged_chunk = chunk.copy()
chunk["merged"] = True
for j, other in enumerate(video_chunks):
if i == j or other["merged"]:
continue
if not (merged_chunk["end"] < other["start"] or other["end"] < merged_chunk["start"]):
merged_chunk["start"] = min(merged_chunk["start"], other["start"])
merged_chunk["end"] = max(merged_chunk["end"], other["end"])
merged_chunk["subtitle"] = smart_merge_subtitles(merged_chunk["subtitle"], other["subtitle"])
other["merged"] = True
merged_chunks.append({
"video_id": video_id,
"start_time": seconds_to_timestamp(merged_chunk["start"]),
"end_time": seconds_to_timestamp(merged_chunk["end"]),
"subtitle": merged_chunk["subtitle"]
})
return merged_chunks
def get_video_segment(video_id, start_time, end_time):
dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True)
for sample in dataset:
if sample["__key__"] == video_id:
break
tmp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
full_path = os.path.join(tmp_dir, f"full_{timestamp}.mp4")
trimmed_path = os.path.join(tmp_dir, f"clip_{timestamp}.mp4")
with open(full_path, "wb") as f:
f.write(sample["mp4"])
cmd = [
"ffmpeg", "-ss", start_time, "-to", end_time,
"-i", full_path,
"-c:v", "copy", "-c:a", "copy", "-y",
trimmed_path
]
result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
os.remove(full_path)
if result.returncode != 0:
print("FFmpeg failed")
return None
return trimmed_path
def retrieval(question):
text_tokens = tokenizer([question])
with torch.no_grad():
query_vec = model.encode_text(text_tokens).squeeze(0).cpu().numpy()
search_result = qdrant.search(
collection_name=collection_name,
query_vector=query_vec.tolist(),
limit=40
)
filtered_results = [
res for res in search_result
if len(res.payload.get("subtitle", "")) >= 35
]
return filtered_results[:10] |