Dmitrii
add our gemma residuals
157fcd6
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import numpy as np
cache_path = "weights/caches"
parquets = os.listdir(cache_path)
dataset = load_dataset("kisate-team/feature-explanations", split="train")
def find_revions():
revisions = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2:
revisions.add(int(parts[-3][1:]))
return sorted(revisions)
def find_layers(revision):
layers = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2 and int(parts[-3][1:]) == revision:
layers.add(int(parts[-4][1:]))
return sorted(layers)
revisions = find_revions()
layers = {
revision: find_layers(revision) for revision in revisions
}
features = {
revision: {
layer: {
item["feature"]:item for item in dataset if item["layer"] == layer and item["version"] == revision
} for layer in layers[revision]
} for revision in revisions
}
# layers = dataset.unique("layer")
nearby = 8
stride = 0.25
n_bins = 10
def make_cache_name(layer, revision, model):
return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet"
models = {
"gemma-2b-r": "gemma-2b-residuals",
"phi-3": "phi"
}
tokenizers = {
"gemma-2b-r": "alpindale/gemma-2b",
"phi-3": "microsoft/Phi-3-mini-4k-instruct"
}
token_tables = {
"gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"),
"phi-3": pq.read_table("weights/tokens.parquet")
}
with gr.Blocks() as demo:
feature_table = gr.State(None)
model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model")
revision_dropdown = gr.Dropdown(revisions, label="Revision")
layer_dropdown = gr.Dropdown(layers[4], label="Layer")
def update_features(layer):
feature_dropdown = gr.Dropdown(features[layer].keys())
return feature_dropdown
def update_layers(revision):
layer_dropdown = gr.Dropdown(layers[revision])
return layer_dropdown
frequency = gr.Number(0, label="Total frequency (%)")
extra_tokens = gr.Number(0, label="Extra Max Act Tokens")
# layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
# histogram = gr.LinePlot(x="activation", y="freq")
revision_dropdown.input(update_layers, revision_dropdown, layer_dropdown)
feature_input = gr.Number(0, label="Feature")
autoi_expl = gr.Textbox(label="AutoInterp Explanation")
selfe_expl = gr.Textbox(label="SelfExplain Explanation")
cm = gr.HighlightedText()
frame = gr.Highlightedtext()
def update(model, revision, layer, feature, extra_tokens):
correction = 1
if "gemma" in model:
correction = 0
token_table = token_tables[model]
tokenizer_name = tokenizers[model]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
table = pq.read_table(make_cache_name(layer, revision, models[model]))
table_feat = table.filter(pc.field("feature") == feature).to_pandas()
# freq_t = table_feat[["activation", "freq"]]
total_freq = float(table_feat["freq"].sum()) * 100
table_feat = table_feat[table_feat["activation"] > 0]
table_feat = table_feat[table_feat["freq"] > 0]
table_feat = table_feat.sort_values("activation", ascending=False)
texts = table_feat["token"].apply(
lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()]
).tolist()
# texts = [tokenizer.tokenize(text) for text in texts]
activations = table_feat["nearby"].to_numpy()
activations = [[0] * extra_tokens + a.tolist() + [0] * extra_tokens for i, a in enumerate(activations) if len(texts[i]) > 0]
texts = [text for text in texts if len(text) > 0]
for t, a in zip(texts, activations):
assert len(t) == len(a)
if len(activations) > 0:
activations = np.stack(activations) * stride
max_act = table_feat["activation"].max()
activations = activations / max_act
highlight_data = [
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
for text, activation in zip(texts, activations)
]
flat_data = [item for sublist in highlight_data for item in sublist]
color_map_data = [i / n_bins for i in range(n_bins + 1)]
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
else:
flat_data = []
color_map_data = []
if feature in features[revision][layer]:
autoi_expl = features[revision][layer][feature]["explanation"]
selfe_expl = features[revision][layer][feature]["gen_explanations"]
if selfe_expl is not None:
selfe_expl = "\n".join(
f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
)
else:
autoi_expl = "No explanation found"
selfe_expl = "No explanation found"
return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
# feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl])
if __name__ == "__main__":
demo.launch(share=True)