Victarry commited on
Commit
423355f
·
1 Parent(s): d67abe0

Update UI.

Browse files
Files changed (3) hide show
  1. app.py +355 -144
  2. assets/clientside.js +62 -0
  3. assets/custom.css +129 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import dash
2
  import dash_bootstrap_components as dbc
3
- from dash import dcc, html, Input, Output, State, callback_context
4
  import plotly.graph_objects as go
5
 
6
  from src.execution_model import ScheduleConfig, Schedule
@@ -23,7 +23,7 @@ STRATEGIES = {
23
  "dualpipe": generate_dualpipe_schedule,
24
  }
25
 
26
- app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
27
  app.title = "Pipeline Parallelism Schedule Visualizer"
28
 
29
  # Initial default values
@@ -36,107 +36,321 @@ default_values = {
36
  "op_time_backward_d": 1.0,
37
  "op_time_backward_w": 1.0,
38
  "op_time_backward": 2.0,
39
- "strategy": "1f1b_interleave",
40
  "op_time_overlapped_fwd_bwd": None,
41
  }
42
 
43
  # Define input groups using dbc components
 
 
44
  basic_params_card = dbc.Card(
45
  dbc.CardBody([
46
- html.H5("Basic Parameters", className="card-title"),
47
- html.Div([
48
- dbc.Label("Number of Devices (GPUs):"),
49
- dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
50
- ], className="mb-3"),
51
  html.Div([
52
- dbc.Label("Number of Stages (Model Chunks):"),
53
- dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
 
54
  ], className="mb-3"),
55
  html.Div([
56
- dbc.Label("Number of Microbatches:"),
57
- dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
 
58
  ], className="mb-3"),
59
  html.Div([
60
- dbc.Label("P2P Latency (ms):"),
61
- dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
 
62
  ], className="mb-3"),
63
- ])
 
64
  )
65
 
66
  scheduling_params_card = dbc.Card(
67
  dbc.CardBody([
68
- html.H5("Scheduling Parameters", className="card-title"),
69
- html.Div([
70
- dbc.Label("Scheduling Strategies:"),
71
- dbc.Checklist(
72
- id='strategy-checklist',
73
- options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
74
- value=list(STRATEGIES.keys()),
75
- inline=False,
76
- ),
77
- ], className="mb-3"),
78
- ])
 
 
 
 
 
 
 
 
79
  )
80
 
