kisate commited on
Commit
73ab266
·
1 Parent(s): ae6c4fd
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pyarrow.parquet as pq
3
+ import pyarrow.compute as pc
4
+ from transformers import AutoTokenizer
5
+ import os
6
+ import numpy as np
7
+
8
+
9
+ token_table = pq.read_table("weights/tokens.parquet")
10
+ cache_path = "weights/caches"
11
+ parquets = os.listdir(cache_path)
12
+ TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
13
+
14
+ nearby = 8
15
+ stride = 0.25
16
+ n_bins = 10
17
+
18
+ with gr.Blocks() as demo:
19
+ feature_table = gr.State(None)
20
+
21
+ tokenizer_name = gr.Textbox(TOKENIZER)
22
+ dropdown = gr.Dropdown(parquets)
23
+ feature_input = gr.Number(0)
24
+ token_range = gr.Number(64)
25
+
26
+ frequency = gr.Number(0, label="Total frequency (%)")
27
+ histogram = gr.LinePlot(x="activation", y="freq")
28
+ cm = gr.HighlightedText()
29
+ frame = gr.Highlightedtext(
30
+ show_legend=True
31
+ )
32
+
33
+ def update(cache_name, feature, tokenizer_name, token_range):
34
+ if cache_name is None:
35
+ return
36
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
37
+ table = pq.read_table(f"{cache_path}/{cache_name}")
38
+ table_feat = table.filter(pc.field("feature") == feature).to_pandas()
39
+
40
+ freq_t = table_feat[["activation", "freq"]]
41
+ total_freq = float(table_feat["freq"].sum()) * 100
42
+
43
+
44
+ table_feat = table_feat[table_feat["activation"] > 0]
45
+ table_feat = table_feat[table_feat["freq"] > 0]
46
+
47
+ table_feat = table_feat.sort_values("activation", ascending=False)
48
+
49
+ texts = table_feat["token"].apply(
50
+ lambda x: tokenizer.decode(token_table[max(0, x - nearby - 1):x + nearby + 1]["tokens"].to_numpy())
51
+ )
52
+
53
+ texts = [tokenizer.tokenize(text) for text in texts]
54
+ activations = table_feat["nearby"].to_numpy()
55
+ if len(activations) > 0:
56
+ activations = np.stack(activations) * stride
57
+ max_act = table_feat["activation"].max()
58
+ activations = activations / max_act
59
+
60
+ highlight_data = [
61
+ [(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
62
+ for text, activation in zip(texts, activations)
63
+ ]
64
+
65
+ flat_data = [item for sublist in highlight_data for item in sublist]
66
+
67
+ color_map_data = [i / n_bins for i in range(n_bins + 1)]
68
+ color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
69
+ else:
70
+ flat_data = []
71
+ color_map_data = []
72
+
73
+ return flat_data, color_map_data, freq_t, total_freq
74
+
75
+
76
+ dropdown.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
77
+ feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
78
+
79
+
80
+ if __name__ == "__main__":
81
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pyarrow
2
+ transformers[cpu]
3
+ numpy
4
+ pandas
5
+ datasets
weights/caches/phi-l12-r4-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b36de99f765834429dede8b705b92fa9e0fd804bf3a35f323d3c964fae0158d0
3
+ size 12256546
weights/caches/phi-l14-r4-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1286d1fceb351bc6d67df3491892738fa515e81e1a7543e7a92024b535c6954a
3
+ size 15270782
weights/caches/phi-l16-r4-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd1a604f72bf11d03f652f48d4a18d093a745312226c123baefabd77bad7e5e5
3
+ size 12232213
weights/caches/phi-l18-r4-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c9331f82be59a4eaa9297fd12670462312d084839f471a4e5db109db54b8439
3
+ size 13454437
weights/tokens.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abc26a2910593929e66edd0549529b0768562a225efe26960c619b41495394a8
3
+ size 1550772