File size: 14,884 Bytes
c048b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html, Input, Output, State, callback_context
import plotly.graph_objects as go

from src.execution_model import ScheduleConfig, Schedule
from src.strategies import (
    generate_1f1b_schedule,
    generate_zero_bubble_1p_schedule,
    generate_1f1b_overlap_schedule,
    generate_1f1b_interleave_schedule,
    generate_1f1b_interleave_overlap_schedule,
    generate_dualpipe_schedule
)
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure

STRATEGIES = {
    "1f1b": generate_1f1b_schedule,
    "zb1p": generate_zero_bubble_1p_schedule,
    "1f1b_overlap": generate_1f1b_overlap_schedule,
    "1f1b_interleave": generate_1f1b_interleave_schedule,
    "1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
    "dualpipe": generate_dualpipe_schedule,
}

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
app.title = "Pipeline Parallelism Schedule Visualizer"

# Initial default values
default_values = {
    "num_devices": 4,
    "num_stages": 8,
    "num_batches": 16,
    "p2p_latency": 0.0,
    "op_time_forward": 1.0,
    "op_time_backward_d": 1.0,
    "op_time_backward_w": 1.0,
    "op_time_backward": 2.0,
    "strategy": "1f1b_interleave",
    "op_time_overlapped_fwd_bwd": None,
}

# Define input groups using dbc components
basic_params_card = dbc.Card(
    dbc.CardBody([
        html.H5("Basic Parameters", className="card-title"),
        html.Div([
            dbc.Label("Number of Devices (GPUs):"),
            dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Number of Stages (Model Chunks):"),
            dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Number of Microbatches:"),
            dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
        ], className="mb-3"),
        html.Div([
            dbc.Label("P2P Latency (ms):"),
            dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
        ], className="mb-3"),
    ])
)

scheduling_params_card = dbc.Card(
    dbc.CardBody([
        html.H5("Scheduling Parameters", className="card-title"),
        html.Div([
            dbc.Label("Scheduling Strategies:"),
            dbc.Checklist(
                id='strategy-checklist',
                options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
                value=list(STRATEGIES.keys()),
                inline=False,
            ),
        ], className="mb-3"),
    ])
)

timing_params_card = dbc.Card(
    dbc.CardBody([
        html.H5("Operation Timing (ms)", className="card-title"),
        html.Div([
            dbc.Label("Forward:"),
            dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Backward (Combined):"),
            dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
            dbc.FormText("Used when strategy does NOT require split backward."),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Backward D (Data Grad):"),
            dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
            dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Backward W (Weight Grad):"),
            dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
            dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
        ], className="mb-3"),
        html.Div([
            dbc.Label("Overlapped Forward+Backward:"),
            dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
            dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
        ], className="mb-3"),
    ])
)

# Updated app layout using dbc components and structure
app.layout = dbc.Container([
    html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),

    dbc.Row([
        dbc.Col(basic_params_card, md=4),
        dbc.Col(scheduling_params_card, md=4),
        dbc.Col(timing_params_card, md=4),
    ]),

    dbc.Row([
        dbc.Col([
            dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
        ], className="text-center")
    ]),

    dbc.Row([
        dbc.Col([
            dcc.Loading(
                id="loading-graph-area",
                type="circle",
                children=html.Div(id='graph-output-container', className="mt-4")
            )
        ])
    ])
], fluid=True)

