kisate commited on
Commit
c98496e
·
1 Parent(s): 73ab266

Add explanations

Browse files
Files changed (1) hide show
  1. app.py +39 -16
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import pyarrow.parquet as pq
3
  import pyarrow.compute as pc
4
  from transformers import AutoTokenizer
 
5
  import os
6
  import numpy as np
7
 
@@ -11,36 +12,50 @@ cache_path = "weights/caches"
11
  parquets = os.listdir(cache_path)
12
  TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
13
 
 
 
 
 
 
 
14
  nearby = 8
15
  stride = 0.25
16
  n_bins = 10
17
 
 
 
 
18
  with gr.Blocks() as demo:
19
  feature_table = gr.State(None)
20
 
21
  tokenizer_name = gr.Textbox(TOKENIZER)
22
- dropdown = gr.Dropdown(parquets)
23
- feature_input = gr.Number(0)
24
- token_range = gr.Number(64)
 
 
 
 
 
 
25
 
26
  frequency = gr.Number(0, label="Total frequency (%)")
27
- histogram = gr.LinePlot(x="activation", y="freq")
 
 
 
 
28
  cm = gr.HighlightedText()
29
- frame = gr.Highlightedtext(
30
- show_legend=True
31
- )
32
 
33
- def update(cache_name, feature, tokenizer_name, token_range):
34
- if cache_name is None:
35
- return
36
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
37
- table = pq.read_table(f"{cache_path}/{cache_name}")
38
  table_feat = table.filter(pc.field("feature") == feature).to_pandas()
39
 
40
- freq_t = table_feat[["activation", "freq"]]
41
  total_freq = float(table_feat["freq"].sum()) * 100
42
 
43
-
44
  table_feat = table_feat[table_feat["activation"] > 0]
45
  table_feat = table_feat[table_feat["freq"] > 0]
46
 
@@ -70,11 +85,19 @@ with gr.Blocks() as demo:
70
  flat_data = []
71
  color_map_data = []
72
 
73
- return flat_data, color_map_data, freq_t, total_freq
 
 
 
 
 
 
 
 
74
 
75
 
76
- dropdown.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
77
- feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
78
 
79
 
80
  if __name__ == "__main__":
 
2
  import pyarrow.parquet as pq
3
  import pyarrow.compute as pc
4
  from transformers import AutoTokenizer
5
+ from datasets import load_dataset
6
  import os
7
  import numpy as np
8
 
 
12
  parquets = os.listdir(cache_path)
13
  TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
14
 
15
+ dataset = load_dataset("kisate-team/feature-explanations", split="train")
16
+
17
+ layers = dataset.unique("layer")
18
+
19
+ features = {layer:{item["feature"]:item for item in dataset if item["layer"] == layer} for layer in layers}
20
+
21
  nearby = 8
22
  stride = 0.25
23
  n_bins = 10
24
 
25
+ def make_cache_name(layer):
26
+ return f"{cache_path}/phi-l{layer}-r4-st0.25x128-activations.parquet"
27
+
28
  with gr.Blocks() as demo:
29
  feature_table = gr.State(None)
30
 
31
  tokenizer_name = gr.Textbox(TOKENIZER)
32
+ layer_dropdown = gr.Dropdown(layers)
33
+ feature_dropdown = gr.Dropdown()
34
+
35
+ def update_features(layer):
36
+ feature_dropdown = gr.Dropdown(features[layer].keys())
37
+ return feature_dropdown
38
+
39
+ layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
40
+
41
 
42
  frequency = gr.Number(0, label="Total frequency (%)")
43
+ # histogram = gr.LinePlot(x="activation", y="freq")
44
+
45
+ autoi_expl = gr.Textbox()
46
+ selfe_expl = gr.Textbox()
47
+
48
  cm = gr.HighlightedText()
49
+ frame = gr.Highlightedtext()
 
 
50
 
51
+ def update(layer, feature, tokenizer_name):
 
 
52
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
53
+ table = pq.read_table(make_cache_name(layer))
54
  table_feat = table.filter(pc.field("feature") == feature).to_pandas()
55
 
56
+ # freq_t = table_feat[["activation", "freq"]]
57
  total_freq = float(table_feat["freq"].sum()) * 100
58
 
 
59
  table_feat = table_feat[table_feat["activation"] > 0]
60
  table_feat = table_feat[table_feat["freq"] > 0]
61
 
 
85
  flat_data = []
86
  color_map_data = []
87
 
88
+ autoi_expl = features[layer][feature]["explanation"]
89
+ selfe_expl = features[layer][feature]["gen_explanations"]
90
+
91
+ if selfe_expl is not None:
92
+ selfe_expl = "\n".join(
93
+ f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
94
+ )
95
+
96
+ return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
97
 
98
 
99
+ feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
100
+ # feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
101
 
102
 
103
  if __name__ == "__main__":