File size: 3,308 Bytes
90559ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from lutils import openf, writef

from src.data.chapters import Chapters, sec_to_hms
from src.data.prompt import Prompt
from src.utils import RankedLogger

log = RankedLogger(__name__, rank_zero_only=True)


class ChaptersASR(Chapters):
    def __init__(self, vidc_dir: str = "dataset/", subset=""):
        super().__init__(vidc_dir=vidc_dir, subset=subset)

        self._asrs = None

    @property
    def asrs(self):
        if self._asrs is None:
            self.load_asr_data()
        return self._asrs

    def load_asr_data(self):
        if self._asrs is not None:
            return

        if self.subset:
            asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json"
            if asr_pth.exists():
                self._asrs = openf(asr_pth)
            else:
                log.info(f"ASR data not found for subset {self.subset}.")
                asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json"
                asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json"
                if "val" in self.subset and asr_val_pth.exists():
                    log.info("Loading from ASR validation file.")
                    asrs = openf(asr_val_pth)
                elif "train" in self.subset and asr_train_pth.exists():
                    log.info("Loading from ASR training file.")
                    asrs = openf(asr_train_pth)
                else:
                    log.info("Loading from ASR file.")
                    asrs = openf(self.vidc_dir / "docs/asrs.json")
                video_ids = set(self.video_ids) & set(asrs.keys())
                self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids}
                asr_pth.parent.mkdir(exist_ok=True)
                writef(asr_pth, self._asrs)
        else:
            self._asrs = openf(self.vidc_dir / "docs/asrs.json")

    def get_asr(self, video_id, add_end=False):
        if video_id not in self.asrs:
            return None

        asr = self.asrs[video_id]
        asr_clean = []
        for t, s, e in zip(asr["text"], asr["start"], asr["end"]):
            t = t.strip()
            s = sec_to_hms(s)
            e = sec_to_hms(e)
            if add_end:
                asr_clean.append(f"{s} - {e}: {t}")
            else:
                asr_clean.append(f"{s}: {t}")

        return "\n".join(asr_clean) + "\n"

    def __contains__(self, vid_id):
        return vid_id in self.asrs


class PromptASR(Prompt):
    def __init__(self, chapters: ChaptersASR, add_end=False):
        super().__init__(chapters=chapters)
        self.add_end = add_end

    def get_task_prompt(self):
        return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n"

    def get_transcript(self, vid_id):
        vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end)
        assert vid_asr is not None, f"ASR not found for video ID: {vid_id}"
        return vid_asr

    def __contains__(self, vid_id):
        return vid_id in self.chapters


if __name__ == "__main__":
    chapters = ChaptersASR(subset="s10k_train")
    vid_id = chapters.sample()

    prompt = PromptASR(chapters=chapters)
    print(prompt.get_prompt_train(vid_id))
    print(prompt.get_transcript(vid_id))
    print(prompt.get_output(vid_id))