Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,308 Bytes
90559ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from lutils import openf, writef
from src.data.chapters import Chapters, sec_to_hms
from src.data.prompt import Prompt
from src.utils import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
class ChaptersASR(Chapters):
def __init__(self, vidc_dir: str = "dataset/", subset=""):
super().__init__(vidc_dir=vidc_dir, subset=subset)
self._asrs = None
@property
def asrs(self):
if self._asrs is None:
self.load_asr_data()
return self._asrs
def load_asr_data(self):
if self._asrs is not None:
return
if self.subset:
asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json"
if asr_pth.exists():
self._asrs = openf(asr_pth)
else:
log.info(f"ASR data not found for subset {self.subset}.")
asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json"
asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json"
if "val" in self.subset and asr_val_pth.exists():
log.info("Loading from ASR validation file.")
asrs = openf(asr_val_pth)
elif "train" in self.subset and asr_train_pth.exists():
log.info("Loading from ASR training file.")
asrs = openf(asr_train_pth)
else:
log.info("Loading from ASR file.")
asrs = openf(self.vidc_dir / "docs/asrs.json")
video_ids = set(self.video_ids) & set(asrs.keys())
self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids}
asr_pth.parent.mkdir(exist_ok=True)
writef(asr_pth, self._asrs)
else:
self._asrs = openf(self.vidc_dir / "docs/asrs.json")
def get_asr(self, video_id, add_end=False):
if video_id not in self.asrs:
return None
asr = self.asrs[video_id]
asr_clean = []
for t, s, e in zip(asr["text"], asr["start"], asr["end"]):
t = t.strip()
s = sec_to_hms(s)
e = sec_to_hms(e)
if add_end:
asr_clean.append(f"{s} - {e}: {t}")
else:
asr_clean.append(f"{s}: {t}")
return "\n".join(asr_clean) + "\n"
def __contains__(self, vid_id):
return vid_id in self.asrs
class PromptASR(Prompt):
def __init__(self, chapters: ChaptersASR, add_end=False):
super().__init__(chapters=chapters)
self.add_end = add_end
def get_task_prompt(self):
return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n"
def get_transcript(self, vid_id):
vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end)
assert vid_asr is not None, f"ASR not found for video ID: {vid_id}"
return vid_asr
def __contains__(self, vid_id):
return vid_id in self.chapters
if __name__ == "__main__":
chapters = ChaptersASR(subset="s10k_train")
vid_id = chapters.sample()
prompt = PromptASR(chapters=chapters)
print(prompt.get_prompt_train(vid_id))
print(prompt.get_transcript(vid_id))
print(prompt.get_output(vid_id))
|