ggcristian commited on
Commit
218e8a1
·
1 Parent(s): 8e9d8db

use dropdowns for bubble plot and add cursor-pointer as css

Browse files
Files changed (2) hide show
  1. app.py +67 -22
  2. css_html_js.py +3 -2
app.py CHANGED
@@ -76,20 +76,64 @@ def update_benchmarks_by_task(task):
76
  return gr.update(choices=["All"] + benchmarks, value="All")
77
 
78
  def generate_scatter_plot(benchmark, metric):
79
- benchmark, metric = handle_special_cases(benchmark, metric)
80
-
81
- subset = df[df['Benchmark'] == benchmark]
82
- if benchmark == "RTL-Repo":
83
- subset = subset[subset['Metric'].str.contains('EM', case=False, na=False)]
84
- detailed_scores = subset.groupby('Model', as_index=False)['Score'].mean()
85
- detailed_scores.rename(columns={'Score': 'Exact Matching (EM)'}, inplace=True)
86
- detailed_scores['Average ⬆️'] = detailed_scores['Exact Matching (EM)']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
- detailed_scores = subset.pivot_table(index='Model', columns='Metric', values='Score').reset_index()
89
- detailed_scores['Average ⬆️'] = detailed_scores[['Syntax (STX)', 'Functionality (FNC)', 'Synthesis (SYN)', 'Power', 'Performance', 'Area']].mean(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- details = df[['Model', 'Params', 'Model Type']].drop_duplicates('Model')
92
- scatter_data = pd.merge(detailed_scores, details, on='Model', how='left').dropna(subset=['Params', metric])
93
 
94
  scatter_data['x'] = scatter_data['Params']
95
  scatter_data['y'] = scatter_data[metric]
@@ -101,7 +145,7 @@ def generate_scatter_plot(benchmark, metric):
101
  y_axis_limits = {
102
  'Functionality (FNC)': [5, 90], 'Syntax (STX)': [20, 100], 'Synthesis (SYN)': [5, 90],
103
  'Power': [0, 50], 'Performance': [0, 50], 'Area': [0, 50], 'Exact Matching (EM)': [0, 50],
104
- 'Average ⬆️': [0, 80]
105
  }
106
  y_range = y_axis_limits.get(metric, [0, 80])
107
 
@@ -109,10 +153,6 @@ def generate_scatter_plot(benchmark, metric):
109
  scatter_data, x='x', y='y', log_x=True, size='size', color='Model Type', text='Model',
110
  hover_data={metric: ':.2f'}, title=f'Params vs. {metric} for {benchmark}',
111
  labels={'x': '# Params (Log Scale)', 'y': metric}, template="plotly_white",
112
- # color_discrete_map={"General": "
113
- #A8D5BA", "Coding": "
114
- #F7DC6F", "RTL-Specific": "
115
- #87CEFA"},
116
  height=600, width=1200
117
  )
118
 