81
  timing_params_card = dbc.Card(
82
  dbc.CardBody([
83
- html.H5("Operation Timing (ms)", className="card-title"),
84
- html.Div([
85
- dbc.Label("Forward:"),
86
- dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
87
- ], className="mb-3"),
88
  html.Div([
89
- dbc.Label("Backward (Combined):"),
90
- dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
91
- dbc.FormText("Used when strategy does NOT require split backward."),
92
- ], className="mb-3"),
93
- html.Div([
94
- dbc.Label("Backward D (Data Grad):"),
95
- dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
96
- dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
 
 
 
97
  ], className="mb-3"),
98
  html.Div([
99
- dbc.Label("Backward W (Weight Grad):"),
100
- dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
101
- dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
 
 
 
 
 
 
 
 
102
  ], className="mb-3"),
103
  html.Div([
104
- dbc.Label("Overlapped Forward+Backward:"),
105
- dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
106
- dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
 
 
 
 
 
 
 
 
 
107
  ], className="mb-3"),
108
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
 
111
  # Updated app layout using dbc components and structure
112
  app.layout = dbc.Container([
113
  html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
114
 
 
115
  dbc.Row([
116
- dbc.Col(basic_params_card, md=4),
117
- dbc.Col(scheduling_params_card, md=4),
118
- dbc.Col(timing_params_card, md=4),
119
- ]),
120
-
121
- dbc.Row([
122
- dbc.Col([
123
- dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
124
- ], className="text-center")
125
- ]),
126
-
127
- dbc.Row([
128
  dbc.Col([
 
129
  dcc.Loading(
130
  id="loading-graph-area",
131
  type="circle",
132
- children=html.Div(id='graph-output-container', className="mt-4")
133
  )
134
- ])
135
- ])
136
- ], fluid=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  @app.callback(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  Output('graph-output-container', 'children'),
 
140
  Input('generate-button', 'n_clicks'),
141
  State('num_devices', 'value'),
142
  State('num_stages', 'value'),
@@ -147,7 +361,7 @@ app.layout = dbc.Container([
147
  State('op_time_backward_d', 'value'),
148
  State('op_time_backward_w', 'value'),
149
  State('op_time_overlapped_fwd_bwd', 'value'),
150
- State('strategy-checklist', 'value'),
151
  prevent_initial_call=True
152
  )
153
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
@@ -155,19 +369,39 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
155
  op_time_overlapped_fwd_bwd,
156
  selected_strategies):
157
 
158
- # Define the desired display order for strategies
159
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
160
-
161
- output_components = []
162
- valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
163
- error_messages = [] # Store (strategy_name, error_message) for errors
164
- automatic_adjustments = [] # Store messages about automatic parameter adjustments
 
 
 
 
165
 
166
  if not selected_strategies:
167
- return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- if not all([num_devices, num_stages, num_batches, op_time_forward]):
170
- return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]
 
171
 
172
  for strategy in selected_strategies:
173
  error_message = ""
@@ -179,17 +413,15 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
179
 
180
  # Apply automatic adjustments for dualpipe
181
  if strategy == "dualpipe" and num_stages != num_devices:
182
- current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
183
- automatic_adjustments.append(
184
- f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
185
- )
186
 
187
  # Apply automatic adjustments for strategies that require num_stages == num_devices
188
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
189
  current_num_stages = num_devices
190
- automatic_adjustments.append(
191
- f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
192
- )
193
 
194
  split_backward = strategy in ["zb1p", "dualpipe"]
195
 
@@ -201,41 +433,32 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
201
  if not error_message:
202
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
203
  placement_strategy = "standard"
204
- # No need to check num_stages == num_devices as we've enforced it above
205
  elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
206
  placement_strategy = "interleave"
207
  if current_num_stages % current_num_devices != 0:
208
- error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
209
  elif strategy == "dualpipe":
210
  placement_strategy = "dualpipe"
211
  if current_num_stages % 2 != 0:
212
- error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
213
 
214
  # Create adjusted operation times based on placement strategy
215
  if not error_message:
216
  try:
217
- # Calculate number of stages per device for time adjustment
218
  stages_per_device = current_num_stages // current_num_devices
219
-
220
- # Calculate scaling factor - this normalizes operation time by stages per device
221
- # For standard placement (1:1 stage:device mapping), this remains 1.0
222
- # For interleaved, this scales down the time proportionally
223
  time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
224
-
225
  if stages_per_device > 1:
226
- automatic_adjustments.append(
227
- f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
228
- )
229
-
230
- # Apply scaling to operation times
231
- op_times = {
232
- "forward": float(op_time_forward) * time_scale_factor
233
- }
234
-
235
  if split_backward:
236
  op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
237
  op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
238
- # Keep combined for compatibility
239
  op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
240
  else:
241
  op_times["backward"] = float(op_time_backward) * time_scale_factor
@@ -244,14 +467,13 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
244
  try:
245
  overlapped_val = float(op_time_overlapped_fwd_bwd)
246
  if overlapped_val > 0:
247
- # Scale overlapped time too
248
  op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
249
  except (ValueError, TypeError):
250
  pass
251
 
252
  config = ScheduleConfig(
253
  num_devices=int(current_num_devices),
254
- num_stages=int(current_num_stages), # Use adjusted value
255
  num_batches=int(num_batches),
256
  p2p_latency=float(p2p_latency),
257
  placement_strategy=placement_strategy,
@@ -265,73 +487,62 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
265
 
266
  schedule = schedule_func(config)
267
  schedule.execute()
268
-
269
- # Store valid results instead of creating figure immediately
270
  vis_data = convert_schedule_to_visualization_format(schedule)
271
  valid_results.append((strategy, schedule, vis_data))
272
 
273
  except (AssertionError, ValueError, TypeError) as e:
274
- error_message = f"Error generating schedule for '{strategy}': {e}"
275
- import traceback
276
- traceback.print_exc()
277
  except Exception as e:
278
- error_message = f"An unexpected error occurred for '{strategy}': {e}"
279
- import traceback
280
- traceback.print_exc()
281
 
282
  if error_message:
283
  error_messages.append((strategy, error_message))
284
 
285
- # Add alerts for any automatic parameter adjustments
 
286
  for adjustment in automatic_adjustments:
287
- output_components.append(
288
- dbc.Alert(adjustment, color="info", dismissable=True)
 
 
 
 
 
 
 
289
  )
290
 
291
- # If we have valid results, calculate the maximum execution time across all schedules
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  if valid_results:
293
- # Find global maximum execution time
294
  max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
295
-
296
- # Sort valid results according to the display order
297
- sorted_valid_results = []
298
-
299
- # First add strategies in the predefined order
300
- for strategy_name in strategy_display_order:
301
- for result in valid_results:
302
- if result[0] == strategy_name:
303
- sorted_valid_results.append(result)
304
-
305
- # Then add any remaining strategies that might not be in the predefined order
306
- for result in valid_results:
307
- if result[0] not in strategy_display_order:
308
- sorted_valid_results.append(result)
309
-
310
- # Create figures with aligned x-axis, using the sorted results
311
  for strategy, _, vis_data in sorted_valid_results:
312
  fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
313
-
314
- # Force the x-axis range to be the same for all figures
315
- # Add a small margin (5%) for better visualization
316
  margin = max_execution_time * 0.05
317
  fig.update_layout(
318
- xaxis=dict(
319
- range=[0, max_execution_time + margin]
320
- )
321
  )
322
-
323
- output_components.append(html.Div([
324
  html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
325
  dcc.Graph(figure=fig)
326
  ]))
327
-
328
- # Add error messages to output
329
- for strategy, msg in error_messages:
330
- output_components.append(
331
- dbc.Alert(msg, color="danger", className="mt-3")
332
- )
333
 
334
- return output_components
 
335
 
336
  # For Hugging Face Spaces deployment
337
  server = app.server
 
1
  import dash
2
  import dash_bootstrap_components as dbc
3
+ from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction
4
  import plotly.graph_objects as go
5
 
6
  from src.execution_model import ScheduleConfig, Schedule
 
23
  "dualpipe": generate_dualpipe_schedule,
24
  }
25
 
26
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP], suppress_callback_exceptions=True)
27
  app.title = "Pipeline Parallelism Schedule Visualizer"
28
 
29
  # Initial default values
 
36
  "op_time_backward_d": 1.0,
37
  "op_time_backward_w": 1.0,
38
  "op_time_backward": 2.0,
39
+ "strategy": ["1f1b_interleave"],
40
  "op_time_overlapped_fwd_bwd": None,
41
  }
42
 
43
  # Define input groups using dbc components
44
+ card_style = {"marginBottom": "24px"}
45
+
46
  basic_params_card = dbc.Card(
47
  dbc.CardBody([
48
+ html.H5("Basic Parameters", className="card-title mb-4"),
 
 
 
 
49
  html.Div([
50
+ dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"),
51
+ dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, required=True),
52
+ dbc.FormFeedback("Please provide a positive integer for the number of devices.", type="invalid", id="feedback-num_devices"),
53
  ], className="mb-3"),
54
  html.Div([
55
+ dbc.Label("Number of Stages (Model Chunks)", html_for='num_stages', className="form-label"),
56
+ dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, required=True),
57
+ dbc.FormFeedback("Please provide a positive integer for the number of stages.", type="invalid", id="feedback-num_stages"),
58
  ], className="mb-3"),
59
  html.Div([
60
+ dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"),
61
+ dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, required=True),
62
+ dbc.FormFeedback("Please provide a positive integer for the number of microbatches.", type="invalid", id="feedback-num_batches"),
63
  ], className="mb-3"),
64
+ ]),
65
+ style=card_style
66
  )
