Dmitrii commited on
Commit
157fcd6
·
1 Parent(s): 7e427fb

add our gemma residuals

Browse files
app.py CHANGED
@@ -7,10 +7,8 @@ import os
7
  import numpy as np
8
 
9
 
10
- token_table = pq.read_table("weights/tokens.parquet")
11
  cache_path = "weights/caches"
12
  parquets = os.listdir(cache_path)
13
- TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
14
 
15
  dataset = load_dataset("kisate-team/feature-explanations", split="train")
16
 
@@ -20,7 +18,7 @@ def find_revions():
20
  if parquet.endswith(".parquet"):
21
  parts = parquet.split("-")
22
  if len(parts) > 2:
23
- revisions.add(int(parts[2][1:]))
24
  return sorted(revisions)
25
 
26
  def find_layers(revision):
@@ -28,8 +26,8 @@ def find_layers(revision):
28
  for parquet in parquets:
29
  if parquet.endswith(".parquet"):
30
  parts = parquet.split("-")
31
- if len(parts) > 2 and int(parts[2][1:]) == revision:
32
- layers.add(int(parts[1][1:]))
33
  return sorted(layers)
34
 
35
  revisions = find_revions()
@@ -51,13 +49,29 @@ nearby = 8
51
  stride = 0.25
52
  n_bins = 10
53
 
54
- def make_cache_name(layer, revision):
55
- return f"{cache_path}/phi-l{layer}-r{revision}-st0.25x128-activations.parquet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  with gr.Blocks() as demo:
58
  feature_table = gr.State(None)
59
 
60
- tokenizer_name = gr.Textbox(TOKENIZER, label="Tokenizer")
 
61
  revision_dropdown = gr.Dropdown(revisions, label="Revision")
62
 
63
  layer_dropdown = gr.Dropdown(layers[4], label="Layer")
@@ -86,9 +100,17 @@ with gr.Blocks() as demo:
86
  cm = gr.HighlightedText()
87
  frame = gr.Highlightedtext()
88
 
89
- def update(revision, layer, feature, extra_tokens, tokenizer_name):
 
 
 
 
 
 
 
 
90
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
91
- table = pq.read_table(make_cache_name(layer, revision))
92
  table_feat = table.filter(pc.field("feature") == feature).to_pandas()
93
 
94
  # freq_t = table_feat[["activation", "freq"]]
@@ -100,7 +122,7 @@ with gr.Blocks() as demo:
100
  table_feat = table_feat.sort_values("activation", ascending=False)
101
 
102
  texts = table_feat["token"].apply(
103
- lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + 1 - extra_tokens):x + extra_tokens + nearby + 2]["tokens"].to_numpy()]
104
  ).tolist()
105
 
106
  # texts = [tokenizer.tokenize(text) for text in texts]
@@ -145,7 +167,7 @@ with gr.Blocks() as demo:
145
 
146
 
147
  # feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