@@ -223,9 +263,10 @@ with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Default(primary_hue=c
223
 
224
  with gr.Tab("Interactive Bubble Plot"):
225
  with gr.Row(equal_height=True):
226
- bubble_benchmark = gr.Radio(choices=benchmarks, label="Select Benchmark", value='VerilogEval S2R')
227
- bubble_metric = gr.Radio(choices=non_rtl_metrics[:-1], label="Select Metric", value="Syntax (STX)")
228
- scatter_plot = gr.Plot(value=generate_scatter_plot('VerilogEval S2R', default_metric), label="Bubble Chart", elem_id="full-width-plot")
 
229
 
230
  with gr.Tab("About Us"):
231
  gr.HTML(
@@ -282,8 +323,12 @@ with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Default(primary_hue=c
282
  metric = "Exact Matching (EM)"
283
  return gr.update(choices=rtl_metrics, value=metric), generate_scatter_plot(benchmark, metric)
284
  else:
285
- metric = non_rtl_metrics[0] # default to Syntax
286
- return gr.update(choices=non_rtl_metrics[:-1], value=metric), generate_scatter_plot(benchmark, metric)
 
 
 
 
287
 
288
  def on_metric_change(benchmark, metric):
289
  benchmark, metric = handle_special_cases(benchmark, metric)
 
76
  return gr.update(choices=["All"] + benchmarks, value="All")
77
 
78
  def generate_scatter_plot(benchmark, metric):
79
+ if benchmark == "All":
80
+ models_data = []
81
+
82
+ for bench in benchmarks:
83
+ subset = df[df['Benchmark'] == bench]
84
+ if bench == "RTL-Repo":
85
+ subset = subset[subset['Metric'].str.contains('EM', case=False, na=False)]
86
+ models_in_bench = subset['Model'].unique()
87
+ models_data.extend([(model, bench) for model in models_in_bench])
88
+
89
+ all_models = list(set([m[0] for m in models_data]))
90
+ details = df[['Model', 'Params', 'Model Type']].drop_duplicates('Model')
91
+
92
+ if metric == "Aggregated ⬆️":
93
+ agg_columns = [col for col in df_agg.columns if col.startswith('Agg ')]
94
+ if agg_columns:
95
+ agg_data = df_agg.copy()
96
+ agg_data['Aggregated ⬆️'] = agg_data[agg_columns].mean(axis=1).round(2)
97
+ scatter_data = pd.merge(details, agg_data[['Model', 'Aggregated ⬆️']], on='Model', how='inner')
98
+ else:
99
+ scatter_data = details.copy()
100
+ scatter_data['Aggregated ⬆️'] = 50 # defaut
101
+ else:
102
+ scatter_data = details.copy()
103
+ metric_data = df[df['Metric'] == metric].groupby('Model')['Score'].mean().reset_index()
104
+ metric_data = metric_data.rename(columns={'Score': metric})
105
+ scatter_data = pd.merge(scatter_data, metric_data, on='Model', how='left')
106
+ scatter_data = scatter_data.dropna(subset=[metric] if metric in scatter_data.columns else ['Aggregated ⬆️'])
107
+
108
  else:
109
+ # Code we already had for individual benchmark selection
110
+ benchmark, metric = handle_special_cases(benchmark, metric)
111
+
112
+ subset = df[df['Benchmark'] == benchmark]
113
+ if benchmark == "RTL-Repo":
114
+ subset = subset[subset['Metric'].str.contains('EM', case=False, na=False)]
115
+ detailed_scores = subset.groupby('Model', as_index=False)['Score'].mean()
116
+ detailed_scores.rename(columns={'Score': 'Exact Matching (EM)'}, inplace=True)
117
+ detailed_scores['Aggregated ⬆️'] = detailed_scores['Exact Matching (EM)']
118
+ else:
119
+ agg_column = None
120
+ detailed_scores = subset.pivot_table(index='Model', columns='Metric', values='Score').reset_index()
121
+ if benchmark == 'VerilogEval S2R':
122
+ agg_column = 'Agg VerilogEval S2R'
123
+ elif benchmark == 'VerilogEval MC':
124
+ agg_column = 'Agg VerilogEval MC'
125
+ elif benchmark == 'RTLLM':
126
+ agg_column = 'Agg RTLLM'
127
+ elif benchmark == 'VeriGen':
128
+ agg_column = 'Agg VeriGen'
129
+ if agg_column and agg_column in df_agg.columns:
130
+ agg_data = df_agg[['Model', agg_column]].rename(columns={agg_column: 'Aggregated ⬆️'})
131
+ detailed_scores = pd.merge(detailed_scores, agg_data, on='Model', how='left')
132
+ else:
133
+ detailed_scores['Aggregated ⬆️'] = detailed_scores[['Syntax (STX)', 'Functionality (FNC)', 'Synthesis (SYN)', 'Power', 'Performance', 'Area']].mean(axis=1).round(2)
134
 
135
+ details = df[['Model', 'Params', 'Model Type']].drop_duplicates('Model')
136
+ scatter_data = pd.merge(detailed_scores, details, on='Model', how='left').dropna(subset=['Params', metric])
137
 
138
  scatter_data['x'] = scatter_data['Params']
139
  scatter_data['y'] = scatter_data[metric]
 
145
  y_axis_limits = {
146
  'Functionality (FNC)': [5, 90], 'Syntax (STX)': [20, 100], 'Synthesis (SYN)': [5, 90],
147
  'Power': [0, 50], 'Performance': [0, 50], 'Area': [0, 50], 'Exact Matching (EM)': [0, 50],
148
+ 'Aggregated ⬆️': [0, 80]
149
  }
150
  y_range = y_axis_limits.get(metric, [0, 80])
151
 
 
153
  scatter_data, x='x', y='y', log_x=True, size='size', color='Model Type', text='Model',
154
  hover_data={metric: ':.2f'}, title=f'Params vs. {metric} for {benchmark}',
155
  labels={'x': '# Params (Log Scale)', 'y': metric}, template="plotly_white",
 
 
 
 
156
  height=600, width=1200
157
  )
158
 
 
263
 
264
  with gr.Tab("Interactive Bubble Plot"):
265
  with gr.Row(equal_height=True):
266
+ bubble_benchmark = gr.Dropdown(choices=["All"] + benchmarks, label="Select Benchmark", value='All', elem_classes="gr-dropdown")
267
+ bubble_metric = gr.Dropdown(choices=["Aggregated ⬆️"] + non_rtl_metrics[:-1], label="Select Metric", value="Aggregated ⬆️")
268
+ with gr.Row(equal_height=True):
269
+ scatter_plot = gr.Plot(value=generate_scatter_plot('All', "Aggregated ⬆️"), label="Bubble Chart", elem_id="full-width-plot")
270
 
271
  with gr.Tab("About Us"):
272
  gr.HTML(
 
323
  metric = "Exact Matching (EM)"
324
  return gr.update(choices=rtl_metrics, value=metric), generate_scatter_plot(benchmark, metric)
325
  else:
326
+ if benchmark == "All":
327
+ metric = "Aggregated ⬆️" # default to Aggregated
328
+ return gr.update(choices=["Aggregated ⬆️"] + non_rtl_metrics[:-1], value=metric), generate_scatter_plot(benchmark, metric)
329
+ else:
330
+ metric = non_rtl_metrics[0]
331
+ return gr.update(choices=non_rtl_metrics[:-1], value=metric), generate_scatter_plot(benchmark, metric)
332
 
333
  def on_metric_change(benchmark, metric):
334
  benchmark, metric = handle_special_cases(benchmark, metric)
css_html_js.py CHANGED
@@ -51,11 +51,9 @@ custom_css = """
51
  background: none;
52
  border: none;
53
  }
54
-
55
  #search-bar {
56
  padding: 0px;
57
  }
58
- /* Limit the width of the first AutoEvalColumn so that names don't expand too much */
59
  #leaderboard-table td:nth-child(2),
60
  #leaderboard-table th:nth-child(2) {
61
  max-width: 400px;
@@ -111,6 +109,9 @@ custom_css = """
111
  .slider_input_container {
112
  padding-top: 8px;
113
  }
 
 
 
114
  """
115
 
116
  get_window_url_params = """
 
51
  background: none;
52
  border: none;
53
  }
 
54
  #search-bar {
55
  padding: 0px;
56
  }
 
57
  #leaderboard-table td:nth-child(2),
58
  #leaderboard-table th:nth-child(2) {
59
  max-width: 400px;
 
109
  .slider_input_container {
110
  padding-top: 8px;
111
  }
112
+ input[role="listbox"] {
113
+ cursor: pointer !important;
114
+ }
115
  """
116
 
117
  get_window_url_params = """