67
 
68
  scheduling_params_card = dbc.Card(
69
  dbc.CardBody([
70
+ html.H5("Scheduling Strategy", className="card-title mb-4"),
71
+ dbc.ButtonGroup(
72
+ [
73
+ dbc.Button(
74
+ strategy,
75
+ id={"type": "strategy-button", "index": strategy},
76
+ color="secondary",
77
+ outline=True,
78
+ active=strategy in default_values["strategy"],
79
+ className="me-1"
80
+ )
81
+ for strategy in STRATEGIES.keys()
82
+ ],
83
+ className="d-flex flex-wrap"
84
+ ),
85
+ dcc.Store(id='selected-strategies-store', data=default_values["strategy"]),
86
+ html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2')
87
+ ]),
88
+ style=card_style
89
  )
90
 
91
  timing_params_card = dbc.Card(
92
  dbc.CardBody([
93
+ html.H5("Operation Timing (ms)", className="card-title mb-4"),
 
 
 
 
94
  html.Div([
95
+ html.Div([
96
+ dbc.Label("P2P Latency", html_for='p2p_latency', className="form-label d-inline-block me-1"),
97
+ html.I(className="bi bi-info-circle", id="tooltip-target-p2p", style={"cursor": "pointer"})
98
+ ]),
99
+ dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, required=True),
100
+ dbc.FormFeedback("P2P latency must be a number >= 0.", type="invalid", id="feedback-p2p_latency"),
101
+ dbc.Tooltip(
102
+ "Time (ms) for point-to-point communication between adjacent devices.",
103
+ target="tooltip-target-p2p",
104
+ placement="right"
105
+ )
106
  ], className="mb-3"),
107
  html.Div([
108
+ html.Div([
109
+ dbc.Label("Forward Operation Time", html_for='op_time_forward', className="form-label d-inline-block me-1"),
110
+ html.I(className="bi bi-info-circle", id="tooltip-target-fwd", style={"cursor": "pointer"})
111
+ ]),
112
+ dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, required=True),
113
+ dbc.FormFeedback("Forward time must be a number > 0.", type="invalid", id="feedback-op_time_forward"),
114
+ dbc.Tooltip(
115
+ "Time (ms) for a single forward pass of one microbatch through one stage.",
116
+ target="tooltip-target-fwd",
117
+ placement="right"
118
+ )
119
  ], className="mb-3"),
120
  html.Div([
121
+ html.Div([
122
+ dbc.Label("Backward (Combined)", html_for='op_time_backward', className="form-label d-inline-block me-1"),
123
+ html.I(className="bi bi-info-circle", id="tooltip-target-bwd", style={"cursor": "pointer"})
124
+ ]),
125
+ dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
126
+ dbc.FormText("Used when strategy does NOT require split backward."),
127
+ dbc.FormFeedback("Backward time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward"),
128
+ dbc.Tooltip(
129
+ "Time (ms) for a combined backward pass (data gradient + weight gradient) of one microbatch through one stage.",
130
+ target="tooltip-target-bwd",
131
+ placement="right"
132
+ )
133
  ], className="mb-3"),
134
+
135
+ # --- Collapsible Advanced Options (Item 3) ---
136
+ html.Hr(className="my-3"),
137
+ dbc.Switch(
138
+ id="advanced-timing-switch",
139
+ label="Show Advanced Timing Options",
140
+ value=False,
141
+ className="mb-3"
142
+ ),
143
+ dbc.Collapse(
144
+ id="advanced-timing-collapse",
145
+ is_open=False,
146
+ children=[
147
+ html.Div([
148
+ html.Div([
149
+ dbc.Label("Backward D (Data Grad)", html_for='op_time_backward_d', className="form-label d-inline-block me-1"),
150
+ html.I(className="bi bi-info-circle", id="tooltip-target-bwd-d", style={"cursor": "pointer"})
151
+ ]),
152
+ dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
153
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
154
+ dbc.FormFeedback("Backward D time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_d"),
155
+ dbc.Tooltip(
156
+ "Time (ms) for the data gradient part of the backward pass.",
157
+ target="tooltip-target-bwd-d",
158
+ placement="right"
159
+ )
160
+ ], className="mb-3"),
161
+ html.Div([
162
+ html.Div([
163
+ dbc.Label("Backward W (Weight Grad)", html_for='op_time_backward_w', className="form-label d-inline-block me-1"),
164
+ html.I(className="bi bi-info-circle", id="tooltip-target-bwd-w", style={"cursor": "pointer"})
165
+ ]),
166
+ dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
167
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
168
+ dbc.FormFeedback("Backward W time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_w"),
169
+ dbc.Tooltip(
170
+ "Time (ms) for the weight gradient part of the backward pass.",
171
+ target="tooltip-target-bwd-w",
172
+ placement="right"
173
+ )
174
+ ], className="mb-3"),
175
+ html.Div([
176
+ html.Div([
177
+ dbc.Label("Overlapped Forward+Backward", html_for='op_time_overlapped_fwd_bwd', className="form-label d-inline-block me-1"),
178
+ html.I(className="bi bi-info-circle", id="tooltip-target-overlap", style={"cursor": "pointer"})
179
+ ]),
180
+ dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Defaults to Fwd + Bwd", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
181
+ dbc.FormText("Specify if Forward and Backward ops overlap completely."),
182
+ dbc.FormFeedback("Overlapped time must be > 0 if specified.", type="invalid", id="feedback-op_time_overlapped_fwd_bwd"),
183
+ dbc.Tooltip(
184
+ "Optional: Specify a single time (ms) if the forward and backward passes for a microbatch can be fully overlapped within the same stage execution slot.",
185
+ target="tooltip-target-overlap",
186
+ placement="right"
187
+ )
188
+ ], className="mb-3"),
189
+ ]
190
+ )
191
+ ]),
192
+ style=card_style
193
  )
