Spaces:
Running
on
Zero
Running
on
Zero
from .base_prompter import BasePrompter | |
from ..models.wan_video_text_encoder import WanTextEncoder | |
from transformers import AutoTokenizer | |
import os, torch | |
import ftfy | |
import html | |
import string | |
import regex as re | |
def basic_clean(text): | |
text = ftfy.fix_text(text) | |
text = html.unescape(html.unescape(text)) | |
return text.strip() | |
def whitespace_clean(text): | |
text = re.sub(r'\s+', ' ', text) | |
text = text.strip() | |
return text | |
def canonicalize(text, keep_punctuation_exact_string=None): | |
text = text.replace('_', ' ') | |
if keep_punctuation_exact_string: | |
text = keep_punctuation_exact_string.join( | |
part.translate(str.maketrans('', '', string.punctuation)) | |
for part in text.split(keep_punctuation_exact_string)) | |
else: | |
text = text.translate(str.maketrans('', '', string.punctuation)) | |
text = text.lower() | |
text = re.sub(r'\s+', ' ', text) | |
return text.strip() | |
class HuggingfaceTokenizer: | |
def __init__(self, name, seq_len=None, clean=None, **kwargs): | |
assert clean in (None, 'whitespace', 'lower', 'canonicalize') | |
self.name = name | |
self.seq_len = seq_len | |
self.clean = clean | |
# init tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) | |
self.vocab_size = self.tokenizer.vocab_size | |
def __call__(self, sequence, **kwargs): | |
return_mask = kwargs.pop('return_mask', False) | |
# arguments | |
_kwargs = {'return_tensors': 'pt'} | |
if self.seq_len is not None: | |
_kwargs.update({ | |
'padding': 'max_length', | |
'truncation': True, | |
'max_length': self.seq_len | |
}) | |
_kwargs.update(**kwargs) | |
# tokenization | |
if isinstance(sequence, str): | |
sequence = [sequence] | |
if self.clean: | |
sequence = [self._clean(u) for u in sequence] | |
ids = self.tokenizer(sequence, **_kwargs) | |
# output | |
if return_mask: | |
return ids.input_ids, ids.attention_mask | |
else: | |
return ids.input_ids | |
def _clean(self, text): | |
if self.clean == 'whitespace': | |
text = whitespace_clean(basic_clean(text)) | |
elif self.clean == 'lower': | |
text = whitespace_clean(basic_clean(text)).lower() | |
elif self.clean == 'canonicalize': | |
text = canonicalize(basic_clean(text)) | |
return text | |
class WanPrompter(BasePrompter): | |
def __init__(self, tokenizer_path=None, text_len=512): | |
super().__init__() | |
self.text_len = text_len | |
self.text_encoder = None | |
self.fetch_tokenizer(tokenizer_path) | |
def fetch_tokenizer(self, tokenizer_path=None): | |
if tokenizer_path is not None: | |
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace') | |
def fetch_models(self, text_encoder: WanTextEncoder = None): | |
self.text_encoder = text_encoder | |
def encode_prompt(self, prompt, positive=True, device="cuda"): | |
prompt = self.process_prompt(prompt, positive=positive) | |
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) | |
ids = ids.to(device) | |
mask = mask.to(device) | |
seq_lens = mask.gt(0).sum(dim=1).long() | |
prompt_emb = self.text_encoder(ids, mask) | |
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)] | |
return prompt_emb | |