Spaces:
Running
Running
Add table results.
Browse files
app.py
CHANGED
@@ -371,11 +371,12 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
371 |
|
372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
373 |
|
374 |
-
graph_components = []
|
375 |
-
toast_components = []
|
376 |
valid_results = []
|
377 |
error_messages = []
|
378 |
automatic_adjustments = []
|
|
|
379 |
|
380 |
# Use a variable to track if initial validation fails
|
381 |
initial_validation_error = None
|
@@ -489,6 +490,8 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
489 |
schedule.execute()
|
490 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
491 |
valid_results.append((strategy, schedule, vis_data))
|
|
|
|
|
492 |
|
493 |
except (AssertionError, ValueError, TypeError) as e:
|
494 |
error_message = f"Error for '{strategy}': {e}"
|
@@ -530,19 +533,69 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
530 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
531 |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
|
532 |
|
|
|
|
|
533 |
for strategy, _, vis_data in sorted_valid_results:
|
534 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
535 |
margin = max_execution_time * 0.05
|
536 |
fig.update_layout(
|
537 |
xaxis=dict(range=[0, max_execution_time + margin])
|
538 |
)
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
543 |
|
544 |
-
|
545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
# For Hugging Face Spaces deployment
|
548 |
server = app.server
|
|
|
371 |
|
372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
373 |
|
374 |
+
graph_components = []
|
375 |
+
toast_components = []
|
376 |
valid_results = []
|
377 |
error_messages = []
|
378 |
automatic_adjustments = []
|
379 |
+
execution_times = [] # Add list to store execution times
|
380 |
|
381 |
# Use a variable to track if initial validation fails
|
382 |
initial_validation_error = None
|
|
|
490 |
schedule.execute()
|
491 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
492 |
valid_results.append((strategy, schedule, vis_data))
|
493 |
+
# Store execution time
|
494 |
+
execution_times.append((strategy, schedule.get_total_execution_time()))
|
495 |
|
496 |
except (AssertionError, ValueError, TypeError) as e:
|
497 |
error_message = f"Error for '{strategy}': {e}"
|
|
|
533 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
534 |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
|
535 |
|
536 |
+
# Prepare graphs for single-column layout
|
537 |
+
graph_components = [] # Use graph_components again
|
538 |
for strategy, _, vis_data in sorted_valid_results:
|
539 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
540 |
margin = max_execution_time * 0.05
|
541 |
fig.update_layout(
|
542 |
xaxis=dict(range=[0, max_execution_time + margin])
|
543 |
)
|
544 |
+
# Append the Div directly for vertical stacking
|
545 |
+
graph_components.append(
|
546 |
+
html.Div([
|
547 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
548 |
+
dcc.Graph(figure=fig)
|
549 |
+
])
|
550 |
+
)
|
551 |
|
552 |
+
# No grid arrangement needed for single column
|
553 |
+
# rows = [] ... removed ...
|
554 |
+
|
555 |
+
# If there are graphs, use the component list, otherwise show a message
|
556 |
+
output_content = []
|
557 |
+
if graph_components: # Check if graph_components list is populated
|
558 |
+
output_content = graph_components # Assign the list of components
|
559 |
+
elif not toast_components: # Only show 'no results' if no errors/adjustments either
|
560 |
+
output_content = dbc.Alert("Click 'Generate Schedule' to see results.", color="info", className="mt-3")
|
561 |
+
|
562 |
+
# Add the execution time table if there are results
|
563 |
+
if execution_times:
|
564 |
+
# Sort times based on execution time (ascending)
|
565 |
+
sorted_times = sorted(execution_times, key=lambda x: x[1])
|
566 |
+
min_time = sorted_times[0][1] if sorted_times else None
|
567 |
+
|
568 |
+
table_header = [html.Thead(html.Tr([html.Th("Strategy"), html.Th("Total Execution Time (ms)")]))]
|
569 |
+
table_rows = []
|
570 |
+
for strategy, time in sorted_times:
|
571 |
+
row_class = "table-success" if time == min_time else ""
|
572 |
+
table_rows.append(html.Tr([html.Td(strategy), html.Td(f"{time:.2f}")], className=row_class))
|
573 |
+
|
574 |
+
table_body = [html.Tbody(table_rows)]
|
575 |
+
summary_table = dbc.Table(
|
576 |
+
table_header + table_body,
|
577 |
+
bordered=True,
|
578 |
+
striped=True,
|
579 |
+
hover=True,
|
580 |
+
responsive=True,
|
581 |
+
color="light", # Apply a light theme color
|
582 |
+
className="mt-5" # Add margin top
|
583 |
+
)
|
584 |
+
# Prepend title to the table
|
585 |
+
table_component = html.Div([
|
586 |
+
html.H4("Execution Time Summary", className="text-center mt-4 mb-3"),
|
587 |
+
summary_table
|
588 |
+
])
|
589 |
+
|
590 |
+
# Append the table component to the output content
|
591 |
+
# If output_content is just the alert, replace it. Otherwise, append.
|
592 |
+
if isinstance(output_content, list):
|
593 |
+
output_content.append(table_component)
|
594 |
+
else: # It must be the Alert
|
595 |
+
output_content = [output_content, table_component] # Replace Alert with list
|
596 |
+
|
597 |
+
# Return graph components (single column list or message) and toast components
|
598 |
+
return output_content, toast_components
|
599 |
|
600 |
# For Hugging Face Spaces deployment
|
601 |
server = app.server
|