TabArena-WIP / app.py
geoalgo's picture
fix sorting metric (for now rank as ELO not computed), remove dataset renaming which was messing around when adding dataset properties
7ad3cf3
import gradio as gr
from apscheduler.schedulers.background import BackgroundScheduler
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
from src.about import (
CITATION_BUTTON_LABEL,
CITATION_BUTTON_TEXT,
INTRODUCTION_TEXT,
LLM_BENCHMARKS_TEXT,
TITLE,
)
from src.constants import ProblemTypes, MetricNames
from src.display.css_html_js import custom_css
from src.display.utils import (
ModelInfoColumn,
fields
)
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, REPO_ID
from src.populate import get_model_info_df, get_merged_df
from src.utils import get_grouped_dfs, pivot_existed_df, rename_metrics, format_df
def restart_space():
API.restart_space(repo_id=REPO_ID)
def init_leaderboard(ori_dataframe, model_info_df, sort_val: str|None = None):
if ori_dataframe is None or ori_dataframe.empty:
raise ValueError("Leaderboard DataFrame is empty or None.")
model_info_col_list = [c.name for c in fields(ModelInfoColumn) if c.displayed_by_default if c.name not in ['#Params (B)', 'available_on_hub', 'hub', 'Model sha','Hub License']]
col2type_dict = {c.name: c.type for c in fields(ModelInfoColumn)}
default_selection_list = list(ori_dataframe.columns) + model_info_col_list
# print('default_selection_list: ', default_selection_list)
# ipdb.set_trace()
# default_selection_list = [col for col in default_selection_list if col not in ['#Params (B)', 'available_on_hub', 'hub', 'Model sha','Hub License']]
merged_df = get_merged_df(ori_dataframe, model_info_df)
new_cols = ['T'] + [col for col in merged_df.columns if col != 'T']
merged_df = merged_df[new_cols]
if sort_val:
if sort_val in merged_df.columns:
merged_df = merged_df.sort_values(by=[sort_val])
else:
print(f'Warning: cannot sort by {sort_val}')
print('Merged df: ', merged_df)
# get the data type
datatype_list = [col2type_dict[col] if col in col2type_dict else 'number' for col in merged_df.columns]
# print('datatype_list: ', datatype_list)
# print('merged_df.column: ', merged_df.columns)
# ipdb.set_trace()
return Leaderboard(
value=merged_df,
datatype=datatype_list,
select_columns=SelectColumns(
default_selection=default_selection_list,
# default_selection=[c.name for c in fields(ModelInfoColumn) if
# c.displayed_by_default and c.name not in ['params', 'available_on_hub', 'hub',
# 'Model sha', 'Hub License']],
# default_selection=list(dataframe.columns),
cant_deselect=[c.name for c in fields(ModelInfoColumn) if c.never_hidden],
label="Select Columns to Display:",
# How to uncheck??
),
hide_columns=[c.name for c in fields(ModelInfoColumn) if c.hidden],
search_columns=['model'],
# hide_columns=[c.name for c in fields(AutoEvalColumn) if c.hidden],
# filter_columns=[
# ColumnFilter(AutoEvalColumn.model_type.name, type="checkboxgroup", label="Model types"),
# ColumnFilter(AutoEvalColumn.precision.name, type="checkboxgroup", label="Precision"),
# ColumnFilter(
# AutoEvalColumn.params.name,
# type="slider",
# min=0.01,
# max=500,
# label="Select the number of parameters (B)",
# ),
# ColumnFilter(
# AutoEvalColumn.still_on_hub.name, type="boolean", label="Deleted/incomplete", default=False
# ),
# ],
filter_columns=[
ColumnFilter(ModelInfoColumn.model_type.name, type="checkboxgroup", label="Model types"),
],
# bool_checkboxgroup_label="",
column_widths=[40, 150] + [180 for _ in range(len(merged_df.columns)-2)],
interactive=False,
)
def load_results():
grouped_dfs = get_grouped_dfs()
domain_df, overall_df = grouped_dfs[ProblemTypes.col_name], grouped_dfs['overall']
overall_df = rename_metrics(overall_df)
overall_df = format_df(overall_df)
overall_df = overall_df.sort_values(by=['rank'])
domain_df = pivot_existed_df(domain_df, tab_name=ProblemTypes.col_name)
print(f'Domain dataframe is {domain_df}')
model_info_df = get_model_info_df(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH)
#domain_df, freq_df, term_length_df, variate_type_df, overall_df = grouped_dfs['domain'], grouped_dfs['frequency'], grouped_dfs['term_length'], grouped_dfs['univariate'], grouped_dfs['overall']
# freq_df = pivot_existed_df(freq_df, tab_name='frequency')
# print(f'Freq dataframe is {freq_df}')
# term_length_df = pivot_existed_df(term_length_df, tab_name='term_length')
# print(f'Term length dataframe is {term_length_df}')
# variate_type_df = pivot_existed_df(variate_type_df, tab_name='univariate')
# print(f'Variate type dataframe is {variate_type_df}')
return overall_df, model_info_df, domain_df
def main():
overall_df, model_info_df, domain_df = load_results()
demo = gr.Blocks(css=custom_css)
with demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem('πŸ… Overall', elem_id="llm-benchmark-tab-table", id=5):
leaderboard = init_leaderboard(overall_df, model_info_df, sort_val=MetricNames.rank)
print(f'FINAL Overall LEADERBOARD {overall_df}')
with gr.TabItem("πŸ… By Domain", elem_id="llm-benchmark-tab-table", id=0):
leaderboard = init_leaderboard(domain_df, model_info_df)
print(f"FINAL Domain LEADERBOARD 1 {domain_df}")
# with gr.TabItem("πŸ… By Frequency", elem_id="llm-benchmark-tab-table", id=1):
# leaderboard = init_leaderboard(freq_df, model_info_df)
# print(f"FINAL Frequency LEADERBOARD 1 {freq_df}")
# with gr.TabItem("πŸ… By Term Length", elem_id="llm-benchmark-tab-table", id=2):
# leaderboard = init_leaderboard(term_length_df, model_info_df)
# print(f"FINAL term length LEADERBOARD 1 {term_length_df}")
# with gr.TabItem("πŸ… By Variate Type", elem_id="llm-benchmark-tab-table", id=3):
# leaderboard = init_leaderboard(variate_type_df, model_info_df)
# print(f"FINAL LEADERBOARD 1 {variate_type_df}")
with gr.TabItem("πŸ“ About", elem_id="llm-benchmark-tab-table", id=4):
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
with gr.Accordion("πŸ“™ Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
lines=20,
elem_id="citation-button",
show_copy_button=True,
)
scheduler = BackgroundScheduler()
# scheduler.add_job(restart_space, "interval", seconds=1800)
scheduler.start()
demo.queue(default_concurrency_limit=40).launch()
if __name__ == '__main__':
main()