File size: 2,914 Bytes
90559ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pathlib import Path

from lutils import writef
from tqdm import tqdm

from src.test.utils_chapters import extract_chapters, filter_chapters
from src.utils import RankedLogger

log = RankedLogger(__name__, rank_zero_only=True)


def get_chapters(
    inference,
    prompt,
    max_new_tokens,
    do_sample=False,
    vid_duration=None,
    use_cache=True,
    vid_id="",
):
    output_text = inference(
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        add_special_tokens=True,
        do_sample=do_sample,
        use_cache=use_cache,
    )

    if isinstance(output_text, int):
        # the input is too long, return the length of the input
        return output_text, None

    chapters = extract_chapters(output_text)
    chapters = filter_chapters(chapters, vid_duration=vid_duration)

    if not chapters and not do_sample:
        log.info(f"No chapters found for {vid_id}, trying again with sampling")
        return get_chapters(
            inference,
            prompt,
            max_new_tokens,
            do_sample=True,
            vid_duration=vid_duration,
        )

    return output_text, chapters


class VidChaptersTester:
    def __init__(self, save_dir: str, do_sample=False, **kwargs):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.do_sample = do_sample

    def __call__(
        self,
        inference,
        test_dataloader,
        max_new_tokens=1024,
    ):
        pbar = tqdm(
            total=len(test_dataloader),
            desc="Evaluating chapters",
        )

        for batch in test_dataloader:
            vid_id = batch["vid_id"][0]
            prompt = batch["prompt"][0]
            transcript = batch["transcript"][0]
            vid_duration = batch["vid_duration"][0]
            prompt += transcript

            chapters_pth = self.save_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
            chapters_pth.parent.mkdir(exist_ok=True)

            if chapters_pth.exists():
                pbar.update(1)
                continue

            pbar.set_description(f"vid_id: {vid_id}")

            output_text, chapters = get_chapters(
                inference,
                prompt,
                max_new_tokens,
                do_sample=self.do_sample,
                vid_duration=vid_duration,
                vid_id=vid_id,
            )

            if chapters is None:
                log.info(f"Input too long for {vid_id}, {output_text} tokens")
                error_pth = chapters_pth.with_suffix(".txt")
                writef(error_pth, [output_text])
                pbar.update(1)
                continue

            if chapters:
                vid_data = {
                    "chapters": chapters,
                    "output": output_text,
                }
                writef(chapters_pth, vid_data)

            pbar.update(1)

        pbar.close()