194
 
195
  # Updated app layout using dbc components and structure
196
  app.layout = dbc.Container([
197
  html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
198
 
199
+ # Main Row with Left (Graphs) and Right (Controls) Columns
200
  dbc.Row([
201
+ # --- Left Column (Graphs Area) ---
 
 
 
 
 
 
 
 
 
 
 
202
  dbc.Col([
203
+ # Output Area for Graphs
204
  dcc.Loading(
205
  id="loading-graph-area",
206
  type="circle",
207
+ children=html.Div(id='graph-output-container', style={"minHeight": "600px"})
208
  )
209
+ ], lg=8, md=7, sm=12, className="mb-4 mb-lg-0"),
210
+
211
+ # --- Right Column (Controls Area) ---
212
+ dbc.Col([
213
+ # Parameter Cards Stacked Vertically
214
+ basic_params_card,
215
+ scheduling_params_card,
216
+ timing_params_card,
217
+
218
+ # Generate Button below the cards in the right column
219
+ dbc.Row([
220
+ dbc.Col(
221
+ dbc.Button(
222
+ 'Generate Schedule',
223
+ id='generate-button',
224
+ n_clicks=0,
225
+ color="primary",
226
+ className="w-100",
227
+ disabled=False
228
+ ),
229
+ )
230
+ ], className="mt-3")
231
+ ], lg=4, md=5, sm=12)
232
+ ]),
233
+
234
+ # --- Toast Container (Positioned Fixed) ---
235
+ html.Div(id="toast-container", style={"position": "fixed", "top": 20, "right": 20, "zIndex": 1050})
236
 
237
+ ], fluid=True, className="py-4")
238
+
239
+ # --- Callback for Input Validation and Generate Button State ---
240
+ @app.callback(
241
+ Output('generate-button', 'disabled'),
242
+ # Outputs to control the 'invalid' state of Inputs
243
+ Output('num_devices', 'invalid'),
244
+ Output('num_stages', 'invalid'),
245
+ Output('num_batches', 'invalid'),
246
+ Output('p2p_latency', 'invalid'),
247
+ Output('op_time_forward', 'invalid'),
248
+ Output('op_time_backward', 'invalid'),
249
+ Output('op_time_backward_d', 'invalid'),
250
+ Output('op_time_backward_w', 'invalid'),
251
+ Output('op_time_overlapped_fwd_bwd', 'invalid'),
252
+ # Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
253
+ # We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
254
+ # Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
255
+ Output('strategy-selection-feedback', 'children', allow_duplicate=True), # Update feedback from validation callback too
256
+ # Inputs: Trigger validation whenever any relevant input changes
257
+ Input('num_devices', 'value'),
258
+ Input('num_stages', 'value'),
259
+ Input('num_batches', 'value'),
260
+ Input('p2p_latency', 'value'),
261
+ Input('op_time_forward', 'value'),
262
+ Input('op_time_backward', 'value'),
263
+ Input('op_time_backward_d', 'value'),
264
+ Input('op_time_backward_w', 'value'),
265
+ Input('op_time_overlapped_fwd_bwd', 'value'),
266
+ Input('selected-strategies-store', 'data'), # Validate strategy selection
267
+ prevent_initial_call=True # Prevent callback running on page load before user interaction
268
+ )
269
+ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
270
+ op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
271
+ op_time_overlapped_fwd_bwd, selected_strategies):
272
+ is_invalid = {
273
+ "num_devices": num_devices is None or num_devices < 1,
274
+ "num_stages": num_stages is None or num_stages < 1,
275
+ "num_batches": num_batches is None or num_batches < 1,
276
+ "p2p_latency": p2p_latency is None or p2p_latency < 0,
277
+ "op_time_forward": op_time_forward is None or op_time_forward <= 0,
278
+ "op_time_backward": op_time_backward is not None and op_time_backward <= 0,
279
+ "op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
280
+ "op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
281
+ "op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
282
+ }
283
+
284
+ # Validate strategy selection
285
+ strategy_feedback = "" # Default empty feedback
286
+ if not selected_strategies or len(selected_strategies) == 0:
287
+ is_invalid["strategies"] = True
288
+ strategy_feedback = "Please select at least one strategy."
289
+ else:
290
+ is_invalid["strategies"] = False
291
+ # Additional validation: Check if required timings are provided for selected strategies
292
+ needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies)
293
+ needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies)
294
+
295
+ if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None):
296
+ is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0
297
+ is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0
298
+ # We might want specific feedback here, but setting invalid=True is often enough
299
+
300
+ if needs_combined_backward and op_time_backward is None:
301
+ is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0
302
+
303
+ # Check if any input is invalid
304
+ overall_invalid = any(is_invalid.values())
305
+
306
+ # Disable button if any validation fails
307
+ disable_button = overall_invalid
308
+
309
+ # Return button state and invalid states for each input
310
+ return (
311
+ disable_button,
312
+ is_invalid["num_devices"],
313
+ is_invalid["num_stages"],
314
+ is_invalid["num_batches"],
315
+ is_invalid["p2p_latency"],
316
+ is_invalid["op_time_forward"],
317
+ is_invalid["op_time_backward"],
318
+ is_invalid["op_time_backward_d"],
319
+ is_invalid["op_time_backward_w"],
320
+ is_invalid["op_time_overlapped_fwd_bwd"],
321
+ strategy_feedback # Update strategy feedback based on validation
322
+ )
323
+
324
+ # --- Callback to toggle Advanced Options Collapse ---
325
  @app.callback(
326
+ Output("advanced-timing-collapse", "is_open"),
327
+ Input("advanced-timing-switch", "value"),
328
+ prevent_initial_call=True,
329
+ )
330
+ def toggle_advanced_options(switch_value):
331
+ return switch_value
332
+
333
+ # --- Client-side Callback for Strategy ButtonGroup ---
334
+ app.clientside_callback(
335
+ ClientsideFunction(
336
+ namespace='clientside',
337
+ function_name='update_strategy_selection'
338
+ ),
339
+ Output('selected-strategies-store', 'data'),
340
+ Output({'type': 'strategy-button', 'index': ALL}, 'active'),
341
+ Output({'type': 'strategy-button', 'index': ALL}, 'color'),
342
+ Output({'type': 'strategy-button', 'index': ALL}, 'outline'),
343
+ Output('strategy-selection-feedback', 'children'),
344
+ Input({'type': 'strategy-button', 'index': ALL}, 'n_clicks'),
345
+ State('selected-strategies-store', 'data'),
346
+ prevent_initial_call=True
347
+ )
348
+
349
+ # --- Main Graph Update Callback ---
350
+ @app.callback(
351
+ # Output graph container and toast container separately
352
  Output('graph-output-container', 'children'),
353
+ Output('toast-container', 'children'), # Output for toasts
354
  Input('generate-button', 'n_clicks'),
355
  State('num_devices', 'value'),
356
  State('num_stages', 'value'),
 
361
  State('op_time_backward_d', 'value'),
362
  State('op_time_backward_w', 'value'),
363
  State('op_time_overlapped_fwd_bwd', 'value'),
364
+ State('selected-strategies-store', 'data'),
365
  prevent_initial_call=True
366
  )
