Spaces:
Running
Running
import os | |
import gradio as gr | |
from html import escape | |
from transformers import AutoTokenizer | |
def get_available_models() -> list[str]: | |
"""获取models目录下所有包含config.json的模型""" | |
models_dir = "models" | |
if not os.path.exists(models_dir): | |
return [] | |
available_models = [] | |
for model_name in os.listdir(models_dir): | |
model_path = os.path.join(models_dir, model_name) | |
config_file = os.path.join(model_path, "config.json") | |
if os.path.isdir(model_path) and os.path.isfile(config_file): | |
available_models.append(model_name) | |
return sorted(available_models) | |
def tokenize_text( | |
model_name: str, text: str | |
) -> tuple[str | None, str | None, int | None, dict | None, int, int]: | |
"""处理tokenize请求""" | |
if not model_name: | |
return "Please choose a model and input some texts", None, None, None, 0, 0 | |
if not text: | |
text = "Please choose a model and input some texts" | |
try: | |
# 加载tokenizer | |
model_path = os.path.join("models", model_name) | |
if os.path.isdir(model_path): | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, trust_remote_code=True, device_map="cpu" | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, trust_remote_code=True, device_map="cpu" | |
) | |
tokenizer_type = tokenizer.__class__.__name__ | |
if hasattr(tokenizer, "vocab_size"): | |
vocab_size = tokenizer.vocab_size | |
elif hasattr(tokenizer, "get_vocab"): | |
vocab_size = len(tokenizer.get_vocab()) | |
else: | |
vocab_size = -1 | |
sp_token_list = [ | |
"pad_token", | |
"eos_token", | |
"bos_token", | |
"sep_token", | |
"cls_token", | |
"unk_token", | |
"mask_token", | |
"image_token", | |
"audio_token", | |
"video_token", | |
"vision_bos_token", | |
"vision_eos_token", | |
"audio_bos_token", | |
"audio_eos_token", | |
] | |
special_tokens = {} | |
for token_name in sp_token_list: | |
if ( | |
hasattr(tokenizer, token_name) | |
and getattr(tokenizer, token_name) is not None | |
): | |
token_value = getattr(tokenizer, token_name) | |
if token_value and str(token_value).strip(): | |
special_tokens[token_name] = str(token_value) | |
# Tokenize处理 | |
input_ids = tokenizer.encode(text, add_special_tokens=True) | |
# 生成带颜色的HTML | |
colors = ["#A8D8EA", "#AA96DA", "#FCBAD3"] | |
html_parts = [] | |
for i, token_id in enumerate(input_ids): | |
# 转义HTML特殊字符 | |
safe_token = escape(tokenizer.decode(token_id)) | |
# 交替颜色 | |
color = colors[i % len(colors)] | |
html_part = ( | |
f'<span style="background-color: {color};' | |
f"margin: 2px; padding: 2px 5px; border-radius: 3px;" | |
f'display: inline-block; font-size: 1.2em;">' | |
f"{safe_token}<br/>" | |
f'<sub style="font-size: 0.9em;">{token_id}</sub>' | |
f"</span>" | |
) | |
html_parts.append(html_part) | |
# 统计信息 | |
token_len = len(input_ids) | |
char_len = len(text) | |
return ( | |
"".join(html_parts), | |
tokenizer_type, | |
vocab_size, | |
special_tokens, | |
token_len, | |
char_len, | |
) | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return error_msg, None, None, None, 0, 0 | |
banner_md = """# 🎨 Tokenize it! | |
Powerful token visualization tool for your text inputs. 🚀 | |
Works for LLMs both online and *locally* on your machine!""" | |
banner = gr.Markdown(banner_md) | |
model_selector = gr.Dropdown( | |
label="Choose or enter model name", | |
choices=get_available_models(), | |
interactive=True, | |
allow_custom_value=True, | |
) | |
text_input = gr.Textbox(label="Input Text", placeholder="Hello World!", lines=4) | |
submit_btn = gr.Button("🚀 Tokenize!", variant="primary") | |
tokenizer_type = gr.Textbox(label="Tokenizer Type", interactive=False) | |
vocab_size = gr.Number(label="Vocab Size", interactive=False) | |
sp_tokens = gr.JSON(label="Special Tokens") | |
output_html = gr.HTML(label="Tokenized Output", elem_classes="token-output") | |
token_count = gr.Number(label="Token Count", value=0, interactive=False) | |
char_count = gr.Number(label="Character Count", value=0, interactive=False) | |
with gr.Blocks(title="Token Visualizer", theme="NoCrypt/miku") as webui: | |
banner.render() | |
with gr.Row(scale=2): | |
with gr.Column(): | |
model_selector.render() | |
text_input.render() | |
submit_btn.render() | |
output_html.render() | |
with gr.Column(): | |
with gr.Accordion("Details", open=False): | |
with gr.Row(): | |
tokenizer_type.render() | |
vocab_size.render() | |
sp_tokens.render() | |
with gr.Row(): | |
token_count.render() | |
char_count.render() | |
# 定义CSS样式 | |
webui.css = """ | |
.token-output span { | |
margin: 3px; | |
vertical-align: top; | |
} | |
.stats-output { | |
font-weight: bold !important; | |
color: #2c3e50 !important; | |
} | |
""" | |
submit_btn.click( | |
fn=tokenize_text, | |
inputs=[model_selector, text_input], | |
outputs=[ | |
output_html, | |
tokenizer_type, | |
vocab_size, | |
sp_tokens, | |
token_count, | |
char_count, | |
], | |
) | |
if __name__ == "__main__": | |
os.makedirs("models", exist_ok=True) | |
webui.launch(pwa=True) | |