Spaces:
Running
on
Zero
Running
on
Zero
from lutils import openf, writef | |
from src.data.chapters import Chapters, sec_to_hms | |
from src.data.prompt import Prompt | |
from src.utils import RankedLogger | |
log = RankedLogger(__name__, rank_zero_only=True) | |
class ChaptersASR(Chapters): | |
def __init__(self, vidc_dir: str = "dataset/", subset=""): | |
super().__init__(vidc_dir=vidc_dir, subset=subset) | |
self._asrs = None | |
def asrs(self): | |
if self._asrs is None: | |
self.load_asr_data() | |
return self._asrs | |
def load_asr_data(self): | |
if self._asrs is not None: | |
return | |
if self.subset: | |
asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json" | |
if asr_pth.exists(): | |
self._asrs = openf(asr_pth) | |
else: | |
log.info(f"ASR data not found for subset {self.subset}.") | |
asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json" | |
asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json" | |
if "val" in self.subset and asr_val_pth.exists(): | |
log.info("Loading from ASR validation file.") | |
asrs = openf(asr_val_pth) | |
elif "train" in self.subset and asr_train_pth.exists(): | |
log.info("Loading from ASR training file.") | |
asrs = openf(asr_train_pth) | |
else: | |
log.info("Loading from ASR file.") | |
asrs = openf(self.vidc_dir / "docs/asrs.json") | |
video_ids = set(self.video_ids) & set(asrs.keys()) | |
self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids} | |
asr_pth.parent.mkdir(exist_ok=True) | |
writef(asr_pth, self._asrs) | |
else: | |
self._asrs = openf(self.vidc_dir / "docs/asrs.json") | |
def get_asr(self, video_id, add_end=False): | |
if video_id not in self.asrs: | |
return None | |
asr = self.asrs[video_id] | |
asr_clean = [] | |
for t, s, e in zip(asr["text"], asr["start"], asr["end"]): | |
t = t.strip() | |
s = sec_to_hms(s) | |
e = sec_to_hms(e) | |
if add_end: | |
asr_clean.append(f"{s} - {e}: {t}") | |
else: | |
asr_clean.append(f"{s}: {t}") | |
return "\n".join(asr_clean) + "\n" | |
def __contains__(self, vid_id): | |
return vid_id in self.asrs | |
class PromptASR(Prompt): | |
def __init__(self, chapters: ChaptersASR, add_end=False): | |
super().__init__(chapters=chapters) | |
self.add_end = add_end | |
def get_task_prompt(self): | |
return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n" | |
def get_transcript(self, vid_id): | |
vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end) | |
assert vid_asr is not None, f"ASR not found for video ID: {vid_id}" | |
return vid_asr | |
def __contains__(self, vid_id): | |
return vid_id in self.chapters | |
if __name__ == "__main__": | |
chapters = ChaptersASR(subset="s10k_train") | |
vid_id = chapters.sample() | |
prompt = PromptASR(chapters=chapters) | |
print(prompt.get_prompt_train(vid_id)) | |
print(prompt.get_transcript(vid_id)) | |
print(prompt.get_output(vid_id)) | |