rahideer commited on
Commit
b36e408
·
verified ·
1 Parent(s): 7b5fc70

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +25 -0
utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import numpy as np
3
+
4
+ def list_supported_models(task):
5
+ if task == "Text Classification":
6
+ return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"]
7
+ elif task == "Text Generation":
8
+ return ["gpt2", "distilgpt2"]
9
+ elif task == "Question Answering":
10
+ return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"]
11
+ return []
12
+
13
+ def visualize_attention(attentions, tokenizer, inputs):
14
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
15
+ last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len]
16
+ avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
17
+
18
+ fig = go.Figure(data=go.Heatmap(
19
+ z=avg_attention,
20
+ x=tokens,
21
+ y=tokens,
22
+ colorscale='Viridis'
23
+ ))
24
+ fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
25
+ return fig