Spaces:
Running
Running
import plotly.graph_objects as go | |
import numpy as np | |
from sklearn.decomposition import PCA | |
def list_supported_models(task): | |
if task == "Text Classification": | |
return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"] | |
elif task == "Text Generation": | |
return ["gpt2", "distilgpt2"] | |
elif task == "Question Answering": | |
return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"] | |
return [] | |
def visualize_attention(attentions, tokenizer, inputs): | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
last_layer_attention = attentions[-1][0] # [heads, seq_len, seq_len] | |
avg_attention = last_layer_attention.mean(dim=0).detach().numpy() | |
fig = go.Figure(data=go.Heatmap( | |
z=avg_attention, | |
x=tokens, | |
y=tokens, | |
colorscale='Viridis' | |
)) | |
fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens)) | |
return fig | |
def plot_token_embeddings(embeddings, tokens): | |
pca = PCA(n_components=2) | |
reduced = pca.fit_transform(embeddings.detach().numpy()) | |
fig = go.Figure() | |
for i, token in enumerate(tokens): | |
fig.add_trace(go.Scatter( | |
x=[reduced[i][0]], y=[reduced[i][1]], | |
text=[token], | |
mode='markers+text', | |
textposition='top center', | |
marker=dict(size=10), | |
name=token | |
)) | |
fig.update_layout(title="Token Embeddings (PCA)", xaxis_title="PC 1", yaxis_title="PC 2") | |
return fig | |