148
- feature_input.change(update, [revision_dropdown, layer_dropdown, feature_input, extra_tokens, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
149
 
150
 
151
  if __name__ == "__main__":
 
7
  import numpy as np
8
 
9
 
 
10
  cache_path = "weights/caches"
11
  parquets = os.listdir(cache_path)
 
12
 
13
  dataset = load_dataset("kisate-team/feature-explanations", split="train")
14
 
 
18
  if parquet.endswith(".parquet"):
19
  parts = parquet.split("-")
20
  if len(parts) > 2:
21
+ revisions.add(int(parts[-3][1:]))
22
  return sorted(revisions)
23
 
24
  def find_layers(revision):
 
26
  for parquet in parquets:
27
  if parquet.endswith(".parquet"):
28
  parts = parquet.split("-")
29
+ if len(parts) > 2 and int(parts[-3][1:]) == revision:
30
+ layers.add(int(parts[-4][1:]))
31
  return sorted(layers)
32
 
33
  revisions = find_revions()
 
49
  stride = 0.25
50
  n_bins = 10
51
 
52
+ def make_cache_name(layer, revision, model):
53
+ return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet"
54
+
55
+ models = {
56
+ "gemma-2b-r": "gemma-2b-residuals",
57
+ "phi-3": "phi"
58
+ }
59
+
60
+ tokenizers = {
61
+ "gemma-2b-r": "alpindale/gemma-2b",
62
+ "phi-3": "microsoft/Phi-3-mini-4k-instruct"
63
+ }
64
+
65
+ token_tables = {
66
+ "gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"),
67
+ "phi-3": pq.read_table("weights/tokens.parquet")
68
+ }
69
 
70
  with gr.Blocks() as demo:
71
  feature_table = gr.State(None)
72
 
73
+ model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model")
74
+
75
  revision_dropdown = gr.Dropdown(revisions, label="Revision")
76
 
77
  layer_dropdown = gr.Dropdown(layers[4], label="Layer")
 
100
  cm = gr.HighlightedText()
101
  frame = gr.Highlightedtext()
102
 
103
+ def update(model, revision, layer, feature, extra_tokens):
104
+ correction = 1
105
+ if "gemma" in model:
106
+ correction = 0
107
+
108
+ token_table = token_tables[model]
109
+
110
+ tokenizer_name = tokenizers[model]
111
+
112
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
113
+ table = pq.read_table(make_cache_name(layer, revision, models[model]))
114
  table_feat = table.filter(pc.field("feature") == feature).to_pandas()
115
 
116
  # freq_t = table_feat[["activation", "freq"]]
 
122
  table_feat = table_feat.sort_values("activation", ascending=False)
123
 
124
  texts = table_feat["token"].apply(
125
+ lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()]
126
  ).tolist()
127
 
128
  # texts = [tokenizer.tokenize(text) for text in texts]
 
167
 
168
 
169
  # feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
170
+ feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl])
171
 
172
 
173
  if __name__ == "__main__":
weights/caches/gemma-2b-residuals-l10-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8aa3f8e3cee3e390decaa5173056c1212c23e3a18e2386370d9194826199cf75
3
+ size 44542566
weights/caches/gemma-2b-residuals-l11-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91cbd63633a40b8bfa71bd6d44970ae6e743b1cb1c29e8475fbc36abfd710718
3
+ size 45054001
weights/caches/gemma-2b-residuals-l12-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfbaaf961204a4c4ef838b26cba6d9e746071860a36976ef18f8edfd899858ae
3
+ size 46113844
weights/caches/gemma-2b-residuals-l13-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1795d13266048de778367a568600c7f3ae37b28e2ef52d26f794db0924237d6
3
+ size 46947241
weights/caches/gemma-2b-residuals-l14-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e7f4422af3da4a4d5c4bb39c84a1ec4d50d6dbf5003fc1b59fb480b7d205838
3
+ size 47186402
weights/caches/gemma-2b-residuals-l15-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f655a992c99727eccf29f7b42a36eb8e2e7c0fa05f26c860e93808cacbaff756
3
+ size 47844586
weights/caches/gemma-2b-residuals-l16-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ec9ec36bd9d7621dccd294ec827fc8caf21e7de4a21210a762702cff3e63b6
3
+ size 47697629
weights/caches/gemma-2b-residuals-l6-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:170e7298b233ad62ed63a03e9fe3798d825f94152e8a5f21b1ce8f49b2841bdd
3
+ size 36831250
weights/caches/gemma-2b-residuals-l8-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86f3bfc067085879d68dc3dd7bf318e1c0888077418bb89e1c61e293c6f7d1f6
3
+ size 42722391
weights/caches/gemma-2b-residuals-l9-r1-st0.25x128-activations.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1824242f0cb49182607145e9149ebd4a92f3ef77463b68c16a2cd8147337994e
3
+ size 43738835
weights/tokens_gemma.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7908b97a2d86f42d61761e7f4383b51b124b486c6cb1c61cba0ece07fea6daae
3
+ size 16738082