Spaces:
Running
Running
import dash | |
import dash_bootstrap_components as dbc | |
from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction | |
import plotly.graph_objects as go | |
from src.execution_model import ScheduleConfig, Schedule | |
from src.strategies import ( | |
generate_1f1b_schedule, | |
generate_zero_bubble_1p_schedule, | |
generate_1f1b_overlap_schedule, | |
generate_1f1b_interleave_schedule, | |
generate_1f1b_interleave_overlap_schedule, | |
generate_dualpipe_schedule | |
) | |
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure | |
STRATEGIES = { | |
"1f1b": generate_1f1b_schedule, | |
"zb1p": generate_zero_bubble_1p_schedule, | |
"1f1b_overlap": generate_1f1b_overlap_schedule, | |
"1f1b_interleave": generate_1f1b_interleave_schedule, | |
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule, | |
"dualpipe": generate_dualpipe_schedule, | |
} | |
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP], suppress_callback_exceptions=True) | |
app.title = "Pipeline Parallelism Schedule Visualizer" | |
# Initial default values | |
default_values = { | |
"num_devices": 4, | |
"num_stages": 8, | |
"num_batches": 16, | |
"p2p_latency": 0.0, | |
"op_time_forward": 1.0, | |
"op_time_backward_d": 1.0, | |
"op_time_backward_w": 1.0, | |
"op_time_backward": 2.0, | |
"strategy": ["1f1b_interleave"], | |
"op_time_overlapped_fwd_bwd": None, | |
} | |
# Define input groups using dbc components | |
card_style = {"marginBottom": "24px"} | |
basic_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Basic Parameters", className="card-title mb-4"), | |
html.Div([ | |
dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"), | |
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, required=True), | |
dbc.FormFeedback("Please provide a positive integer for the number of devices.", type="invalid", id="feedback-num_devices"), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Number of Stages (Model Chunks)", html_for='num_stages', className="form-label"), | |
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, required=True), | |
dbc.FormFeedback("Please provide a positive integer for the number of stages.", type="invalid", id="feedback-num_stages"), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"), | |
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, required=True), | |
dbc.FormFeedback("Please provide a positive integer for the number of microbatches.", type="invalid", id="feedback-num_batches"), | |
], className="mb-3"), | |
]), | |
style=card_style | |
) | |
scheduling_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Scheduling Strategy", className="card-title mb-4"), | |
dbc.ButtonGroup( | |
[ | |
dbc.Button( | |
strategy, | |
id={"type": "strategy-button", "index": strategy}, | |
color="secondary", | |
outline=True, | |
active=strategy in default_values["strategy"], | |
className="me-1" | |
) | |
for strategy in STRATEGIES.keys() | |
], | |
className="d-flex flex-wrap" | |
), | |
dcc.Store(id='selected-strategies-store', data=default_values["strategy"]), | |
html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2') | |
]), | |
style=card_style | |
) | |
timing_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Operation Timing (ms)", className="card-title mb-4"), | |
html.Div([ | |
html.Div([ | |
dbc.Label("P2P Latency", html_for='p2p_latency', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-p2p", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, required=True), | |
dbc.FormFeedback("P2P latency must be a number >= 0.", type="invalid", id="feedback-p2p_latency"), | |
dbc.Tooltip( | |
"Time (ms) for point-to-point communication between adjacent devices.", | |
target="tooltip-target-p2p", | |
placement="right" | |
) | |
], className="mb-3"), | |
html.Div([ | |
html.Div([ | |
dbc.Label("Forward Operation Time", html_for='op_time_forward', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-fwd", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, required=True), | |
dbc.FormFeedback("Forward time must be a number > 0.", type="invalid", id="feedback-op_time_forward"), | |
dbc.Tooltip( | |
"Time (ms) for a single forward pass of one microbatch through one stage.", | |
target="tooltip-target-fwd", | |
placement="right" | |
) | |
], className="mb-3"), | |
html.Div([ | |
html.Div([ | |
dbc.Label("Backward (Combined)", html_for='op_time_backward', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-bwd", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy does NOT require split backward."), | |
dbc.FormFeedback("Backward time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward"), | |
dbc.Tooltip( | |
"Time (ms) for a combined backward pass (data gradient + weight gradient) of one microbatch through one stage.", | |
target="tooltip-target-bwd", | |
placement="right" | |
) | |
], className="mb-3"), | |
# --- Collapsible Advanced Options (Item 3) --- | |
html.Hr(className="my-3"), | |
dbc.Switch( | |
id="advanced-timing-switch", | |
label="Show Advanced Timing Options", | |
value=False, | |
className="mb-3" | |
), | |
dbc.Collapse( | |
id="advanced-timing-collapse", | |
is_open=False, | |
children=[ | |
html.Div([ | |
html.Div([ | |
dbc.Label("Backward D (Data Grad)", html_for='op_time_backward_d', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-d", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
dbc.FormFeedback("Backward D time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_d"), | |
dbc.Tooltip( | |
"Time (ms) for the data gradient part of the backward pass.", | |
target="tooltip-target-bwd-d", | |
placement="right" | |
) | |
], className="mb-3"), | |
html.Div([ | |
html.Div([ | |
dbc.Label("Backward W (Weight Grad)", html_for='op_time_backward_w', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-w", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
dbc.FormFeedback("Backward W time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_w"), | |
dbc.Tooltip( | |
"Time (ms) for the weight gradient part of the backward pass.", | |
target="tooltip-target-bwd-w", | |
placement="right" | |
) | |
], className="mb-3"), | |
html.Div([ | |
html.Div([ | |
dbc.Label("Overlapped Forward+Backward", html_for='op_time_overlapped_fwd_bwd', className="form-label d-inline-block me-1"), | |
html.I(className="bi bi-info-circle", id="tooltip-target-overlap", style={"cursor": "pointer"}) | |
]), | |
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Defaults to Fwd + Bwd", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]), | |
dbc.FormText("Specify if Forward and Backward ops overlap completely."), | |
dbc.FormFeedback("Overlapped time must be > 0 if specified.", type="invalid", id="feedback-op_time_overlapped_fwd_bwd"), | |
dbc.Tooltip( | |
"Optional: Specify a single time (ms) if the forward and backward passes for a microbatch can be fully overlapped within the same stage execution slot.", | |
target="tooltip-target-overlap", | |
placement="right" | |
) | |
], className="mb-3"), | |
] | |
) | |
]), | |
style=card_style | |
) | |
# Updated app layout using dbc components and structure | |
app.layout = dbc.Container([ | |
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"), | |
# Main Row with Left (Graphs) and Right (Controls) Columns | |
dbc.Row([ | |
# --- Left Column (Graphs Area) --- | |
dbc.Col([ | |
# Output Area for Graphs | |
dcc.Loading( | |
id="loading-graph-area", | |
type="circle", | |
children=html.Div(id='graph-output-container', style={"minHeight": "600px"}) | |
) | |
], lg=10, md=9, sm=12, className="mb-4 mb-lg-0"), | |
# --- Right Column (Controls Area) --- | |
dbc.Col([ | |
# Parameter Cards Stacked Vertically | |
basic_params_card, | |
scheduling_params_card, | |
timing_params_card, | |
# Generate Button below the cards in the right column | |
dbc.Row([ | |
dbc.Col( | |
dbc.Button( | |
'Generate Schedule', | |
id='generate-button', | |
n_clicks=0, | |
color="primary", | |
className="w-100", | |
disabled=False | |
), | |
) | |
], className="mt-3") | |
], lg=2, md=3, sm=12) | |
]), | |
# --- Toast Container (Positioned Fixed) --- | |
html.Div(id="toast-container", style={"position": "fixed", "top": 20, "right": 20, "zIndex": 1050}) | |
], fluid=True, className="py-4") | |
# --- Callback for Input Validation and Generate Button State --- | |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency, | |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
op_time_overlapped_fwd_bwd, selected_strategies): | |
is_invalid = { | |
"num_devices": num_devices is None or num_devices < 1, | |
"num_stages": num_stages is None or num_stages < 1, | |
"num_batches": num_batches is None or num_batches < 1, | |
"p2p_latency": p2p_latency is None or p2p_latency < 0, | |
"op_time_forward": op_time_forward is None or op_time_forward <= 0, | |
"op_time_backward": op_time_backward is not None and op_time_backward <= 0, | |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0, | |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0, | |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0, | |
} | |
# Validate strategy selection | |
strategy_feedback = "" # Default empty feedback | |
if not selected_strategies or len(selected_strategies) == 0: | |
is_invalid["strategies"] = True | |
strategy_feedback = "Please select at least one strategy." | |
else: | |
is_invalid["strategies"] = False | |
# Additional validation: Check if required timings are provided for selected strategies | |
needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies) | |
needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies) | |
if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None): | |
is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0 | |
is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0 | |
# We might want specific feedback here, but setting invalid=True is often enough | |
if needs_combined_backward and op_time_backward is None: | |
is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0 | |
# Check if any input is invalid | |
overall_invalid = any(is_invalid.values()) | |
# Disable button if any validation fails | |
disable_button = overall_invalid | |
# Return button state and invalid states for each input | |
return ( | |
disable_button, | |
is_invalid["num_devices"], | |
is_invalid["num_stages"], | |
is_invalid["num_batches"], | |
is_invalid["p2p_latency"], | |
is_invalid["op_time_forward"], | |
is_invalid["op_time_backward"], | |
is_invalid["op_time_backward_d"], | |
is_invalid["op_time_backward_w"], | |
is_invalid["op_time_overlapped_fwd_bwd"], | |
strategy_feedback # Update strategy feedback based on validation | |
) | |
# --- Callback to toggle Advanced Options Collapse --- | |
def toggle_advanced_options(switch_value): | |
return switch_value | |
# --- Client-side Callback for Strategy ButtonGroup --- | |
app.clientside_callback( | |
ClientsideFunction( | |
namespace='clientside', | |
function_name='update_strategy_selection' | |
), | |
Output('selected-strategies-store', 'data'), | |
Output({'type': 'strategy-button', 'index': ALL}, 'active'), | |
Output({'type': 'strategy-button', 'index': ALL}, 'color'), | |
Output({'type': 'strategy-button', 'index': ALL}, 'outline'), | |
Output('strategy-selection-feedback', 'children'), | |
Input({'type': 'strategy-button', 'index': ALL}, 'n_clicks'), | |
State('selected-strategies-store', 'data'), | |
prevent_initial_call=True | |
) | |
# --- Main Graph Update Callback --- | |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency, | |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
op_time_overlapped_fwd_bwd, | |
selected_strategies): | |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"] | |
graph_components = [] | |
toast_components = [] | |
valid_results = [] | |
error_messages = [] | |
automatic_adjustments = [] | |
execution_times = [] # Add list to store execution times | |
# Use a variable to track if initial validation fails | |
initial_validation_error = None | |
if not selected_strategies: | |
initial_validation_error = dbc.Toast( | |
"Please select at least one scheduling strategy.", | |
header="Input Error", | |
icon="warning", | |
duration=4000, | |
is_open=True, | |
className="border-warning" | |
) | |
elif not all([num_devices, num_stages, num_batches, op_time_forward]): | |
initial_validation_error = dbc.Toast( | |
"Missing required basic input values (Devices, Stages, Batches, Forward Time).", | |
header="Input Error", | |
icon="danger", | |
duration=4000, | |
is_open=True, | |
className="border-danger" | |
) | |
if initial_validation_error: | |
# Return empty graph list and the validation error toast | |
return [], [initial_validation_error] | |
for strategy in selected_strategies: | |
error_message = "" | |
placement_strategy = "" | |
# Use local copies of params that might be adjusted for this strategy | |
current_num_stages = num_stages | |
current_num_devices = num_devices | |
# Apply automatic adjustments for dualpipe | |
if strategy == "dualpipe" and num_stages != num_devices: | |
current_num_stages = num_devices | |
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices." | |
automatic_adjustments.append(adjustment_msg) | |
# Apply automatic adjustments for strategies that require num_stages == num_devices | |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices: | |
current_num_stages = num_devices | |
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices." | |
automatic_adjustments.append(adjustment_msg) | |
split_backward = strategy in ["zb1p", "dualpipe"] | |
if split_backward and not all([op_time_backward_d, op_time_backward_w]): | |
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required." | |
elif not split_backward and not op_time_backward: | |
error_message = f"Strategy '{strategy}': Combined Backward time is required." | |
if not error_message: | |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]: | |
placement_strategy = "standard" | |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]: | |
placement_strategy = "interleave" | |
if current_num_stages % current_num_devices != 0: | |
error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices." | |
elif strategy == "dualpipe": | |
placement_strategy = "dualpipe" | |
if current_num_stages % 2 != 0: | |
error_message = f"Strategy '{strategy}': Requires an even number of stages." | |
# Create adjusted operation times based on placement strategy | |
if not error_message: | |
try: | |
stages_per_device = current_num_stages // current_num_devices | |
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0 | |
if stages_per_device > 1: | |
adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)." | |
# Avoid adding duplicate adjustment messages if already added above | |
if adjustment_msg not in automatic_adjustments: | |
automatic_adjustments.append(adjustment_msg) | |
op_times = { "forward": float(op_time_forward) * time_scale_factor } | |
if split_backward: | |
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor | |
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor | |
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor | |
else: | |
op_times["backward"] = float(op_time_backward) * time_scale_factor | |
if op_time_overlapped_fwd_bwd is not None: | |
try: | |
overlapped_val = float(op_time_overlapped_fwd_bwd) | |
if overlapped_val > 0: | |
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor | |
except (ValueError, TypeError): | |
pass | |
config = ScheduleConfig( | |
num_devices=int(current_num_devices), | |
num_stages=int(current_num_stages), | |
num_batches=int(num_batches), | |
p2p_latency=float(p2p_latency), | |
placement_strategy=placement_strategy, | |
split_backward=split_backward, | |
op_times=op_times, | |
) | |
schedule_func = STRATEGIES.get(strategy) | |
if not schedule_func: | |
raise ValueError(f"Invalid strategy function for: {strategy}") | |
schedule = schedule_func(config) | |
schedule.execute() | |
vis_data = convert_schedule_to_visualization_format(schedule) | |
valid_results.append((strategy, schedule, vis_data)) | |
# Store execution time | |
execution_times.append((strategy, schedule.get_total_execution_time())) | |
except (AssertionError, ValueError, TypeError) as e: | |
error_message = f"Error for '{strategy}': {e}" | |
except Exception as e: | |
error_message = f"Unexpected error for '{strategy}': {e}" | |
if error_message: | |
error_messages.append((strategy, error_message)) | |
# --- Generate Toasts --- | |
# Add toasts for automatic adjustments | |
for adjustment in automatic_adjustments: | |
toast_components.append( | |
dbc.Toast( | |
adjustment, | |
header="Parameter Adjustment", | |
icon="info", | |
duration=5000, # Slightly longer duration for info | |
is_open=True, | |
className="border-info" | |
) | |
) | |
# Add toasts for errors | |
for strategy, msg in error_messages: | |
toast_components.append( | |
dbc.Toast( | |
msg, | |
header=f"Error: {strategy}", | |
icon="danger", | |
duration=8000, # Longer duration for errors | |
is_open=True, | |
className="border-danger" | |
) | |
) | |
# --- Generate Graphs --- | |
if valid_results: | |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results) | |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf')) | |
# Prepare graphs for single-column layout | |
graph_components = [] # Use graph_components again | |
for strategy, _, vis_data in sorted_valid_results: | |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False) | |
margin = max_execution_time * 0.05 | |
fig.update_layout( | |
xaxis=dict(range=[0, max_execution_time + margin]) | |
) | |
# Append the Div directly for vertical stacking | |
graph_components.append( | |
html.Div([ | |
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"), | |
dcc.Graph(figure=fig) | |
]) | |
) | |
# No grid arrangement needed for single column | |
# rows = [] ... removed ... | |
# If there are graphs, use the component list, otherwise show a message | |
output_content = [] | |
if graph_components: # Check if graph_components list is populated | |
output_content = graph_components # Assign the list of components | |
elif not toast_components: # Only show 'no results' if no errors/adjustments either | |
output_content = dbc.Alert("Click 'Generate Schedule' to see results.", color="info", className="mt-3") | |
# Add the execution time table if there are results | |
if execution_times: | |
# Sort times based on execution time (ascending) | |
sorted_times = sorted(execution_times, key=lambda x: x[1]) | |
min_time = sorted_times[0][1] if sorted_times else None | |
table_header = [html.Thead(html.Tr([html.Th("Strategy"), html.Th("Total Execution Time (ms)")]))] | |
table_rows = [] | |
for strategy, time in sorted_times: | |
row_class = "table-success" if time == min_time else "" | |
table_rows.append(html.Tr([html.Td(strategy), html.Td(f"{time:.2f}")], className=row_class)) | |
table_body = [html.Tbody(table_rows)] | |
summary_table = dbc.Table( | |
table_header + table_body, | |
bordered=True, | |
striped=True, | |
hover=True, | |
responsive=True, | |
color="light", # Apply a light theme color | |
className="mt-5" # Add margin top | |
) | |
# Prepend title to the table | |
table_component = html.Div([ | |
html.H4("Execution Time Summary", className="text-center mt-4 mb-3"), | |
summary_table | |
]) | |
# Append the table component to the output content | |
# If output_content is just the alert, replace it. Otherwise, append. | |
if isinstance(output_content, list): | |
output_content.append(table_component) | |
else: # It must be the Alert | |
output_content = [output_content, table_component] # Replace Alert with list | |
# Return graph components (single column list or message) and toast components | |
return output_content, toast_components | |
# For Hugging Face Spaces deployment | |
server = app.server | |
if __name__ == '__main__': | |
app.run_server(debug=False, host='0.0.0.0', port=7860) |