Victarry's picture
Update column width.
8ba3766
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction
import plotly.graph_objects as go
from src.execution_model import ScheduleConfig, Schedule
from src.strategies import (
generate_1f1b_schedule,
generate_zero_bubble_1p_schedule,
generate_1f1b_overlap_schedule,
generate_1f1b_interleave_schedule,
generate_1f1b_interleave_overlap_schedule,
generate_dualpipe_schedule
)
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
STRATEGIES = {
"1f1b": generate_1f1b_schedule,
"zb1p": generate_zero_bubble_1p_schedule,
"1f1b_overlap": generate_1f1b_overlap_schedule,
"1f1b_interleave": generate_1f1b_interleave_schedule,
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
"dualpipe": generate_dualpipe_schedule,
}
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP], suppress_callback_exceptions=True)
app.title = "Pipeline Parallelism Schedule Visualizer"
# Initial default values
default_values = {
"num_devices": 4,
"num_stages": 8,
"num_batches": 16,
"p2p_latency": 0.0,
"op_time_forward": 1.0,
"op_time_backward_d": 1.0,
"op_time_backward_w": 1.0,
"op_time_backward": 2.0,
"strategy": ["1f1b_interleave"],
"op_time_overlapped_fwd_bwd": None,
}
# Define input groups using dbc components
card_style = {"marginBottom": "24px"}
basic_params_card = dbc.Card(
dbc.CardBody([
html.H5("Basic Parameters", className="card-title mb-4"),
html.Div([
dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"),
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, required=True),
dbc.FormFeedback("Please provide a positive integer for the number of devices.", type="invalid", id="feedback-num_devices"),
], className="mb-3"),
html.Div([
dbc.Label("Number of Stages (Model Chunks)", html_for='num_stages', className="form-label"),
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, required=True),
dbc.FormFeedback("Please provide a positive integer for the number of stages.", type="invalid", id="feedback-num_stages"),
], className="mb-3"),
html.Div([
dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"),
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, required=True),
dbc.FormFeedback("Please provide a positive integer for the number of microbatches.", type="invalid", id="feedback-num_batches"),
], className="mb-3"),
]),
style=card_style
)
scheduling_params_card = dbc.Card(
dbc.CardBody([
html.H5("Scheduling Strategy", className="card-title mb-4"),
dbc.ButtonGroup(
[
dbc.Button(
strategy,
id={"type": "strategy-button", "index": strategy},
color="secondary",
outline=True,
active=strategy in default_values["strategy"],
className="me-1"
)
for strategy in STRATEGIES.keys()
],
className="d-flex flex-wrap"
),
dcc.Store(id='selected-strategies-store', data=default_values["strategy"]),
html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2')
]),
style=card_style
)
timing_params_card = dbc.Card(
dbc.CardBody([
html.H5("Operation Timing (ms)", className="card-title mb-4"),
html.Div([
html.Div([
dbc.Label("P2P Latency", html_for='p2p_latency', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-p2p", style={"cursor": "pointer"})
]),
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, required=True),
dbc.FormFeedback("P2P latency must be a number >= 0.", type="invalid", id="feedback-p2p_latency"),
dbc.Tooltip(
"Time (ms) for point-to-point communication between adjacent devices.",
target="tooltip-target-p2p",
placement="right"
)
], className="mb-3"),
html.Div([
html.Div([
dbc.Label("Forward Operation Time", html_for='op_time_forward', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-fwd", style={"cursor": "pointer"})
]),
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, required=True),
dbc.FormFeedback("Forward time must be a number > 0.", type="invalid", id="feedback-op_time_forward"),
dbc.Tooltip(
"Time (ms) for a single forward pass of one microbatch through one stage.",
target="tooltip-target-fwd",
placement="right"
)
], className="mb-3"),
html.Div([
html.Div([
dbc.Label("Backward (Combined)", html_for='op_time_backward', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-bwd", style={"cursor": "pointer"})
]),
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
dbc.FormText("Used when strategy does NOT require split backward."),
dbc.FormFeedback("Backward time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward"),
dbc.Tooltip(
"Time (ms) for a combined backward pass (data gradient + weight gradient) of one microbatch through one stage.",
target="tooltip-target-bwd",
placement="right"
)
], className="mb-3"),
# --- Collapsible Advanced Options (Item 3) ---
html.Hr(className="my-3"),
dbc.Switch(
id="advanced-timing-switch",
label="Show Advanced Timing Options",
value=False,
className="mb-3"
),
dbc.Collapse(
id="advanced-timing-collapse",
is_open=False,
children=[
html.Div([
html.Div([
dbc.Label("Backward D (Data Grad)", html_for='op_time_backward_d', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-d", style={"cursor": "pointer"})
]),
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
dbc.FormFeedback("Backward D time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_d"),
dbc.Tooltip(
"Time (ms) for the data gradient part of the backward pass.",
target="tooltip-target-bwd-d",
placement="right"
)
], className="mb-3"),
html.Div([
html.Div([
dbc.Label("Backward W (Weight Grad)", html_for='op_time_backward_w', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-w", style={"cursor": "pointer"})
]),
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
dbc.FormFeedback("Backward W time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_w"),
dbc.Tooltip(
"Time (ms) for the weight gradient part of the backward pass.",
target="tooltip-target-bwd-w",
placement="right"
)
], className="mb-3"),
html.Div([
html.Div([
dbc.Label("Overlapped Forward+Backward", html_for='op_time_overlapped_fwd_bwd', className="form-label d-inline-block me-1"),
html.I(className="bi bi-info-circle", id="tooltip-target-overlap", style={"cursor": "pointer"})
]),
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Defaults to Fwd + Bwd", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
dbc.FormText("Specify if Forward and Backward ops overlap completely."),
dbc.FormFeedback("Overlapped time must be > 0 if specified.", type="invalid", id="feedback-op_time_overlapped_fwd_bwd"),
dbc.Tooltip(
"Optional: Specify a single time (ms) if the forward and backward passes for a microbatch can be fully overlapped within the same stage execution slot.",
target="tooltip-target-overlap",
placement="right"
)
], className="mb-3"),
]
)
]),
style=card_style
)
# Updated app layout using dbc components and structure
app.layout = dbc.Container([
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
# Main Row with Left (Graphs) and Right (Controls) Columns
dbc.Row([
# --- Left Column (Graphs Area) ---
dbc.Col([
# Output Area for Graphs
dcc.Loading(
id="loading-graph-area",
type="circle",
children=html.Div(id='graph-output-container', style={"minHeight": "600px"})
)
], lg=10, md=9, sm=12, className="mb-4 mb-lg-0"),
# --- Right Column (Controls Area) ---
dbc.Col([
# Parameter Cards Stacked Vertically
basic_params_card,
scheduling_params_card,
timing_params_card,
# Generate Button below the cards in the right column
dbc.Row([
dbc.Col(
dbc.Button(
'Generate Schedule',
id='generate-button',
n_clicks=0,
color="primary",
className="w-100",
disabled=False
),
)
], className="mt-3")
], lg=2, md=3, sm=12)
]),
# --- Toast Container (Positioned Fixed) ---
html.Div(id="toast-container", style={"position": "fixed", "top": 20, "right": 20, "zIndex": 1050})
], fluid=True, className="py-4")
# --- Callback for Input Validation and Generate Button State ---
@app.callback(
Output('generate-button', 'disabled'),
# Outputs to control the 'invalid' state of Inputs
Output('num_devices', 'invalid'),
Output('num_stages', 'invalid'),
Output('num_batches', 'invalid'),
Output('p2p_latency', 'invalid'),
Output('op_time_forward', 'invalid'),
Output('op_time_backward', 'invalid'),
Output('op_time_backward_d', 'invalid'),
Output('op_time_backward_w', 'invalid'),
Output('op_time_overlapped_fwd_bwd', 'invalid'),
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
Output('strategy-selection-feedback', 'children', allow_duplicate=True), # Update feedback from validation callback too
# Inputs: Trigger validation whenever any relevant input changes
Input('num_devices', 'value'),
Input('num_stages', 'value'),
Input('num_batches', 'value'),
Input('p2p_latency', 'value'),
Input('op_time_forward', 'value'),
Input('op_time_backward', 'value'),
Input('op_time_backward_d', 'value'),
Input('op_time_backward_w', 'value'),
Input('op_time_overlapped_fwd_bwd', 'value'),
Input('selected-strategies-store', 'data'), # Validate strategy selection
prevent_initial_call=True # Prevent callback running on page load before user interaction
)
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
op_time_overlapped_fwd_bwd, selected_strategies):
is_invalid = {
"num_devices": num_devices is None or num_devices < 1,
"num_stages": num_stages is None or num_stages < 1,
"num_batches": num_batches is None or num_batches < 1,
"p2p_latency": p2p_latency is None or p2p_latency < 0,
"op_time_forward": op_time_forward is None or op_time_forward <= 0,
"op_time_backward": op_time_backward is not None and op_time_backward <= 0,
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
}
# Validate strategy selection
strategy_feedback = "" # Default empty feedback
if not selected_strategies or len(selected_strategies) == 0:
is_invalid["strategies"] = True
strategy_feedback = "Please select at least one strategy."
else:
is_invalid["strategies"] = False
# Additional validation: Check if required timings are provided for selected strategies
needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies)
needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies)
if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None):
is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0
is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0
# We might want specific feedback here, but setting invalid=True is often enough
if needs_combined_backward and op_time_backward is None:
is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0
# Check if any input is invalid
overall_invalid = any(is_invalid.values())
# Disable button if any validation fails
disable_button = overall_invalid
# Return button state and invalid states for each input
return (
disable_button,
is_invalid["num_devices"],
is_invalid["num_stages"],
is_invalid["num_batches"],
is_invalid["p2p_latency"],
is_invalid["op_time_forward"],
is_invalid["op_time_backward"],
is_invalid["op_time_backward_d"],
is_invalid["op_time_backward_w"],
is_invalid["op_time_overlapped_fwd_bwd"],
strategy_feedback # Update strategy feedback based on validation
)
# --- Callback to toggle Advanced Options Collapse ---
@app.callback(
Output("advanced-timing-collapse", "is_open"),
Input("advanced-timing-switch", "value"),
prevent_initial_call=True,
)
def toggle_advanced_options(switch_value):
return switch_value
# --- Client-side Callback for Strategy ButtonGroup ---
app.clientside_callback(
ClientsideFunction(
namespace='clientside',
function_name='update_strategy_selection'
),
Output('selected-strategies-store', 'data'),
Output({'type': 'strategy-button', 'index': ALL}, 'active'),
Output({'type': 'strategy-button', 'index': ALL}, 'color'),
Output({'type': 'strategy-button', 'index': ALL}, 'outline'),
Output('strategy-selection-feedback', 'children'),
Input({'type': 'strategy-button', 'index': ALL}, 'n_clicks'),
State('selected-strategies-store', 'data'),
prevent_initial_call=True
)
# --- Main Graph Update Callback ---
@app.callback(
# Output graph container and toast container separately
Output('graph-output-container', 'children'),
Output('toast-container', 'children'), # Output for toasts
Input('generate-button', 'n_clicks'),
State('num_devices', 'value'),
State('num_stages', 'value'),
State('num_batches', 'value'),
State('p2p_latency', 'value'),
State('op_time_forward', 'value'),
State('op_time_backward', 'value'),
State('op_time_backward_d', 'value'),
State('op_time_backward_w', 'value'),
State('op_time_overlapped_fwd_bwd', 'value'),
State('selected-strategies-store', 'data'),
prevent_initial_call=True
)
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
op_time_overlapped_fwd_bwd,
selected_strategies):
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
graph_components = []
toast_components = []
valid_results = []
error_messages = []
automatic_adjustments = []
execution_times = [] # Add list to store execution times
# Use a variable to track if initial validation fails
initial_validation_error = None
if not selected_strategies:
initial_validation_error = dbc.Toast(
"Please select at least one scheduling strategy.",
header="Input Error",
icon="warning",
duration=4000,
is_open=True,
className="border-warning"
)
elif not all([num_devices, num_stages, num_batches, op_time_forward]):
initial_validation_error = dbc.Toast(
"Missing required basic input values (Devices, Stages, Batches, Forward Time).",
header="Input Error",
icon="danger",
duration=4000,
is_open=True,
className="border-danger"
)
if initial_validation_error:
# Return empty graph list and the validation error toast
return [], [initial_validation_error]
for strategy in selected_strategies:
error_message = ""
placement_strategy = ""
# Use local copies of params that might be adjusted for this strategy
current_num_stages = num_stages
current_num_devices = num_devices
# Apply automatic adjustments for dualpipe
if strategy == "dualpipe" and num_stages != num_devices:
current_num_stages = num_devices
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
automatic_adjustments.append(adjustment_msg)
# Apply automatic adjustments for strategies that require num_stages == num_devices
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
current_num_stages = num_devices
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
automatic_adjustments.append(adjustment_msg)
split_backward = strategy in ["zb1p", "dualpipe"]
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
elif not split_backward and not op_time_backward:
error_message = f"Strategy '{strategy}': Combined Backward time is required."
if not error_message:
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
placement_strategy = "standard"
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
placement_strategy = "interleave"
if current_num_stages % current_num_devices != 0:
error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices."
elif strategy == "dualpipe":
placement_strategy = "dualpipe"
if current_num_stages % 2 != 0:
error_message = f"Strategy '{strategy}': Requires an even number of stages."
# Create adjusted operation times based on placement strategy
if not error_message:
try:
stages_per_device = current_num_stages // current_num_devices
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
if stages_per_device > 1:
adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)."
# Avoid adding duplicate adjustment messages if already added above
if adjustment_msg not in automatic_adjustments:
automatic_adjustments.append(adjustment_msg)
op_times = { "forward": float(op_time_forward) * time_scale_factor }
if split_backward:
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
else:
op_times["backward"] = float(op_time_backward) * time_scale_factor
if op_time_overlapped_fwd_bwd is not None:
try:
overlapped_val = float(op_time_overlapped_fwd_bwd)
if overlapped_val > 0:
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
except (ValueError, TypeError):
pass
config = ScheduleConfig(
num_devices=int(current_num_devices),
num_stages=int(current_num_stages),
num_batches=int(num_batches),
p2p_latency=float(p2p_latency),
placement_strategy=placement_strategy,
split_backward=split_backward,
op_times=op_times,
)
schedule_func = STRATEGIES.get(strategy)
if not schedule_func:
raise ValueError(f"Invalid strategy function for: {strategy}")
schedule = schedule_func(config)
schedule.execute()
vis_data = convert_schedule_to_visualization_format(schedule)
valid_results.append((strategy, schedule, vis_data))
# Store execution time
execution_times.append((strategy, schedule.get_total_execution_time()))
except (AssertionError, ValueError, TypeError) as e:
error_message = f"Error for '{strategy}': {e}"
except Exception as e:
error_message = f"Unexpected error for '{strategy}': {e}"
if error_message:
error_messages.append((strategy, error_message))
# --- Generate Toasts ---
# Add toasts for automatic adjustments
for adjustment in automatic_adjustments:
toast_components.append(
dbc.Toast(
adjustment,
header="Parameter Adjustment",
icon="info",
duration=5000, # Slightly longer duration for info
is_open=True,
className="border-info"
)
)
# Add toasts for errors
for strategy, msg in error_messages:
toast_components.append(
dbc.Toast(
msg,
header=f"Error: {strategy}",
icon="danger",
duration=8000, # Longer duration for errors
is_open=True,
className="border-danger"
)
)
# --- Generate Graphs ---
if valid_results:
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
# Prepare graphs for single-column layout
graph_components = [] # Use graph_components again
for strategy, _, vis_data in sorted_valid_results:
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
margin = max_execution_time * 0.05
fig.update_layout(
xaxis=dict(range=[0, max_execution_time + margin])
)
# Append the Div directly for vertical stacking
graph_components.append(
html.Div([
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
dcc.Graph(figure=fig)
])
)
# No grid arrangement needed for single column
# rows = [] ... removed ...
# If there are graphs, use the component list, otherwise show a message
output_content = []
if graph_components: # Check if graph_components list is populated
output_content = graph_components # Assign the list of components
elif not toast_components: # Only show 'no results' if no errors/adjustments either
output_content = dbc.Alert("Click 'Generate Schedule' to see results.", color="info", className="mt-3")
# Add the execution time table if there are results
if execution_times:
# Sort times based on execution time (ascending)
sorted_times = sorted(execution_times, key=lambda x: x[1])
min_time = sorted_times[0][1] if sorted_times else None
table_header = [html.Thead(html.Tr([html.Th("Strategy"), html.Th("Total Execution Time (ms)")]))]
table_rows = []
for strategy, time in sorted_times:
row_class = "table-success" if time == min_time else ""
table_rows.append(html.Tr([html.Td(strategy), html.Td(f"{time:.2f}")], className=row_class))
table_body = [html.Tbody(table_rows)]
summary_table = dbc.Table(
table_header + table_body,
bordered=True,
striped=True,
hover=True,
responsive=True,
color="light", # Apply a light theme color
className="mt-5" # Add margin top
)
# Prepend title to the table
table_component = html.Div([
html.H4("Execution Time Summary", className="text-center mt-4 mb-3"),
summary_table
])
# Append the table component to the output content
# If output_content is just the alert, replace it. Otherwise, append.
if isinstance(output_content, list):
output_content.append(table_component)
else: # It must be the Alert
output_content = [output_content, table_component] # Replace Alert with list
# Return graph components (single column list or message) and toast components
return output_content, toast_components
# For Hugging Face Spaces deployment
server = app.server
if __name__ == '__main__':
app.run_server(debug=False, host='0.0.0.0', port=7860)