import gradio as gr import pyarrow.parquet as pq import pyarrow.compute as pc from transformers import AutoTokenizer from datasets import load_dataset import os import numpy as np cache_path = "weights/caches" parquets = os.listdir(cache_path) dataset = load_dataset("kisate-team/feature-explanations", split="train") def find_revions(): revisions = set() for parquet in parquets: if parquet.endswith(".parquet"): parts = parquet.split("-") if len(parts) > 2: revisions.add(int(parts[-3][1:])) return sorted(revisions) def find_layers(revision): layers = set() for parquet in parquets: if parquet.endswith(".parquet"): parts = parquet.split("-") if len(parts) > 2 and int(parts[-3][1:]) == revision: layers.add(int(parts[-4][1:])) return sorted(layers) revisions = find_revions() layers = { revision: find_layers(revision) for revision in revisions } features = { revision: { layer: { item["feature"]:item for item in dataset if item["layer"] == layer and item["version"] == revision } for layer in layers[revision] } for revision in revisions } # layers = dataset.unique("layer") nearby = 8 stride = 0.25 n_bins = 10 def make_cache_name(layer, revision, model): return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet" models = { "gemma-2b-r": "gemma-2b-residuals", "phi-3": "phi" } tokenizers = { "gemma-2b-r": "alpindale/gemma-2b", "phi-3": "microsoft/Phi-3-mini-4k-instruct" } token_tables = { "gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"), "phi-3": pq.read_table("weights/tokens.parquet") } with gr.Blocks() as demo: feature_table = gr.State(None) model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model") revision_dropdown = gr.Dropdown(revisions, label="Revision") layer_dropdown = gr.Dropdown(layers[4], label="Layer") def update_features(layer): feature_dropdown = gr.Dropdown(features[layer].keys()) return feature_dropdown def update_layers(revision): layer_dropdown = gr.Dropdown(layers[revision]) return layer_dropdown frequency = gr.Number(0, label="Total frequency (%)") extra_tokens = gr.Number(0, label="Extra Max Act Tokens") # layer_dropdown.input(update_features, layer_dropdown, feature_dropdown) # histogram = gr.LinePlot(x="activation", y="freq") revision_dropdown.input(update_layers, revision_dropdown, layer_dropdown) feature_input = gr.Number(0, label="Feature") autoi_expl = gr.Textbox(label="AutoInterp Explanation") selfe_expl = gr.Textbox(label="SelfExplain Explanation") cm = gr.HighlightedText() frame = gr.Highlightedtext() def update(model, revision, layer, feature, extra_tokens): correction = 1 if "gemma" in model: correction = 0 token_table = token_tables[model] tokenizer_name = tokenizers[model] tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) table = pq.read_table(make_cache_name(layer, revision, models[model])) table_feat = table.filter(pc.field("feature") == feature).to_pandas() # freq_t = table_feat[["activation", "freq"]] total_freq = float(table_feat["freq"].sum()) * 100 table_feat = table_feat[table_feat["activation"] > 0] table_feat = table_feat[table_feat["freq"] > 0] table_feat = table_feat.sort_values("activation", ascending=False) texts = table_feat["token"].apply( lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()] ).tolist() # texts = [tokenizer.tokenize(text) for text in texts] activations = table_feat["nearby"].to_numpy() activations = [[0] * extra_tokens + a.tolist() + [0] * extra_tokens for i, a in enumerate(activations) if len(texts[i]) > 0] texts = [text for text in texts if len(text) > 0] for t, a in zip(texts, activations): assert len(t) == len(a) if len(activations) > 0: activations = np.stack(activations) * stride max_act = table_feat["activation"].max() activations = activations / max_act highlight_data = [ [(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)] for text, activation in zip(texts, activations) ] flat_data = [item for sublist in highlight_data for item in sublist] color_map_data = [i / n_bins for i in range(n_bins + 1)] color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data] else: flat_data = [] color_map_data = [] if feature in features[revision][layer]: autoi_expl = features[revision][layer][feature]["explanation"] selfe_expl = features[revision][layer][feature]["gen_explanations"] if selfe_expl is not None: selfe_expl = "\n".join( f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl) ) else: autoi_expl = "No explanation found" selfe_expl = "No explanation found" return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl # feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl]) feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl]) if __name__ == "__main__": demo.launch(share=True)