chapter-llama / src /test /vidchapters.py
lucas-ventura's picture
Rename vidchapters.py to src/test/vidchapters.py
2e23f3d verified
from pathlib import Path
from lutils import writef
from tqdm import tqdm
from src.test.utils_chapters import extract_chapters, filter_chapters
from src.utils import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
def get_chapters(
inference,
prompt,
max_new_tokens,
do_sample=False,
vid_duration=None,
use_cache=True,
vid_id="",
):
output_text = inference(
prompt=prompt,
max_new_tokens=max_new_tokens,
add_special_tokens=True,
do_sample=do_sample,
use_cache=use_cache,
)
if isinstance(output_text, int):
# the input is too long, return the length of the input
return output_text, None
chapters = extract_chapters(output_text)
chapters = filter_chapters(chapters, vid_duration=vid_duration)
if not chapters and not do_sample:
log.info(f"No chapters found for {vid_id}, trying again with sampling")
return get_chapters(
inference,
prompt,
max_new_tokens,
do_sample=True,
vid_duration=vid_duration,
)
return output_text, chapters
class VidChaptersTester:
def __init__(self, save_dir: str, do_sample=False, **kwargs):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(exist_ok=True)
self.do_sample = do_sample
def __call__(
self,
inference,
test_dataloader,
max_new_tokens=1024,
):
pbar = tqdm(
total=len(test_dataloader),
desc="Evaluating chapters",
)
for batch in test_dataloader:
vid_id = batch["vid_id"][0]
prompt = batch["prompt"][0]
transcript = batch["transcript"][0]
vid_duration = batch["vid_duration"][0]
prompt += transcript
chapters_pth = self.save_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
chapters_pth.parent.mkdir(exist_ok=True)
if chapters_pth.exists():
pbar.update(1)
continue
pbar.set_description(f"vid_id: {vid_id}")
output_text, chapters = get_chapters(
inference,
prompt,
max_new_tokens,
do_sample=self.do_sample,
vid_duration=vid_duration,
vid_id=vid_id,
)
if chapters is None:
log.info(f"Input too long for {vid_id}, {output_text} tokens")
error_pth = chapters_pth.with_suffix(".txt")
writef(error_pth, [output_text])
pbar.update(1)
continue
if chapters:
vid_data = {
"chapters": chapters,
"output": output_text,
}
writef(chapters_pth, vid_data)
pbar.update(1)
pbar.close()