@app.callback(
    Output('graph-output-container', 'children'),
    Input('generate-button', 'n_clicks'),
    State('num_devices', 'value'),
    State('num_stages', 'value'),
    State('num_batches', 'value'),
    State('p2p_latency', 'value'),
    State('op_time_forward', 'value'),
    State('op_time_backward', 'value'),
    State('op_time_backward_d', 'value'),
    State('op_time_backward_w', 'value'),
    State('op_time_overlapped_fwd_bwd', 'value'),
    State('strategy-checklist', 'value'),
    prevent_initial_call=True
)
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
                 op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
                 op_time_overlapped_fwd_bwd,
                 selected_strategies):

    # Define the desired display order for strategies
    strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
    
    output_components = []
    valid_results = []  # Store (strategy_name, schedule, vis_data) for valid schedules
    error_messages = []  # Store (strategy_name, error_message) for errors
    automatic_adjustments = []  # Store messages about automatic parameter adjustments

    if not selected_strategies:
        return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]

    if not all([num_devices, num_stages, num_batches, op_time_forward]):
         return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]

    for strategy in selected_strategies:
        error_message = ""
        placement_strategy = ""
        
        # Use local copies of params that might be adjusted for this strategy
        current_num_stages = num_stages
        current_num_devices = num_devices
        
        # Apply automatic adjustments for dualpipe
        if strategy == "dualpipe" and num_stages != num_devices:
            current_num_stages = num_devices  # Force num_stages = num_devices for dualpipe
            automatic_adjustments.append(
                f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
            )

        # Apply automatic adjustments for strategies that require num_stages == num_devices
        if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
            current_num_stages = num_devices
            automatic_adjustments.append(
                f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
            )

        split_backward = strategy in ["zb1p", "dualpipe"]

        if split_backward and not all([op_time_backward_d, op_time_backward_w]):
            error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
        elif not split_backward and not op_time_backward:
            error_message = f"Strategy '{strategy}': Combined Backward time is required."

        if not error_message:
            if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
                placement_strategy = "standard"
                # No need to check num_stages == num_devices as we've enforced it above
            elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
                placement_strategy = "interleave"
                if current_num_stages % current_num_devices != 0:
                    error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
            elif strategy == "dualpipe":
                placement_strategy = "dualpipe"
                if current_num_stages % 2 != 0:
                    error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."

        # Create adjusted operation times based on placement strategy
        if not error_message:
            try:
                # Calculate number of stages per device for time adjustment
                stages_per_device = current_num_stages // current_num_devices
                
                # Calculate scaling factor - this normalizes operation time by stages per device
                # For standard placement (1:1 stage:device mapping), this remains 1.0
                # For interleaved, this scales down the time proportionally
                time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
                
                if stages_per_device > 1:
                    automatic_adjustments.append(
                        f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
                    )
                
                # Apply scaling to operation times
                op_times = { 
                    "forward": float(op_time_forward) * time_scale_factor 
                }
                
                if split_backward:
                    op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
                    op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
                    # Keep combined for compatibility
                    op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
                else:
                    op_times["backward"] = float(op_time_backward) * time_scale_factor

                if op_time_overlapped_fwd_bwd is not None:
                    try:
                        overlapped_val = float(op_time_overlapped_fwd_bwd)
                        if overlapped_val > 0:
                             # Scale overlapped time too
                             op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
                    except (ValueError, TypeError):
                         pass

                config = ScheduleConfig(
                    num_devices=int(current_num_devices),
                    num_stages=int(current_num_stages),  # Use adjusted value
                    num_batches=int(num_batches),
                    p2p_latency=float(p2p_latency),
                    placement_strategy=placement_strategy,
                    split_backward=split_backward,
                    op_times=op_times,
                )

                schedule_func = STRATEGIES.get(strategy)
                if not schedule_func:
                     raise ValueError(f"Invalid strategy function for: {strategy}")

                schedule = schedule_func(config)
                schedule.execute()

                # Store valid results instead of creating figure immediately
                vis_data = convert_schedule_to_visualization_format(schedule)
                valid_results.append((strategy, schedule, vis_data))

            except (AssertionError, ValueError, TypeError) as e:
                 error_message = f"Error generating schedule for '{strategy}': {e}"
                 import traceback
                 traceback.print_exc()
            except Exception as e:
                 error_message = f"An unexpected error occurred for '{strategy}': {e}"
                 import traceback
                 traceback.print_exc()

        if error_message:
             error_messages.append((strategy, error_message))

    # Add alerts for any automatic parameter adjustments
    for adjustment in automatic_adjustments:
        output_components.append(
            dbc.Alert(adjustment, color="info", dismissable=True)
        )

    # If we have valid results, calculate the maximum execution time across all schedules
    if valid_results:
        # Find global maximum execution time
        max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
        
        # Sort valid results according to the display order
        sorted_valid_results = []
        
        # First add strategies in the predefined order
        for strategy_name in strategy_display_order:
            for result in valid_results:
                if result[0] == strategy_name:
                    sorted_valid_results.append(result)
        
        # Then add any remaining strategies that might not be in the predefined order
        for result in valid_results:
            if result[0] not in strategy_display_order:
                sorted_valid_results.append(result)
        
        # Create figures with aligned x-axis, using the sorted results
        for strategy, _, vis_data in sorted_valid_results:
            fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
            
            # Force the x-axis range to be the same for all figures
            # Add a small margin (5%) for better visualization
            margin = max_execution_time * 0.05
            fig.update_layout(
                xaxis=dict(
                    range=[0, max_execution_time + margin]
                )
            )
            
            output_components.append(html.Div([
                html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
                dcc.Graph(figure=fig)
            ]))
    
    # Add error messages to output
    for strategy, msg in error_messages:
        output_components.append(
            dbc.Alert(msg, color="danger", className="mt-3")
        )

    return output_components

# For Hugging Face Spaces deployment
server = app.server

if __name__ == '__main__':
    app.run_server(debug=False, host='0.0.0.0', port=7860)