367
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
 
369
  op_time_overlapped_fwd_bwd,
370
  selected_strategies):
371
 
 
372
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
373
+
374
+ graph_components = [] # Renamed from output_components
375
+ toast_components = [] # New list for toasts
376
+ valid_results = []
377
+ error_messages = []
378
+ automatic_adjustments = []
379
+
380
+ # Use a variable to track if initial validation fails
381
+ initial_validation_error = None
382
 
383
  if not selected_strategies:
384
+ initial_validation_error = dbc.Toast(
385
+ "Please select at least one scheduling strategy.",
386
+ header="Input Error",
387
+ icon="warning",
388
+ duration=4000,
389
+ is_open=True,
390
+ className="border-warning"
391
+ )
392
+ elif not all([num_devices, num_stages, num_batches, op_time_forward]):
393
+ initial_validation_error = dbc.Toast(
394
+ "Missing required basic input values (Devices, Stages, Batches, Forward Time).",
395
+ header="Input Error",
396
+ icon="danger",
397
+ duration=4000,
398
+ is_open=True,
399
+ className="border-danger"
400
+ )
401
 
402
+ if initial_validation_error:
403
+ # Return empty graph list and the validation error toast
404
+ return [], [initial_validation_error]
405
 
406
  for strategy in selected_strategies:
