Spaces:
Running
Running
Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def list_supported_models(task):
|
5 |
+
if task == "Text Classification":
|
6 |
+
return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"]
|
7 |
+
elif task == "Text Generation":
|
8 |
+
return ["gpt2", "distilgpt2"]
|
9 |
+
elif task == "Question Answering":
|
10 |
+
return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"]
|
11 |
+
return []
|
12 |
+
|
13 |
+
def visualize_attention(attentions, tokenizer, inputs):
|
14 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
15 |
+
last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len]
|
16 |
+
avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
|
17 |
+
|
18 |
+
fig = go.Figure(data=go.Heatmap(
|
19 |
+
z=avg_attention,
|
20 |
+
x=tokens,
|
21 |
+
y=tokens,
|
22 |
+
colorscale='Viridis'
|
23 |
+
))
|
24 |
+
fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
|
25 |
+
return fig
|