File size: 5,884 Bytes
73ab266
 
 
 
c98496e
73ab266
 
 
 
 
 
 
c98496e
 
b17a7e8
 
 
 
 
 
157fcd6
b17a7e8
 
 
 
 
 
 
157fcd6
 
b17a7e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c98496e
73ab266
 
 
 
157fcd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c98496e
73ab266
 
 
157fcd6
 
b17a7e8
 
 
c98496e
 
 
 
b17a7e8
 
 
 
73ab266
 
7e427fb
b17a7e8
 
c98496e
 
b17a7e8
 
 
 
 
 
c98496e
73ab266
c98496e
73ab266
157fcd6
 
 
 
 
 
 
 
 
73ab266
157fcd6
73ab266
 
c98496e
73ab266
 
 
 
 
 
 
 
157fcd6
f0427f1
73ab266
f0427f1
73ab266
f0427f1
7e427fb
f0427f1
 
 
 
 
73ab266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b17a7e8
 
 
 
 
 
 
c98496e
b17a7e8
 
 
c98496e
73ab266
 
b17a7e8
157fcd6
73ab266
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import numpy as np


cache_path = "weights/caches"
parquets = os.listdir(cache_path)

dataset = load_dataset("kisate-team/feature-explanations", split="train")

def find_revions():
    revisions = set()
    for parquet in parquets:
        if parquet.endswith(".parquet"):
            parts = parquet.split("-")
            if len(parts) > 2:
                revisions.add(int(parts[-3][1:]))
    return sorted(revisions)

def find_layers(revision):
    layers = set()
    for parquet in parquets:
        if parquet.endswith(".parquet"):
            parts = parquet.split("-")
            if len(parts) > 2 and int(parts[-3][1:]) == revision:
                layers.add(int(parts[-4][1:]))
    return sorted(layers)

revisions = find_revions()
layers = {
    revision: find_layers(revision) for revision in revisions
}

features = {
    revision: {
        layer: {
            item["feature"]:item for item in dataset if item["layer"] == layer and item["version"] == revision
        } for layer in layers[revision]
    } for revision in revisions
}

# layers = dataset.unique("layer")

nearby = 8
stride = 0.25
n_bins = 10

def make_cache_name(layer, revision, model):
    return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet"

models = {
    "gemma-2b-r": "gemma-2b-residuals",
    "phi-3": "phi"
}

tokenizers = {
    "gemma-2b-r": "alpindale/gemma-2b",
    "phi-3": "microsoft/Phi-3-mini-4k-instruct"
}

token_tables = {
    "gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"),
    "phi-3": pq.read_table("weights/tokens.parquet")
}

with gr.Blocks() as demo:
    feature_table = gr.State(None)

    model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model")

    revision_dropdown = gr.Dropdown(revisions, label="Revision")

    layer_dropdown = gr.Dropdown(layers[4], label="Layer")

    def update_features(layer):
        feature_dropdown = gr.Dropdown(features[layer].keys())
        return feature_dropdown
    
    def update_layers(revision):
        layer_dropdown = gr.Dropdown(layers[revision])
        return layer_dropdown

    frequency = gr.Number(0, label="Total frequency (%)")
    extra_tokens = gr.Number(0, label="Extra Max Act Tokens")

    # layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
    # histogram = gr.LinePlot(x="activation", y="freq")

    revision_dropdown.input(update_layers, revision_dropdown, layer_dropdown)

    feature_input = gr.Number(0, label="Feature")

    autoi_expl = gr.Textbox(label="AutoInterp Explanation")
    selfe_expl = gr.Textbox(label="SelfExplain Explanation")

    cm = gr.HighlightedText()
    frame = gr.Highlightedtext()

    def update(model, revision, layer, feature, extra_tokens):
        correction = 1
        if "gemma" in model:
            correction = 0

        token_table = token_tables[model]

        tokenizer_name = tokenizers[model]

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        table = pq.read_table(make_cache_name(layer, revision, models[model]))
        table_feat = table.filter(pc.field("feature") == feature).to_pandas()

        # freq_t = table_feat[["activation", "freq"]]
        total_freq = float(table_feat["freq"].sum()) * 100
        
        table_feat = table_feat[table_feat["activation"] > 0]
        table_feat = table_feat[table_feat["freq"] > 0]

        table_feat = table_feat.sort_values("activation", ascending=False)

        texts = table_feat["token"].apply(
            lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()]
        ).tolist()

        # texts = [tokenizer.tokenize(text) for text in texts]
        activations = table_feat["nearby"].to_numpy()

        activations = [[0] * extra_tokens + a.tolist()  + [0] * extra_tokens for i, a in enumerate(activations) if len(texts[i]) > 0]
        texts = [text for text in texts if len(text) > 0]

        for t, a in zip(texts, activations):
            assert len(t) == len(a)

        if len(activations) > 0:
            activations = np.stack(activations) * stride
            max_act = table_feat["activation"].max()
            activations = activations / max_act

            highlight_data = [
                [(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
                for text, activation in zip(texts, activations)
            ]

            flat_data = [item for sublist in highlight_data for item in sublist]
            
            color_map_data = [i / n_bins for i in range(n_bins + 1)]
            color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
        else:
            flat_data = []
            color_map_data = []

        if feature in features[revision][layer]:
            autoi_expl = features[revision][layer][feature]["explanation"]
            selfe_expl = features[revision][layer][feature]["gen_explanations"]
            if selfe_expl is not None:
                selfe_expl = "\n".join(
                    f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
                )

        else:
            autoi_expl = "No explanation found"
            selfe_expl = "No explanation found"
        return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
        

    # feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
    feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl])


if __name__ == "__main__":
    demo.launch(share=True)