407
  error_message = ""
 
413
 
414
  # Apply automatic adjustments for dualpipe
415
  if strategy == "dualpipe" and num_stages != num_devices:
416
+ current_num_stages = num_devices
417
+ adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
418
+ automatic_adjustments.append(adjustment_msg)
 
419
 
420
  # Apply automatic adjustments for strategies that require num_stages == num_devices
421
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
422
  current_num_stages = num_devices
423
+ adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
424
+ automatic_adjustments.append(adjustment_msg)
 
425
 
426
  split_backward = strategy in ["zb1p", "dualpipe"]
427
 
 
433
  if not error_message:
434
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
435
  placement_strategy = "standard"
 
436
  elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
437
  placement_strategy = "interleave"
438
  if current_num_stages % current_num_devices != 0:
439
+ error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices."
440
  elif strategy == "dualpipe":
441
  placement_strategy = "dualpipe"
442
  if current_num_stages % 2 != 0:
443
+ error_message = f"Strategy '{strategy}': Requires an even number of stages."
444
 
445
  # Create adjusted operation times based on placement strategy
446
  if not error_message:
447
  try:
 
448
  stages_per_device = current_num_stages // current_num_devices
 
 
 
 
449
  time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
450
+
451
  if stages_per_device > 1:
452
+ adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)."
453
+ # Avoid adding duplicate adjustment messages if already added above
454
+ if adjustment_msg not in automatic_adjustments:
455
+ automatic_adjustments.append(adjustment_msg)
456
+
457
+ op_times = { "forward": float(op_time_forward) * time_scale_factor }
458
+
 
 
459
  if split_backward:
460
  op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
461
  op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
 
462
  op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
463
  else:
464
  op_times["backward"] = float(op_time_backward) * time_scale_factor
 
467
  try:
468
  overlapped_val = float(op_time_overlapped_fwd_bwd)
469
  if overlapped_val > 0:
 
470
  op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
471
  except (ValueError, TypeError):
472
  pass
473
 
