Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
from lutils import writef | |
from tqdm import tqdm | |
from src.test.utils_chapters import extract_chapters, filter_chapters | |
from src.utils import RankedLogger | |
log = RankedLogger(__name__, rank_zero_only=True) | |
def get_chapters( | |
inference, | |
prompt, | |
max_new_tokens, | |
do_sample=False, | |
vid_duration=None, | |
use_cache=True, | |
vid_id="", | |
): | |
output_text = inference( | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
add_special_tokens=True, | |
do_sample=do_sample, | |
use_cache=use_cache, | |
) | |
if isinstance(output_text, int): | |
# the input is too long, return the length of the input | |
return output_text, None | |
chapters = extract_chapters(output_text) | |
chapters = filter_chapters(chapters, vid_duration=vid_duration) | |
if not chapters and not do_sample: | |
log.info(f"No chapters found for {vid_id}, trying again with sampling") | |
return get_chapters( | |
inference, | |
prompt, | |
max_new_tokens, | |
do_sample=True, | |
vid_duration=vid_duration, | |
) | |
return output_text, chapters | |
class VidChaptersTester: | |
def __init__(self, save_dir: str, do_sample=False, **kwargs): | |
self.save_dir = Path(save_dir) | |
self.save_dir.mkdir(exist_ok=True) | |
self.do_sample = do_sample | |
def __call__( | |
self, | |
inference, | |
test_dataloader, | |
max_new_tokens=1024, | |
): | |
pbar = tqdm( | |
total=len(test_dataloader), | |
desc="Evaluating chapters", | |
) | |
for batch in test_dataloader: | |
vid_id = batch["vid_id"][0] | |
prompt = batch["prompt"][0] | |
transcript = batch["transcript"][0] | |
vid_duration = batch["vid_duration"][0] | |
prompt += transcript | |
chapters_pth = self.save_dir / f"{vid_id[:2]}" / f"{vid_id}.json" | |
chapters_pth.parent.mkdir(exist_ok=True) | |
if chapters_pth.exists(): | |
pbar.update(1) | |
continue | |
pbar.set_description(f"vid_id: {vid_id}") | |
output_text, chapters = get_chapters( | |
inference, | |
prompt, | |
max_new_tokens, | |
do_sample=self.do_sample, | |
vid_duration=vid_duration, | |
vid_id=vid_id, | |
) | |
if chapters is None: | |
log.info(f"Input too long for {vid_id}, {output_text} tokens") | |
error_pth = chapters_pth.with_suffix(".txt") | |
writef(error_pth, [output_text]) | |
pbar.update(1) | |
continue | |
if chapters: | |
vid_data = { | |
"chapters": chapters, | |
"output": output_text, | |
} | |
writef(chapters_pth, vid_data) | |
pbar.update(1) | |
pbar.close() | |