Spaces:
Sleeping
Sleeping
File size: 5,884 Bytes
73ab266 c98496e 73ab266 c98496e b17a7e8 157fcd6 b17a7e8 157fcd6 b17a7e8 c98496e 73ab266 157fcd6 c98496e 73ab266 157fcd6 b17a7e8 c98496e b17a7e8 73ab266 7e427fb b17a7e8 c98496e b17a7e8 c98496e 73ab266 c98496e 73ab266 157fcd6 73ab266 157fcd6 73ab266 c98496e 73ab266 157fcd6 f0427f1 73ab266 f0427f1 73ab266 f0427f1 7e427fb f0427f1 73ab266 b17a7e8 c98496e b17a7e8 c98496e 73ab266 b17a7e8 157fcd6 73ab266 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import numpy as np
cache_path = "weights/caches"
parquets = os.listdir(cache_path)
dataset = load_dataset("kisate-team/feature-explanations", split="train")
def find_revions():
revisions = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2:
revisions.add(int(parts[-3][1:]))
return sorted(revisions)
def find_layers(revision):
layers = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2 and int(parts[-3][1:]) == revision:
layers.add(int(parts[-4][1:]))
return sorted(layers)
revisions = find_revions()
layers = {
revision: find_layers(revision) for revision in revisions
}
features = {
revision: {
layer: {
item["feature"]:item for item in dataset if item["layer"] == layer and item["version"] == revision
} for layer in layers[revision]
} for revision in revisions
}
# layers = dataset.unique("layer")
nearby = 8
stride = 0.25
n_bins = 10
def make_cache_name(layer, revision, model):
return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet"
models = {
"gemma-2b-r": "gemma-2b-residuals",
"phi-3": "phi"
}
tokenizers = {
"gemma-2b-r": "alpindale/gemma-2b",
"phi-3": "microsoft/Phi-3-mini-4k-instruct"
}
token_tables = {
"gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"),
"phi-3": pq.read_table("weights/tokens.parquet")
}
with gr.Blocks() as demo:
feature_table = gr.State(None)
model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model")
revision_dropdown = gr.Dropdown(revisions, label="Revision")
layer_dropdown = gr.Dropdown(layers[4], label="Layer")
def update_features(layer):
feature_dropdown = gr.Dropdown(features[layer].keys())
return feature_dropdown
def update_layers(revision):
layer_dropdown = gr.Dropdown(layers[revision])
return layer_dropdown
frequency = gr.Number(0, label="Total frequency (%)")
extra_tokens = gr.Number(0, label="Extra Max Act Tokens")
# layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
# histogram = gr.LinePlot(x="activation", y="freq")
revision_dropdown.input(update_layers, revision_dropdown, layer_dropdown)
feature_input = gr.Number(0, label="Feature")
autoi_expl = gr.Textbox(label="AutoInterp Explanation")
selfe_expl = gr.Textbox(label="SelfExplain Explanation")
cm = gr.HighlightedText()
frame = gr.Highlightedtext()
def update(model, revision, layer, feature, extra_tokens):
correction = 1
if "gemma" in model:
correction = 0
token_table = token_tables[model]
tokenizer_name = tokenizers[model]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
table = pq.read_table(make_cache_name(layer, revision, models[model]))
table_feat = table.filter(pc.field("feature") == feature).to_pandas()
# freq_t = table_feat[["activation", "freq"]]
total_freq = float(table_feat["freq"].sum()) * 100
table_feat = table_feat[table_feat["activation"] > 0]
table_feat = table_feat[table_feat["freq"] > 0]
table_feat = table_feat.sort_values("activation", ascending=False)
texts = table_feat["token"].apply(
lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()]
).tolist()
# texts = [tokenizer.tokenize(text) for text in texts]
activations = table_feat["nearby"].to_numpy()
activations = [[0] * extra_tokens + a.tolist() + [0] * extra_tokens for i, a in enumerate(activations) if len(texts[i]) > 0]
texts = [text for text in texts if len(text) > 0]
for t, a in zip(texts, activations):
assert len(t) == len(a)
if len(activations) > 0:
activations = np.stack(activations) * stride
max_act = table_feat["activation"].max()
activations = activations / max_act
highlight_data = [
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
for text, activation in zip(texts, activations)
]
flat_data = [item for sublist in highlight_data for item in sublist]
color_map_data = [i / n_bins for i in range(n_bins + 1)]
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
else:
flat_data = []
color_map_data = []
if feature in features[revision][layer]:
autoi_expl = features[revision][layer][feature]["explanation"]
selfe_expl = features[revision][layer][feature]["gen_explanations"]
if selfe_expl is not None:
selfe_expl = "\n".join(
f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
)
else:
autoi_expl = "No explanation found"
selfe_expl = "No explanation found"
return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
# feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl])
if __name__ == "__main__":
demo.launch(share=True)
|