474
  config = ScheduleConfig(
475
  num_devices=int(current_num_devices),
476
+ num_stages=int(current_num_stages),
477
  num_batches=int(num_batches),
478
  p2p_latency=float(p2p_latency),
479
  placement_strategy=placement_strategy,
 
487
 
488
  schedule = schedule_func(config)
489
  schedule.execute()
 
 
490
  vis_data = convert_schedule_to_visualization_format(schedule)
491
  valid_results.append((strategy, schedule, vis_data))
492
 
493
  except (AssertionError, ValueError, TypeError) as e:
494
+ error_message = f"Error for '{strategy}': {e}"
 
 
495
  except Exception as e:
496
+ error_message = f"Unexpected error for '{strategy}': {e}"
 
 
497
 
498
  if error_message:
499
  error_messages.append((strategy, error_message))
500
 
501
+ # --- Generate Toasts ---
502
+ # Add toasts for automatic adjustments
503
  for adjustment in automatic_adjustments:
504
+ toast_components.append(
505
+ dbc.Toast(
506
+ adjustment,
507
+ header="Parameter Adjustment",
508
+ icon="info",
509
+ duration=5000, # Slightly longer duration for info
510
+ is_open=True,
511
+ className="border-info"
512
+ )
513
  )
514
 
515
+ # Add toasts for errors
516
+ for strategy, msg in error_messages:
517
+ toast_components.append(
518
+ dbc.Toast(
519
+ msg,
520
+ header=f"Error: {strategy}",
521
+ icon="danger",
522
+ duration=8000, # Longer duration for errors
523
+ is_open=True,
524
+ className="border-danger"
525
+ )
526
+ )
527
+
528
+ # --- Generate Graphs ---
529
  if valid_results:
 
530
  max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
531
+ sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
532
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  for strategy, _, vis_data in sorted_valid_results:
534
  fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
 
 
 
535
  margin = max_execution_time * 0.05
536
  fig.update_layout(
537
+ xaxis=dict(range=[0, max_execution_time + margin])
 
 
538
  )
539
+ graph_components.append(html.Div([
 
540
  html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
541
  dcc.Graph(figure=fig)
542
  ]))
 
 
 
 
 
 
543
 
544
+ # Return graph components and toast components
545
+ return graph_components, toast_components
546
 
547
  # For Hugging Face Spaces deployment
548
  server = app.server
assets/clientside.js ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // assets/clientside.js
2
+
3
+ // Make sure the assets folder is configured correctly in Dash for this to be loaded.
4
+ // Dash automatically serves files from a folder named 'assets' in the root directory.
5
+
6
+ if (!window.dash_clientside) { window.dash_clientside = {}; }
7
+
8
+ window.dash_clientside.clientside = {
9
+ update_strategy_selection: function(n_clicks_all, current_selection) {
10
+ // Determine which button triggered the callback
11
+ const ctx = dash_clientside.callback_context;
12
+ if (!ctx.triggered || ctx.triggered.length === 0) {
13
+ // Should not happen with prevent_initial_call=True, but handle defensively
14
+ return dash_clientside.no_update;
15
+ }
16
+
17
+ const triggered_id_str = ctx.triggered[0].prop_id.split('.')[0];
18
+ if (!triggered_id_str) {
19
+ // If we can't parse the ID, don't update
20
+ return dash_clientside.no_update;
21
+ }
22
+
23
+ // Parse the JSON ID string to get the actual index (strategy name)
24
+ let triggered_index;
25
+ try {
26
+ const triggered_id_obj = JSON.parse(triggered_id_str);
27
+ triggered_index = triggered_id_obj.index;
28
+ } catch (e) {
29
+ console.error("Error parsing callback context ID:", e);
30
+ return dash_clientside.no_update; // Don't update if ID parsing fails
31
+ }
32
+
33
+ // --- Update Selection Logic ---
34
+ // Initialize new_selection as a copy of the current selection
35
+ let new_selection = current_selection ? [...current_selection] : [];
36
+
37
+ // Toggle the selected state
38
+ const index_in_selection = new_selection.indexOf(triggered_index);
39
+ if (index_in_selection > -1) {
40
+ // If already selected, remove it (allow deselecting all for now)
41
+ new_selection.splice(index_in_selection, 1);
42
+ } else {
43
+ // If not selected, add it
44
+ new_selection.push(triggered_index);
45
+ }
46
+
47
+ // --- Prepare Outputs ---
48
+ const all_indices = ctx.inputs_list[0].map(input => input.id.index); // Get all strategy names from the Input IDs
49
+
50
+ // Generate active states, colors, and outlines for ALL buttons
51
+ const active_states = all_indices.map(index => new_selection.includes(index));
52
+ const colors = active_states.map(active => active ? 'primary' : 'secondary'); // 'primary' for active, 'secondary' for inactive
53
+ const outlines = active_states.map(active => !active); // Outline=true for inactive, false for active
54
+
55
+ // Generate validation message
56
+ const feedback = new_selection.length === 0 ? "Please select at least one strategy." : "";
57
+
58
+ // Return updated store data, button states, and feedback
59
+ return [new_selection, active_states, colors, outlines, feedback];
60
+ }
61
+ // Add other clientside functions here if needed
62
+ };
assets/custom.css ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* assets/custom.css */
2
+
3
+ /* --- General & Typography (Item 7, 11) --- */
4
+ body {
5
+ background-color: #F7F9FC; /* Neutral background */
6
+ color: #212B36; /* Dark text */
7
+ font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", Roboto, \"Helvetica Neue\", Arial, sans-serif;
8
+ font-size: 14px;
9
+ }
10
+
11
+ /* Use H1 from dbc.Container/app.layout directly */
12
+ .h1, h1 {
13
+ font-size: 24px; /* H2 equivalent in request */
14
+ font-weight: 600;
15
+ }
16
+
17
+ /* Card titles */
18
+ .card-title.h5, .h5.card-title {
19
+ font-size: 18px; /* H3 equivalent */
20
+ font-weight: 600;
21
+ margin-bottom: 1rem; /* Add space below title */
22
+ }
23
+
24
+ /* Form labels (Item 2) */
25
+ .form-label {
26
+ font-size: 14px;
27
+ font-weight: 500;
28
+ margin-bottom: 0.3rem; /* Space between label and input */
29
+ display: block; /* Ensure it takes full width */
30
+ }
31
+
32
+ /* Form inputs (Item 2) */
33
+ .form-control,
34
+ .form-select {
35
+ font-size: 14px;
36
+ /* width: 100%; Ensure inputs take full width - Bootstrap usually handles this in columns */
37
+ padding: 0.5rem 0.75rem;
38
+ border-radius: 0.375rem; /* Softer corners */
39
+ }
40
+
41
+ /* Form help text */
42
+ .form-text {
43
+ font-size: 12px;
44
+ color: #6c757d; /* Muted color */
45
+ }
46
+
47
+ /* --- Layout & Spacing (Item 1, 7) --- */
48
+ .container-fluid {
49
+ padding-top: 2rem;
50
+ padding-bottom: 2rem;
51
+ }
52
+
53
+ /* Spacing between form rows inside cards */
54
+ .card-body .mb-3 {
55
+ margin-bottom: 1rem !important; /* Default is 1rem, ensure consistency */
56
+ }
57
+
58
+ /* Spacing between cards */
59
+ .card {
60
+ margin-bottom: 24px;
61
+ border: 1px solid #dee2e6; /* Subtle border */
62
+ border-radius: 0.5rem; /* Consistent radius */
63
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05); /* Subtle shadow */
64
+ /* Padding is handled by card-body */
65
+ }
66
+
67
+ /* --- Button Styling (Item 4, 5, 11) --- */
68
+
69
+ /* Primary Action Button (Generate Schedule) */
70
+ #generate-button.btn-primary {
71
+ background-color: #0A74DA; /* Accent color */
72
+ border-color: #0A74DA;
73
+ font-weight: 500;
74
+ padding: 0.6rem 1.2rem; /* Slightly larger padding */
75
+ }
76
+
77
+ #generate-button.btn-primary:hover,
78
+ #generate-button.btn-primary:focus {
79
+ background-color: #085ead; /* Darker accent on hover/focus */
80
+ border-color: #085ead;
81
+ }
82
+
83
+ #generate-button.btn-primary:disabled {
84
+ background-color: #a0cff7; /* Lighter, muted accent when disabled */
85
+ border-color: #a0cff7;
86
+ }
87
+
88
+ /* Strategy Toggle Buttons */
89
+ .btn-group .btn {
90
+ margin-right: 0.5rem; /* Space between buttons */
91
+ margin-bottom: 0.5rem; /* Space for wrapping */
92
+ border-radius: 1rem; /* Pill shape */
93
+ padding: 0.4rem 0.8rem;
94
+ font-size: 13px;
95
+ }
96
+
97
+ /* Active strategy button */
98
+ .btn-group .btn.btn-primary:not(.disabled):not(:disabled).active,
99
+ .btn-group .btn.btn-primary:not(.disabled):not(:disabled):active {
100
+ background-color: #0A74DA; /* Accent color */
101
+ border-color: #0A74DA;
102
+ color: white;
103
+ box-shadow: none; /* Remove default active shadow if needed */
104
+ }
105
+
106
+ /* Inactive strategy button (using outline secondary) */
107
+ .btn-group .btn.btn-outline-secondary {
108
+ border-color: #ced4da;
109
+ color: #495057;
110
+ }
111
+
112
+ .btn-group .btn.btn-outline-secondary:hover {
113
+ background-color: #e9ecef;
114
+ }
115
+
116
+ /* --- Validation Feedback --- */
117
+ .invalid-feedback {
118
+ font-size: 12px;
119
+ margin-top: 0.25rem;
120
+ }
121
+
122
+ /* --- Responsive Adjustments (Item 10) --- */
123
+ /* Bootstrap handles column stacking. We might need more specific rules later */
124
+ /* e.g., adjust chart container width/scrolling on smaller screens */
125
+
126
+ /* Chart Container - Add basic styles, will be refined (Item 9) */
127
+ #graph-output-container .plotly.graph-div {
128
+ /* Add styles for the chart itself if needed */
129
+ }