Spaces:
Build error
Build error
""" | |
Gradio interface for visualizing the policy of a model. | |
""" | |
import chess | |
import chess.svg | |
import gradio as gr | |
import torch | |
from demo import constants, utils, visualisation | |
from lczerolens import move_encodings | |
from lczerolens.board import LczeroBoard | |
from lczerolens.xai import PolicyLens | |
current_board = None | |
current_raw_policy = None | |
current_policy = None | |
current_value = None | |
current_outcome = None | |
def list_models(): | |
""" | |
List the models in the model directory. | |
""" | |
models_info = utils.get_models_info(leela=False) | |
return sorted([[model_info[0]] for model_info in models_info]) | |
def on_select_model_df( | |
evt: gr.SelectData, | |
): | |
""" | |
When a model is selected, update the statement. | |
""" | |
return evt.value | |
def compute_policy( | |
board_fen, | |
action_seq, | |
model_name, | |
): | |
global current_board | |
global current_policy | |
global current_raw_policy | |
global current_value | |
global current_outcome | |
if model_name == "": | |
gr.Warning( | |
"Please select a model.", | |
) | |
return ( | |
None, | |
None, | |
"", | |
) | |
try: | |
board = LczeroBoard(board_fen) | |
except ValueError: | |
gr.Warning("Invalid FEN.") | |
return (None, None, "", None) | |
if action_seq: | |
try: | |
for action in action_seq.split(): | |
board.push_uci(action) | |
except ValueError: | |
gr.Warning("Invalid action sequence.") | |
return (None, None, "", None) | |
wrapper = utils.get_wrapper_from_state(model_name) | |
(output,) = wrapper.predict(board) | |
current_raw_policy = output["policy"][0] | |
policy = torch.softmax(output["policy"][0], dim=-1) | |
filtered_policy = torch.full((1858,), 0.0) | |
legal_moves = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves] | |
filtered_policy[legal_moves] = policy[legal_moves] | |
policy = filtered_policy | |
current_board = board | |
current_policy = policy | |
current_value = output.get("value", None) | |
current_outcome = output.get("wdl", None) | |
def make_plot( | |
view, | |
aggregate_topk, | |
move_to_play, | |
): | |
global current_board | |
global current_policy | |
global current_raw_policy | |
global current_value | |
global current_outcome | |
if current_board is None or current_policy is None: | |
gr.Warning("Please compute a policy first.") | |
return (None, None, "", None) | |
pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(current_policy, int(aggregate_topk)) | |
if view == "from": | |
if current_board.turn == chess.WHITE: | |
heatmap = pickup_agg | |
else: | |
heatmap = pickup_agg.view(8, 8).flip(0).view(64) | |
else: | |
if current_board.turn == chess.WHITE: | |
heatmap = dropoff_agg | |
else: | |
heatmap = dropoff_agg.view(8, 8).flip(0).view(64) | |
us_them = (current_board.turn, not current_board.turn) | |
topk_moves = torch.topk(current_policy, 50) | |
move = move_encodings.decode_move(topk_moves.indices[move_to_play - 1], us_them) | |
arrows = [(move.from_square, move.to_square)] | |
svg_board, fig = visualisation.render_heatmap(current_board, heatmap, arrows=arrows) | |
with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f: | |
f.write(svg_board) | |
fig_dist = visualisation.render_policy_distribution( | |
current_raw_policy, | |
[move_encodings.encode_move(move, us_them) for move in current_board.legal_moves], | |
) | |
return ( | |
f"{constants.FIGURE_DIRECTORY}/policy.svg", | |
fig, | |
(f"Value: {current_value} - WDL: {current_outcome}"), | |
fig_dist, | |
) | |
def make_policy_plot( | |
board_fen, | |
action_seq, | |
view, | |
model_name, | |
aggregate_topk, | |
move_to_play, | |
): | |
compute_policy( | |
board_fen, | |
action_seq, | |
model_name, | |
) | |
return make_plot( | |
view, | |
aggregate_topk, | |
move_to_play, | |
) | |
def play_move( | |
board_fen, | |
action_seq, | |
view, | |
model_name, | |
aggregate_topk, | |
move_to_play, | |
): | |
global current_board | |
global current_policy | |
move = move_encodings.decode_move( | |
current_policy.topk(50).indices[move_to_play - 1], | |
(current_board.turn, not current_board.turn), | |
) | |
current_board.push(move) | |
action_seq = f"{action_seq} {move.uci()}" | |
compute_policy( | |
board_fen, | |
action_seq, | |
model_name, | |
) | |
return [ | |
*make_plot( | |
view, | |
aggregate_topk, | |
1, | |
), | |
action_seq, | |
1, | |
] | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_df = gr.Dataframe( | |
headers=["Available models"], | |
datatype=["str"], | |
interactive=False, | |
type="array", | |
value=list_models, | |
) | |
with gr.Column(scale=1): | |
with gr.Row(): | |
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) | |
model_df.select( | |
on_select_model_df, | |
None, | |
model_name, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
action_seq = gr.Textbox( | |
label="Action sequence", | |
lines=1, | |
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), | |
) | |
with gr.Group(): | |
with gr.Row(): | |
aggregate_topk = gr.Slider( | |
label="Aggregate top k", | |
minimum=1, | |
maximum=1858, | |
step=1, | |
value=1858, | |
scale=3, | |
) | |
view = gr.Radio( | |
label="View", | |
choices=["from", "to"], | |
value="from", | |
scale=1, | |
) | |
with gr.Row(): | |
move_to_play = gr.Slider( | |
label="Move to play", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=1, | |
scale=3, | |
) | |
play_button = gr.Button("Play") | |
policy_button = gr.Button("Compute policy") | |
colorbar = gr.Plot(label="Colorbar") | |
game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="") | |
with gr.Column(): | |
image = gr.Image(label="Board") | |
density_plot = gr.Plot(label="Density") | |
policy_inputs = [ | |
board_fen, | |
action_seq, | |
view, | |
model_name, | |
aggregate_topk, | |
move_to_play, | |
] | |
policy_outputs = [image, colorbar, game_info, density_plot] | |
policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) | |
board_fen.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) | |
action_seq.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) | |
fast_inputs = [ | |
view, | |
aggregate_topk, | |
move_to_play, | |
] | |
aggregate_topk.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) | |
view.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) | |
move_to_play.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) | |
play_button.click( | |
play_move, | |
inputs=policy_inputs, | |
outputs=policy_outputs + [action_seq, move_to_play], | |
) | |