demo / src /interfaces /game_feature_interface.py
Xmaster6y's picture
multi lines
6d92394 unverified
"""
Gradio interface for plotting policy.
"""
import chess
import gradio as gr
import uuid
import torch
from lczerolens.encodings import encode_move
from src import constants, global_variables, visualisation
def compute_features_fn(
features,
model_output,
file_id,
root_idx,
traj_idx,
start_fen,
move_seq,
feature_index
):
error_return = [features, model_output, file_id, root_idx, traj_idx] + [None] * 5
root_board = None
traj_board = None
try:
board = chess.Board(start_fen)
except ValueError:
board = chess.Board()
gr.Warning("Invalid FEN, using starting position.")
return error_return
i = 0
if root_idx == 0:
root_board = board.copy()
if traj_idx == 0:
traj_board = board.copy()
if move_seq:
try:
if move_seq.startswith("1."):
for move in move_seq.split():
if root_board is not None and traj_board is not None:
break
if move.endswith("."):
continue
board.push_san(move)
i += 1
if i == root_idx:
root_board = board.copy()
if i == traj_idx:
traj_board = board.copy()
else:
for move in move_seq.split():
if root_board is not None and traj_board is not None:
break
board.push_uci(move)
i += 1
if i == root_idx:
root_board = board.copy()
if i == traj_idx:
traj_board = board.copy()
except ValueError:
gr.Warning(f"Invalid move {move}.")
return error_return
if root_board is None or traj_board is None:
gr.Warning("Invalid move sequence.")
return error_return
model_output, pixel_acts, sae_output = global_variables.generator.generate(
root_board=root_board,
traj_board=traj_board
)
current_root_fen = root_board.fen()
current_traj_fen = traj_board.fen()
features = sae_output["features"]
x_hat = sae_output["x_hat"]
first_output = render_feature_index(
features,
model_output,
file_id,
root_idx,
traj_idx,
current_traj_fen,
feature_index
)
half_a_dim = constants.ACTIVATION_DIM // 2
half_f_dim = constants.DICTIONARY_SIZE // 2
pixel_f_avg = features.mean(dim=0)
pixel_f_active = (features > 0).float().mean(dim=0)
pixel_p_avg = features.mean(dim=1)
pixel_p_active = (features > 0).float().mean(dim=1)
if board.turn:
most_avg_pixels = pixel_p_avg.topk(5).indices.tolist()
most_active_pixels = pixel_p_active.topk(5).indices.tolist()
else:
most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist()
most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist()
info = f"Root WDL: {model_output['wdl'][0]}\n"
info += f"Traj WDL: {model_output['wdl'][1]}\n"
info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n"
info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n"
info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n"
info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n"
info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n"
info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n"
info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n"
info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n"
info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n"
info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}"
return *first_output, current_root_fen, current_traj_fen, info
def render_feature_index(
features,
model_output,
file_id,
root_idx,
traj_idx,
traj_fen,
feature_index,
):
if file_id is None:
file_id = str(uuid.uuid4())
board = chess.Board(traj_fen)
pixel_features = features[:,feature_index]
if board.turn:
heatmap = pixel_features.view(64)
else:
heatmap = pixel_features.view(8,8).flip(0).view(64)
best_legal_logit = None
best_legal_move = None
for move in board.legal_moves:
move_index = encode_move(move, (board.turn, not board.turn))
logit = model_output["policy"][1,move_index].item()
if best_legal_logit is None:
best_legal_logit = logit
else:
best_legal_move = move
svg_board, fig = visualisation.render_heatmap(
board,
heatmap,
arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
)
with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
f.write(svg_board)
return (
features,
model_output,
file_id,
root_idx,
traj_idx,
f"{constants.FIGURES_FOLER}/{file_id}.svg",
fig
)
def make_features_fn(var, direction):
def _make_features_fn(
features,
model_output,
file_id,
root_idx,
traj_idx,
start_fen,
move_seq,
feature_index
):
move_count = len([mv for mv in move_seq.split() if not mv.endswith(".")])
if var == "root":
root_idx += direction
if root_idx < 0:
gr.Warning("Already at first board.")
root_idx = 0
elif root_idx >= move_count:
gr.Warning("Already at last board.")
root_idx = move_count - 1
elif root_idx > traj_idx:
gr.Warning("Root should be before traj.")
root_idx = traj_idx
elif var == "traj":
traj_idx += direction
if traj_idx < 0:
gr.Warning("Already at first board.")
traj_idx = 0
elif traj_idx >= move_count:
gr.Warning("Already at last board.")
traj_idx = move_count - 1
elif traj_idx < root_idx:
gr.Warning("Traj should be after root.")
traj_idx = root_idx
return compute_features_fn(
features,
model_output,
file_id,
root_idx,
traj_idx,
start_fen,
move_seq,
feature_index
)
return _make_features_fn
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
start_fen = gr.Textbox(
label="Starting FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
move_seq = gr.Textbox(
label="Move sequence",
lines=1,
max_lines=20,
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
)
with gr.Group():
with gr.Row():
previous_root_button = gr.Button("Previous root")
next_root_button = gr.Button("Next root")
with gr.Row():
previous_traj_button = gr.Button("Previous traj")
next_traj_button = gr.Button("Next traj")
with gr.Group():
with gr.Row():
current_root_fen = gr.Textbox(
label="Root FEN",
lines=1,
max_lines=1,
interactive=False
)
with gr.Row():
current_traj_fen = gr.Textbox(
label="Traj FEN",
lines=1,
max_lines=1,
interactive=False
)
with gr.Row():
feature_index = gr.Slider(
label="Feature index",
minimum=0,
maximum=constants.DICTIONARY_SIZE-1,
step=1,
value=0,
)
with gr.Group():
with gr.Row():
info = gr.Textbox(label="Info", lines=1, max_lines=20, value="")
with gr.Row():
colorbar = gr.Plot(label="Colorbar")
with gr.Column():
board_image = gr.Image(label="Board")
features = gr.State(None)
model_output = gr.State(None)
file_id = gr.State(None)
root_idx = gr.State(0)
traj_idx = gr.State(0)
state = [features, model_output, file_id, root_idx, traj_idx]
base_inputs = [start_fen, move_seq, feature_index]
base_outputs = [board_image, colorbar, current_root_fen, current_traj_fen, info]
previous_root_button.click(
make_features_fn(var="root", direction=-1),
inputs=state + base_inputs,
outputs=state + base_outputs,
)
next_root_button.click(
make_features_fn(var="root", direction=1),
inputs=state + base_inputs,
outputs=state + base_outputs,
)
previous_traj_button.click(
make_features_fn(var="traj", direction=-1),
inputs=state + base_inputs,
outputs=state + base_outputs,
)
next_traj_button.click(
make_features_fn(var="traj", direction=1),
inputs=state + base_inputs,
outputs=state + base_outputs,
)
feature_index.change(
render_feature_index,
inputs=state + [current_traj_fen, feature_index],
outputs=state + [board_image, colorbar],
)