chapter-llama / src /data /utils_asr.py
lucas-ventura's picture
Rename utils_asr.py to src/data/utils_asr.py
1f3866f verified
from lutils import openf, writef
from src.data.chapters import Chapters, sec_to_hms
from src.data.prompt import Prompt
from src.utils import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
class ChaptersASR(Chapters):
def __init__(self, vidc_dir: str = "dataset/", subset=""):
super().__init__(vidc_dir=vidc_dir, subset=subset)
self._asrs = None
@property
def asrs(self):
if self._asrs is None:
self.load_asr_data()
return self._asrs
def load_asr_data(self):
if self._asrs is not None:
return
if self.subset:
asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json"
if asr_pth.exists():
self._asrs = openf(asr_pth)
else:
log.info(f"ASR data not found for subset {self.subset}.")
asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json"
asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json"
if "val" in self.subset and asr_val_pth.exists():
log.info("Loading from ASR validation file.")
asrs = openf(asr_val_pth)
elif "train" in self.subset and asr_train_pth.exists():
log.info("Loading from ASR training file.")
asrs = openf(asr_train_pth)
else:
log.info("Loading from ASR file.")
asrs = openf(self.vidc_dir / "docs/asrs.json")
video_ids = set(self.video_ids) & set(asrs.keys())
self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids}
asr_pth.parent.mkdir(exist_ok=True)
writef(asr_pth, self._asrs)
else:
self._asrs = openf(self.vidc_dir / "docs/asrs.json")
def get_asr(self, video_id, add_end=False):
if video_id not in self.asrs:
return None
asr = self.asrs[video_id]
asr_clean = []
for t, s, e in zip(asr["text"], asr["start"], asr["end"]):
t = t.strip()
s = sec_to_hms(s)
e = sec_to_hms(e)
if add_end:
asr_clean.append(f"{s} - {e}: {t}")
else:
asr_clean.append(f"{s}: {t}")
return "\n".join(asr_clean) + "\n"
def __contains__(self, vid_id):
return vid_id in self.asrs
class PromptASR(Prompt):
def __init__(self, chapters: ChaptersASR, add_end=False):
super().__init__(chapters=chapters)
self.add_end = add_end
def get_task_prompt(self):
return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n"
def get_transcript(self, vid_id):
vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end)
assert vid_asr is not None, f"ASR not found for video ID: {vid_id}"
return vid_asr
def __contains__(self, vid_id):
return vid_id in self.chapters
if __name__ == "__main__":
chapters = ChaptersASR(subset="s10k_train")
vid_id = chapters.sample()
prompt = PromptASR(chapters=chapters)
print(prompt.get_prompt_train(vid_id))
print(prompt.get_transcript(vid_id))
print(prompt.get_output(vid_id))