Victarry commited on
Commit
c048b97
·
0 Parent(s):

Initial commit: PP schedule visualization.

Browse files
Files changed (14) hide show
  1. .gitattributes +2 -0
  2. .gitignore +12 -0
  3. Dockerfile +19 -0
  4. LICENSE +21 -0
  5. README.md +157 -0
  6. app.py +340 -0
  7. conf/config.yaml +25 -0
  8. main.py +156 -0
  9. pyproject.toml +69 -0
  10. requirements.txt +9 -0
  11. src/__init__.py +3 -0
  12. src/execution_model.py +401 -0
  13. src/strategies.py +581 -0
  14. src/visualizer.py +612 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ ./venv
3
+ uv.lock
4
+ outputs/
5
+ .cursor/*
6
+ *.json
7
+
8
+ # Uncomment below if you want to include these files
9
+ # !assets/*.png
10
+ # !assets/*.jpg
11
+ # !docs/*.png
12
+ # !docs/*.jpg
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ FROM python:3.9-slim
3
+
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+ ENV HOME="/home/user"
8
+ WORKDIR /home/user/app
9
+
10
+ COPY --chown=user requirements.txt ./
11
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
+
13
+ COPY --chown=user . ./
14
+
15
+ # Expose the port app will run on
16
+ EXPOSE 7860
17
+
18
+ # Start the app
19
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "app:server"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Parallelism Emulation and Visualization
2
+
3
+ This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
4
+
5
+ ## Overview
6
+
7
+ Pipeline parallelism is a technique used to train large models by partitioning the model across multiple devices and processing data in a pipelined fashion. This project allows you to:
8
+
9
+ - Simulate different pipeline parallelism strategies (1F1B, Interleaved, Zero-Bubble, etc.)
10
+ - Visualize the execution schedule on multiple devices
11
+ - Compare different strategies for efficiency
12
+
13
+ ## Features
14
+
15
+ - **Supported Pipeline Strategies**:
16
+ - 1F1B (One-Forward-One-Backward)
17
+ - Interleaved 1F1B
18
+ - Zero-Bubble 1F1B (ZB-1P)
19
+ - 1F1B with computation-communication overlap
20
+ - Interleaved 1F1B with computation-communication overlap
21
+ - DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
22
+
23
+ - **Visualization**:
24
+ - Interactive visualization dashboard using Plotly/Dash
25
+
26
+ - **Configuration**:
27
+ - Configurable simulation parameters through Hydra
28
+ - Customizable stage latency and communication costs
29
+
30
+ ## Installation
31
+
32
+ This project uses [uv](https://github.com/astral-sh/uv) for dependency management.
33
+
34
+ Setup `uv` if not installed on your computer:
35
+ ```bash
36
+ # On macOS and Linux
37
+ curl -LsSf https://astral.sh/uv/install.sh | sh
38
+ ```
39
+
40
+
41
+ ## Running the Interactive Server
42
+
43
+ To visualize schedules interactively:
44
+
45
+ ```bash
46
+ uv run src/server.py
47
+ ```
48
+
49
+ This will start a Dash server (usually on `http://127.0.0.1:8050/`). Open this URL in your web browser.
50
+
51
+ You can then adjust parameters like the number of devices, stages, batches, operation times, and select different scheduling strategies to see the resulting pipeline visualization.
52
+
53
+ ## Running from Command Line
54
+
55
+ ### Running for 1F1B strategy:
56
+ ```bash
57
+ uv run python main.py strategy=1f1b num_devices=4 num_stages=4 num_batches=8
58
+ ```
59
+ ![1f1b](assets/1f1b.png)
60
+
61
+ ### Running for interleaved strategy:
62
+ ```bash
63
+ uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches=8
64
+ ```
65
+ ![interleave](assets/interleave_1f1b.png)
66
+
67
+ ### Running for ZB-1P strategy:
68
+ ```bash
69
+ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
70
+ ```
71
+ ![zb1p](assets/zb1p.png)
72
+
73
+ ### Running for DualPipe strategy:
74
+ ```bash
75
+ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
76
+ ```
77
+ ![dualpipe](assets/dualpipe.png)
78
+
79
+ ### Running for 1F1B-batch-overlap strategy:
80
+ ```bash
81
+ uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
82
+ ```
83
+ ![1f1b_overlap](assets/1f1b_overlap.png)
84
+
85
+ ### Running for 1F1B-interleave-overlap strategy:
86
+ ```bash
87
+ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=8 num_batches=8
88
+ ```
89
+ ![1f1b_interleave_overlap](assets/1f1b_interleave_overlap.png)
90
+
91
+
92
+ ## Configuration
93
+
94
+ The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
95
+
96
+ #### Override Specific Parameters
97
+
98
+ You can override specific parameters at runtime:
99
+ ```bash
100
+ uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
101
+ ```
102
+
103
+ Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
104
+ ```bash
105
+ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
106
+ ```
107
+
108
+
109
+ ### Using Different Configuration Files
110
+
111
+ You can use different configuration files with Hydra in several ways:
112
+
113
+ #### Recommended Approach
114
+
115
+ 1. Create multiple configuration files in the `conf` directory for different use cases:
116
+ ```
117
+ conf/
118
+ ├── config.yaml # Default configuration
119
+ └── model_A.yaml # Create your own config with stage-specific latency for performance projection
120
+ ```
121
+
122
+ 2. Run with your desired configuration using the `--config-name` flag:
123
+ ```bash
124
+ uv run python main.py --config-name=model_A
125
+ ```
126
+
127
+
128
+ ## Project Structure
129
+
130
+ ```
131
+ PP-Emulation/
132
+ ├── conf/ # Hydra configuration files
133
+ │ └── config.yaml # Default configuration
134
+ ├── src/ # Source code
135
+ │ ├── __init__.py # Package initialization
136
+ │ ├── execution_model.py # Schedule execution models
137
+ │ ├── strategies.py # Pipeline parallelism strategies
138
+ │ └── visualizer.py # Visualization utilities
139
+ ├── main.py # Main entry point
140
+ ├── pyproject.toml # Project metadata and dependencies
141
+ └── README.md # This file
142
+ ```
143
+
144
+ ## References
145
+
146
+ 1. _PipeDream: Fast and Efficient Pipeline Parallel DNN Training_. [arxiv](https://arxiv.org/abs/1806.03377)
147
+ 2. _Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM_. [arxiv](https://arxiv.org/abs/2104.04473)
148
+ 3. _Zero Bubble Pipeline Parallelism_. [arxiv](https://arxiv.org/abs/2401.10241)
149
+ 4. _Communication-Computation Overlap in MoE Training with 1F1B Pipeline Parallelism_. [blog](https://zhuanlan.zhihu.com/p/28463368206)
150
+
151
+ ## License
152
+
153
+ This project is licensed under the MIT License - see the LICENSE file for details.
154
+
155
+ ## Contributing
156
+
157
+ Contributions are welcome! Please feel free to submit a Pull Request.
app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ import dash_bootstrap_components as dbc
3
+ from dash import dcc, html, Input, Output, State, callback_context
4
+ import plotly.graph_objects as go
5
+
6
+ from src.execution_model import ScheduleConfig, Schedule
7
+ from src.strategies import (
8
+ generate_1f1b_schedule,
9
+ generate_zero_bubble_1p_schedule,
10
+ generate_1f1b_overlap_schedule,
11
+ generate_1f1b_interleave_schedule,
12
+ generate_1f1b_interleave_overlap_schedule,
13
+ generate_dualpipe_schedule
14
+ )
15
+ from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
16
+
17
+ STRATEGIES = {
18
+ "1f1b": generate_1f1b_schedule,
19
+ "zb1p": generate_zero_bubble_1p_schedule,
20
+ "1f1b_overlap": generate_1f1b_overlap_schedule,
21
+ "1f1b_interleave": generate_1f1b_interleave_schedule,
22
+ "1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
23
+ "dualpipe": generate_dualpipe_schedule,
24
+ }
25
+
26
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
27
+ app.title = "Pipeline Parallelism Schedule Visualizer"
28
+
29
+ # Initial default values
30
+ default_values = {
31
+ "num_devices": 4,
32
+ "num_stages": 8,
33
+ "num_batches": 16,
34
+ "p2p_latency": 0.0,
35
+ "op_time_forward": 1.0,
36
+ "op_time_backward_d": 1.0,
37
+ "op_time_backward_w": 1.0,
38
+ "op_time_backward": 2.0,
39
+ "strategy": "1f1b_interleave",
40
+ "op_time_overlapped_fwd_bwd": None,
41
+ }
42
+
43
+ # Define input groups using dbc components
44
+ basic_params_card = dbc.Card(
45
+ dbc.CardBody([
46
+ html.H5("Basic Parameters", className="card-title"),
47
+ html.Div([
48
+ dbc.Label("Number of Devices (GPUs):"),
49
+ dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
50
+ ], className="mb-3"),
51
+ html.Div([
52
+ dbc.Label("Number of Stages (Model Chunks):"),
53
+ dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
54
+ ], className="mb-3"),
55
+ html.Div([
56
+ dbc.Label("Number of Microbatches:"),
57
+ dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
58
+ ], className="mb-3"),
59
+ html.Div([
60
+ dbc.Label("P2P Latency (ms):"),
61
+ dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
62
+ ], className="mb-3"),
63
+ ])
64
+ )
65
+
66
+ scheduling_params_card = dbc.Card(
67
+ dbc.CardBody([
68
+ html.H5("Scheduling Parameters", className="card-title"),
69
+ html.Div([
70
+ dbc.Label("Scheduling Strategies:"),
71
+ dbc.Checklist(
72
+ id='strategy-checklist',
73
+ options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
74
+ value=list(STRATEGIES.keys()),
75
+ inline=False,
76
+ ),
77
+ ], className="mb-3"),
78
+ ])
79
+ )
80
+
81
+ timing_params_card = dbc.Card(
82
+ dbc.CardBody([
83
+ html.H5("Operation Timing (ms)", className="card-title"),
84
+ html.Div([
85
+ dbc.Label("Forward:"),
86
+ dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
87
+ ], className="mb-3"),
88
+ html.Div([
89
+ dbc.Label("Backward (Combined):"),
90
+ dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
91
+ dbc.FormText("Used when strategy does NOT require split backward."),
92
+ ], className="mb-3"),
93
+ html.Div([
94
+ dbc.Label("Backward D (Data Grad):"),
95
+ dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
96
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
97
+ ], className="mb-3"),
98
+ html.Div([
99
+ dbc.Label("Backward W (Weight Grad):"),
100
+ dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
101
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
102
+ ], className="mb-3"),
103
+ html.Div([
104
+ dbc.Label("Overlapped Forward+Backward:"),
105
+ dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
106
+ dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
107
+ ], className="mb-3"),
108
+ ])
109
+ )
110
+
111
+ # Updated app layout using dbc components and structure
112
+ app.layout = dbc.Container([
113
+ html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
114
+
115
+ dbc.Row([
116
+ dbc.Col(basic_params_card, md=4),
117
+ dbc.Col(scheduling_params_card, md=4),
118
+ dbc.Col(timing_params_card, md=4),
119
+ ]),
120
+
121
+ dbc.Row([
122
+ dbc.Col([
123
+ dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
124
+ ], className="text-center")
125
+ ]),
126
+
127
+ dbc.Row([
128
+ dbc.Col([
129
+ dcc.Loading(
130
+ id="loading-graph-area",
131
+ type="circle",
132
+ children=html.Div(id='graph-output-container', className="mt-4")
133
+ )
134
+ ])
135
+ ])
136
+ ], fluid=True)
137
+
138
+ @app.callback(
139
+ Output('graph-output-container', 'children'),
140
+ Input('generate-button', 'n_clicks'),
141
+ State('num_devices', 'value'),
142
+ State('num_stages', 'value'),
143
+ State('num_batches', 'value'),
144
+ State('p2p_latency', 'value'),
145
+ State('op_time_forward', 'value'),
146
+ State('op_time_backward', 'value'),
147
+ State('op_time_backward_d', 'value'),
148
+ State('op_time_backward_w', 'value'),
149
+ State('op_time_overlapped_fwd_bwd', 'value'),
150
+ State('strategy-checklist', 'value'),
151
+ prevent_initial_call=True
152
+ )
153
+ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
154
+ op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
155
+ op_time_overlapped_fwd_bwd,
156
+ selected_strategies):
157
+
158
+ # Define the desired display order for strategies
159
+ strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
160
+
161
+ output_components = []
162
+ valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
163
+ error_messages = [] # Store (strategy_name, error_message) for errors
164
+ automatic_adjustments = [] # Store messages about automatic parameter adjustments
165
+
166
+ if not selected_strategies:
167
+ return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
168
+
169
+ if not all([num_devices, num_stages, num_batches, op_time_forward]):
170
+ return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]
171
+
172
+ for strategy in selected_strategies:
173
+ error_message = ""
174
+ placement_strategy = ""
175
+
176
+ # Use local copies of params that might be adjusted for this strategy
177
+ current_num_stages = num_stages
178
+ current_num_devices = num_devices
179
+
180
+ # Apply automatic adjustments for dualpipe
181
+ if strategy == "dualpipe" and num_stages != num_devices:
182
+ current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
183
+ automatic_adjustments.append(
184
+ f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
185
+ )
186
+
187
+ # Apply automatic adjustments for strategies that require num_stages == num_devices
188
+ if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
189
+ current_num_stages = num_devices
190
+ automatic_adjustments.append(
191
+ f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
192
+ )
193
+
194
+ split_backward = strategy in ["zb1p", "dualpipe"]
195
+
196
+ if split_backward and not all([op_time_backward_d, op_time_backward_w]):
197
+ error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
198
+ elif not split_backward and not op_time_backward:
199
+ error_message = f"Strategy '{strategy}': Combined Backward time is required."
200
+
201
+ if not error_message:
202
+ if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
203
+ placement_strategy = "standard"
204
+ # No need to check num_stages == num_devices as we've enforced it above
205
+ elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
206
+ placement_strategy = "interleave"
207
+ if current_num_stages % current_num_devices != 0:
208
+ error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
209
+ elif strategy == "dualpipe":
210
+ placement_strategy = "dualpipe"
211
+ if current_num_stages % 2 != 0:
212
+ error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
213
+
214
+ # Create adjusted operation times based on placement strategy
215
+ if not error_message:
216
+ try:
217
+ # Calculate number of stages per device for time adjustment
218
+ stages_per_device = current_num_stages // current_num_devices
219
+
220
+ # Calculate scaling factor - this normalizes operation time by stages per device
221
+ # For standard placement (1:1 stage:device mapping), this remains 1.0
222
+ # For interleaved, this scales down the time proportionally
223
+ time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
224
+
225
+ if stages_per_device > 1:
226
+ automatic_adjustments.append(
227
+ f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
228
+ )
229
+
230
+ # Apply scaling to operation times
231
+ op_times = {
232
+ "forward": float(op_time_forward) * time_scale_factor
233
+ }
234
+
235
+ if split_backward:
236
+ op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
237
+ op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
238
+ # Keep combined for compatibility
239
+ op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
240
+ else:
241
+ op_times["backward"] = float(op_time_backward) * time_scale_factor
242
+
243
+ if op_time_overlapped_fwd_bwd is not None:
244
+ try:
245
+ overlapped_val = float(op_time_overlapped_fwd_bwd)
246
+ if overlapped_val > 0:
247
+ # Scale overlapped time too
248
+ op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
249
+ except (ValueError, TypeError):
250
+ pass
251
+
252
+ config = ScheduleConfig(
253
+ num_devices=int(current_num_devices),
254
+ num_stages=int(current_num_stages), # Use adjusted value
255
+ num_batches=int(num_batches),
256
+ p2p_latency=float(p2p_latency),
257
+ placement_strategy=placement_strategy,
258
+ split_backward=split_backward,
259
+ op_times=op_times,
260
+ )
261
+
262
+ schedule_func = STRATEGIES.get(strategy)
263
+ if not schedule_func:
264
+ raise ValueError(f"Invalid strategy function for: {strategy}")
265
+
266
+ schedule = schedule_func(config)
267
+ schedule.execute()
268
+
269
+ # Store valid results instead of creating figure immediately
270
+ vis_data = convert_schedule_to_visualization_format(schedule)
271
+ valid_results.append((strategy, schedule, vis_data))
272
+
273
+ except (AssertionError, ValueError, TypeError) as e:
274
+ error_message = f"Error generating schedule for '{strategy}': {e}"
275
+ import traceback
276
+ traceback.print_exc()
277
+ except Exception as e:
278
+ error_message = f"An unexpected error occurred for '{strategy}': {e}"
279
+ import traceback
280
+ traceback.print_exc()
281
+
282
+ if error_message:
283
+ error_messages.append((strategy, error_message))
284
+
285
+ # Add alerts for any automatic parameter adjustments
286
+ for adjustment in automatic_adjustments:
287
+ output_components.append(
288
+ dbc.Alert(adjustment, color="info", dismissable=True)
289
+ )
290
+
291
+ # If we have valid results, calculate the maximum execution time across all schedules
292
+ if valid_results:
293
+ # Find global maximum execution time
294
+ max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
295
+
296
+ # Sort valid results according to the display order
297
+ sorted_valid_results = []
298
+
299
+ # First add strategies in the predefined order
300
+ for strategy_name in strategy_display_order:
301
+ for result in valid_results:
302
+ if result[0] == strategy_name:
303
+ sorted_valid_results.append(result)
304
+
305
+ # Then add any remaining strategies that might not be in the predefined order
306
+ for result in valid_results:
307
+ if result[0] not in strategy_display_order:
308
+ sorted_valid_results.append(result)
309
+
310
+ # Create figures with aligned x-axis, using the sorted results
311
+ for strategy, _, vis_data in sorted_valid_results:
312
+ fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
313
+
314
+ # Force the x-axis range to be the same for all figures
315
+ # Add a small margin (5%) for better visualization
316
+ margin = max_execution_time * 0.05
317
+ fig.update_layout(
318
+ xaxis=dict(
319
+ range=[0, max_execution_time + margin]
320
+ )
321
+ )
322
+
323
+ output_components.append(html.Div([
324
+ html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
325
+ dcc.Graph(figure=fig)
326
+ ]))
327
+
328
+ # Add error messages to output
329
+ for strategy, msg in error_messages:
330
+ output_components.append(
331
+ dbc.Alert(msg, color="danger", className="mt-3")
332
+ )
333
+
334
+ return output_components
335
+
336
+ # For Hugging Face Spaces deployment
337
+ server = app.server
338
+
339
+ if __name__ == '__main__':
340
+ app.run_server(debug=False, host='0.0.0.0', port=7860)
conf/config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default configuration for Pipeline Parallelism Emulation
2
+ num_devices: 4
3
+ num_stages: 4
4
+ num_batches: 8
5
+ visualization_port: 8050
6
+ strategy: "1f1b" # Options: "1f1b", "interleave"
7
+ p2p_latency: 0.0
8
+
9
+ # Operation time configurations
10
+ op_times:
11
+ # Option 1: Simple configuration (same time for all stages)
12
+ forward: 1.0
13
+ backward: 2.0
14
+ backward_D: 1.0
15
+ backward_W: 1.0
16
+ overlapped_forward_backward: 3.0
17
+
18
+ # Option 2: Commented example of stage-specific configuration
19
+ # forward:
20
+ # 0: 0.8 # Stage 0 forward time
21
+ # 1: 1.2 # Stage 1 forward time
22
+ # 2: 1.5 # Stage 2 forward time
23
+ # 3: 0.9 # Stage 3 forward time
24
+ # backward:
25
+ # 0: 1.0 # Stage 0 backward time
main.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.execution_model import ScheduleConfig
2
+ from src.strategies import (
3
+ generate_1f1b_interleave_overlap_schedule,
4
+ generate_1f1b_interleave_schedule,
5
+ generate_1f1b_overlap_schedule,
6
+ generate_1f1b_schedule,
7
+ generate_zero_bubble_1p_schedule,
8
+ generate_dualpipe_schedule,
9
+ )
10
+ from src.visualizer import visualize_pipeline_parallelism_dash
11
+ import hydra
12
+ from omegaconf import DictConfig, OmegaConf
13
+
14
+
15
+ @hydra.main(config_path="conf", config_name="config", version_base=None)
16
+ def main(cfg: DictConfig) -> None:
17
+ """Run pipeline parallelism simulation with the specified configuration."""
18
+ print(f"Running with configuration: {cfg}")
19
+
20
+ if cfg.strategy == "1f1b":
21
+ run_1f1b(cfg)
22
+ elif cfg.strategy == "interleave":
23
+ run_interleave(cfg)
24
+ elif cfg.strategy == "zb1p":
25
+ run_zero_bubble_1p(cfg)
26
+ elif cfg.strategy == "1f1b_overlap":
27
+ run_1f1b_overlap(cfg)
28
+ elif cfg.strategy == "1f1b_interleave_overlap":
29
+ run_1f1b_interleave_overlap(cfg)
30
+ elif cfg.strategy == "dualpipe":
31
+ run_dualpipe(cfg)
32
+ else:
33
+ raise ValueError(f"Unknown strategy: {cfg.strategy}")
34
+
35
+
36
+ def run_1f1b(cfg: DictConfig) -> None:
37
+ """Run 1F1B pipeline parallelism simulation."""
38
+ # Convert OmegaConf to dict for op_times if it exists
39
+ op_times = (
40
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
41
+ )
42
+
43
+ schedule_config = ScheduleConfig(
44
+ num_devices=cfg.num_devices,
45
+ num_stages=cfg.num_stages,
46
+ num_batches=cfg.num_batches,
47
+ p2p_latency=cfg.p2p_latency,
48
+ op_times=op_times,
49
+ placement_strategy="standard",
50
+ )
51
+ schedule = generate_1f1b_schedule(schedule_config)
52
+ schedule.execute()
53
+
54
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
55
+
56
+
57
+ def run_interleave(cfg: DictConfig) -> None:
58
+ """Run interleaved pipeline parallelism simulation."""
59
+ # Convert OmegaConf to dict for op_times if it exists
60
+ op_times = (
61
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
62
+ )
63
+
64
+ schedule_config = ScheduleConfig(
65
+ num_devices=cfg.num_devices,
66
+ num_stages=cfg.num_stages,
67
+ num_batches=cfg.num_batches,
68
+ p2p_latency=cfg.p2p_latency,
69
+ placement_strategy="interleave",
70
+ op_times=op_times,
71
+ )
72
+ schedule = generate_1f1b_interleave_schedule(schedule_config)
73
+ schedule.execute()
74
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
75
+
76
+
77
+ def run_zero_bubble_1p(cfg: DictConfig) -> None:
78
+ """Run zero bubble 1P pipeline parallelism simulation."""
79
+ # Convert OmegaConf to dict for op_times if it exists
80
+ op_times = (
81
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
82
+ )
83
+
84
+ schedule_config = ScheduleConfig(
85
+ num_devices=cfg.num_devices,
86
+ num_stages=cfg.num_stages,
87
+ num_batches=cfg.num_batches,
88
+ p2p_latency=cfg.p2p_latency,
89
+ op_times=op_times,
90
+ split_backward=True,
91
+ )
92
+ schedule = generate_zero_bubble_1p_schedule(schedule_config)
93
+ schedule.execute()
94
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
95
+
96
+
97
+ def run_1f1b_overlap(cfg: DictConfig) -> None:
98
+ """Run 1F1B overlap pipeline parallelism simulation."""
99
+ # Convert OmegaConf to dict for op_times if it exists
100
+ op_times = (
101
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
102
+ )
103
+
104
+ schedule_config = ScheduleConfig(
105
+ num_devices=cfg.num_devices,
106
+ num_stages=cfg.num_stages,
107
+ num_batches=cfg.num_batches,
108
+ p2p_latency=cfg.p2p_latency,
109
+ op_times=op_times,
110
+ split_backward=False,
111
+ )
112
+ schedule = generate_1f1b_overlap_schedule(schedule_config)
113
+ schedule.execute()
114
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
115
+
116
+ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
117
+ """Run 1F1B interleave overlapped pipeline parallelism simulation."""
118
+ # Convert OmegaConf to dict for op_times if it exists
119
+ op_times = (
120
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
121
+ )
122
+
123
+ schedule_config = ScheduleConfig(
124
+ num_devices=cfg.num_devices,
125
+ num_stages=cfg.num_stages,
126
+ num_batches=cfg.num_batches,
127
+ p2p_latency=cfg.p2p_latency,
128
+ placement_strategy="interleave",
129
+ op_times=op_times,
130
+ )
131
+ schedule = generate_1f1b_interleave_overlap_schedule(schedule_config)
132
+ schedule.execute()
133
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
134
+
135
+ def run_dualpipe(cfg: DictConfig) -> None:
136
+ """Run DualPipe pipeline parallelism simulation."""
137
+ # Convert OmegaConf to dict for op_times if it exists
138
+ op_times = (
139
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
140
+ )
141
+
142
+ schedule_config = ScheduleConfig(
143
+ num_devices=cfg.num_devices,
144
+ num_stages=cfg.num_stages,
145
+ num_batches=cfg.num_batches,
146
+ p2p_latency=cfg.p2p_latency,
147
+ op_times=op_times,
148
+ split_backward=True,
149
+ placement_strategy="dualpipe",
150
+ )
151
+ schedule = generate_dualpipe_schedule(schedule_config)
152
+ schedule.execute()
153
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
154
+
155
+ if __name__ == "__main__":
156
+ main()
pyproject.toml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pp-emulation"
7
+ version = "0.1.0"
8
+ description = "Pipeline Parallelism Emulation and Visualization"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [
12
+ {name = "Zhenhuan Liu"}
13
+ ]
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "dash>=2.14.0",
21
+ "hydra-core>=1.3.2",
22
+ "omegaconf>=2.3.0",
23
+ "plotly>=5.18.0",
24
+ "pandas>=2.1.0",
25
+ "numpy>=1.26.0",
26
+ "tqdm>=4.67.0",
27
+ "dash-bootstrap-components>=1.7.1",
28
+ "gunicorn>=23.0.0",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=7.4.0",
34
+ "black>=23.7.0",
35
+ "isort>=5.12.0",
36
+ "mypy>=1.5.1",
37
+ ]
38
+
39
+ # Add Hatch configuration to explicitly define where source code is located
40
+ [tool.hatch.build.targets.wheel]
41
+ packages = ["src"]
42
+
43
+ [tool.hatch.build.targets.sdist]
44
+ include = [
45
+ "src",
46
+ "main.py",
47
+ "conf",
48
+ "LICENSE",
49
+ "README.md",
50
+ ]
51
+
52
+ [tool.black]
53
+ line-length = 88
54
+ target-version = ["py310"]
55
+
56
+ [tool.isort]
57
+ profile = "black"
58
+ line_length = 88
59
+
60
+ [tool.mypy]
61
+ python_version = "3.10"
62
+ warn_return_any = true
63
+ warn_unused_configs = true
64
+ disallow_untyped_defs = true
65
+ disallow_incomplete_defs = true
66
+
67
+ [tool.pytest]
68
+ testpaths = ["tests"]
69
+ pythonpath = ["."]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dash==2.14.2
2
+ dash-bootstrap-components==1.7.1
3
+ plotly==5.18.0
4
+ gunicorn==21.2.0
5
+ hydra-core==1.3.2
6
+ omegaconf==2.3.0
7
+ pandas==2.1.0
8
+ numpy==1.26.0
9
+ tqdm==4.67.0
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Pipeline Parallelism Emulation and Visualization package."""
2
+
3
+ __version__ = "0.1.0"
src/execution_model.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, List, Optional, Union
3
+
4
+
5
+ class Operation:
6
+ """Operation is a single operation in the pipeline."""
7
+
8
+ def __init__(self, batch_id: int, stage_id: int, op_type: str):
9
+ self.batch_id = batch_id
10
+ self.stage_id = stage_id
11
+ self.op_type = op_type
12
+ self.device_id = None
13
+
14
+ self.start_time = None
15
+ self.end_time = None
16
+
17
+ def set_end_time(self, end_time: float):
18
+ self.end_time = end_time
19
+
20
+ def set_start_time(self, start_time: float):
21
+ self.start_time = start_time
22
+
23
+ def __repr__(self) -> str:
24
+ return f"Operation(batch_id={self.batch_id}, stage_id={self.stage_id}, op_type={self.op_type})"
25
+
26
+ class OverlappedOperation:
27
+ """Represents multiple operations that are overlapped/executed concurrently."""
28
+
29
+ def __init__(self, operations: List[Operation]):
30
+ self.operations = operations
31
+ self.device_id = operations[0].device_id
32
+
33
+ # Validate all operations are on the same device
34
+ for op in operations:
35
+ assert op.device_id == self.device_id, "All operations must be on the same device"
36
+
37
+ # Create a combined op_type (e.g., "overlapped_forward_backward")
38
+ self.op_type = "overlapped_" + "_".join([op.op_type for op in operations])
39
+
40
+ # Use the batch_id and stage_id of the first operation for identification
41
+ # (though we'll track all operations internally)
42
+ self.batch_id = operations[0].batch_id
43
+ self.stage_id = operations[0].stage_id
44
+
45
+ # Initialize timing information
46
+ self.start_time = None
47
+ self.end_time = None
48
+
49
+ def set_end_time(self, end_time: float):
50
+ self.end_time = end_time
51
+ for op in self.operations:
52
+ op.set_end_time(end_time)
53
+
54
+ def set_start_time(self, start_time: float):
55
+ self.start_time = start_time
56
+ for op in self.operations:
57
+ op.set_start_time(start_time)
58
+
59
+ def __repr__(self) -> str:
60
+ op_str = ", ".join([f"({op.batch_id},{op.stage_id},{op.op_type})" for op in self.operations])
61
+ return f"OverlappedOperation([{op_str}])"
62
+
63
+ class DeviceQueue:
64
+ def __init__(self, stages: List[int], device_id: int):
65
+ self.stages = stages
66
+ self.device_id = device_id
67
+ self.ops = [] # List of operations
68
+
69
+ def add_operation(self, op: Operation):
70
+ assert op.stage_id in self.stages
71
+ self.ops.append(op)
72
+ assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
73
+ op.device_id = self.device_id
74
+
75
+
76
+ class ScheduleConfig:
77
+ def __init__(
78
+ self,
79
+ num_devices: int,
80
+ num_stages: int,
81
+ num_batches: int,
82
+ p2p_latency: float = 0.0,
83
+ placement_strategy: str = "standard",
84
+ split_backward: bool = False,
85
+ op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
86
+ ):
87
+ self.num_devices = num_devices
88
+ self.num_stages = num_stages
89
+ self.num_batches = num_batches
90
+ self.p2p_latency = p2p_latency
91
+ self.placement_strategy = placement_strategy
92
+ self.split_backward = split_backward
93
+
94
+ # Initialize default operation times
95
+ if self.split_backward:
96
+ self.op_times = {
97
+ "forward": 1.0,
98
+ "backward_D": 1.0,
99
+ "backward_W": 1.0,
100
+ "backward": 2.0,
101
+ }
102
+ else:
103
+ self.op_times = {
104
+ "forward": 1.0,
105
+ "backward": 2.0,
106
+ }
107
+
108
+ # Update with user-provided operation times
109
+ if op_times:
110
+ for op_type, times in op_times.items():
111
+ if isinstance(times, dict):
112
+ # If a dict is provided, it maps stage_id -> time
113
+ if op_type not in self.op_times:
114
+ self.op_times[op_type] = {}
115
+ elif not isinstance(self.op_times[op_type], dict):
116
+ # Convert float to dict if needed
117
+ self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
118
+
119
+ # Update with provided stage-specific times
120
+ for stage_id, time in times.items():
121
+ if not isinstance(self.op_times[op_type], dict):
122
+ self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
123
+ self.op_times[op_type][stage_id] = time
124
+ else:
125
+ # If a float is provided, use same time for all stages
126
+ self.op_times[op_type] = times
127
+
128
+ assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices"
129
+ self.num_stages_per_device = num_stages // num_devices
130
+
131
+ self.init_device_to_stages()
132
+ if self.placement_strategy == "dualpipe":
133
+ assert (
134
+ sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
135
+ )
136
+ else:
137
+ assert (
138
+ sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
139
+ )
140
+
141
+ def init_device_to_stages(self):
142
+ if self.placement_strategy == "standard":
143
+ # Evenly distributed
144
+ stages_per_device = self.num_stages // self.num_devices
145
+ self.device_to_stages = defaultdict(list)
146
+ for i in range(self.num_stages):
147
+ device_to_put = i // stages_per_device
148
+ self.device_to_stages[device_to_put].append(i)
149
+ elif self.placement_strategy == "interleave":
150
+ self.device_to_stages = defaultdict(list)
151
+ for i in range(self.num_stages):
152
+ device_to_put = i % self.num_devices
153
+ self.device_to_stages[device_to_put].append(i)
154
+ elif self.placement_strategy == "dualpipe":
155
+ # For DualPipe, each device has two stages
156
+ assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
157
+ assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
158
+ self.device_to_stages = defaultdict(list)
159
+ for i in range(self.num_stages):
160
+ self.device_to_stages[i] = [i, self.num_stages - i - 1]
161
+ else:
162
+ raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
163
+
164
+ def get_op_time(self, op_type: str, stage_id: int):
165
+ # For overlapped operations, extract the original operation types
166
+ if op_type.startswith("overlapped_"):
167
+ if op_type in self.op_times:
168
+ if isinstance(self.op_times[op_type], dict):
169
+ if stage_id in self.op_times[op_type]:
170
+ return self.op_times[op_type][stage_id]
171
+ else:
172
+ raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
173
+ else:
174
+ return self.op_times[op_type]
175
+ else:
176
+ op_parts = op_type.split("_")[1:]
177
+ if len(op_parts) >= 2:
178
+ op_type1, op_type2 = op_parts[0], op_parts[1]
179
+ return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id)
180
+
181
+ if op_type not in self.op_times:
182
+ raise ValueError(f"Invalid operation type: {op_type}")
183
+ times = self.op_times[op_type]
184
+ if isinstance(times, dict):
185
+ # If we have stage-specific times, use those
186
+ if stage_id not in times:
187
+ raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
188
+ return times[stage_id]
189
+ else:
190
+ # If we have a single float, use the same value for all stages
191
+ return times
192
+
193
+
194
+ class Schedule:
195
+ def __init__(self, config: ScheduleConfig, init_ops: bool = True):
196
+ self.ops = {} # (batch_id, stage_id, op_type) -> Operation
197
+ self.device_queues: List[DeviceQueue] = []
198
+ for dev_id in range(config.num_devices):
199
+ self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
200
+ self.config = config
201
+
202
+ if init_ops:
203
+ self.init_operations()
204
+ self.op_to_overlapped = {}
205
+
206
+ def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
207
+ for op in overlapped_op.operations:
208
+ self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
209
+ self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
210
+
211
+ def register_operation(self, op: Operation):
212
+ assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
213
+ self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
214
+
215
+ def init_operations(self):
216
+ op_types = ["forward", "backward"]
217
+ if self.config.split_backward:
218
+ op_types = ["forward", "backward_D", "backward_W"]
219
+ for batch_id in range(self.config.num_batches):
220
+ for stage_id in range(self.config.num_stages):
221
+ for op_type in op_types:
222
+ self.ops[(batch_id, stage_id, op_type)] = Operation(
223
+ batch_id, stage_id, op_type
224
+ )
225
+
226
+ def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
227
+ if (batch_id, stage_id, op_type) in self.op_to_overlapped:
228
+ return self.op_to_overlapped[(batch_id, stage_id, op_type)]
229
+ if allow_none:
230
+ if (batch_id, stage_id, op_type) not in self.ops:
231
+ return None
232
+ return self.ops[(batch_id, stage_id, op_type)]
233
+
234
+ def get_dependencies(self, op: Operation, include_device_dependency=True):
235
+ deps = []
236
+ if isinstance(op, OverlappedOperation):
237
+ for sub_op in op.operations:
238
+ deps.extend(self.get_dependencies(sub_op, include_device_dependency=False))
239
+
240
+ if include_device_dependency:
241
+ device_index = self.device_queues[op.device_id].ops.index(op)
242
+ if device_index > 0:
243
+ deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
244
+ return deps
245
+ if op.op_type == "forward":
246
+ if op.stage_id > 0:
247
+ deps.append(
248
+ (
249
+ self.get_op(op.batch_id, op.stage_id - 1, "forward"),
250
+ self.config.p2p_latency,
251
+ )
252
+ )
253
+ if self.config.split_backward:
254
+ if op.op_type == "backward_D":
255
+ if op.stage_id < self.config.num_stages - 1:
256
+ op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
257
+ if op_bwd_d is not None:
258
+ deps.append(
259
+ (
260
+ op_bwd_d,
261
+ self.config.p2p_latency,
262
+ )
263
+ )
264
+ else:
265
+ deps.append(
266
+ (
267
+ self.get_op(op.batch_id, op.stage_id + 1, "backward"),
268
+ self.config.p2p_latency,
269
+ )
270
+ )
271
+ elif op.op_type == "backward_W":
272
+ if op.stage_id < self.config.num_stages - 1:
273
+ op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
274
+ if op_bwd_d is not None:
275
+ deps.append(
276
+ (
277
+ op_bwd_d,
278
+ self.config.p2p_latency,
279
+ )
280
+ )
281
+ else:
282
+ deps.append(
283
+ (
284
+ self.get_op(op.batch_id, op.stage_id, "backward"),
285
+ self.config.p2p_latency,
286
+ )
287
+ )
288
+ elif op.op_type == "backward":
289
+ if op.stage_id < self.config.num_stages - 1:
290
+ op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
291
+ if op_bwd is not None:
292
+ deps.append(
293
+ (
294
+ op_bwd,
295
+ self.config.p2p_latency,
296
+ )
297
+ )
298
+ else:
299
+ deps.append(
300
+ (
301
+ self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
302
+ self.config.p2p_latency,
303
+ )
304
+ )
305
+ else:
306
+ if op.op_type == "backward":
307
+ if op.stage_id < self.config.num_stages - 1:
308
+ deps.append(
309
+ (
310
+ self.get_op(op.batch_id, op.stage_id + 1, "backward"),
311
+ self.config.p2p_latency,
312
+ )
313
+ )
314
+
315
+ if include_device_dependency:
316
+ device_index = self.device_queues[op.device_id].ops.index(op)
317
+ if device_index > 0:
318
+ deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
319
+ return deps
320
+
321
+ def show(self):
322
+ """Display detailed information about the schedule for debugging purposes."""
323
+ print("\n=== SCHEDULE DETAILS ===")
324
+ print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}")
325
+ print(f"Placement Strategy: {self.config.placement_strategy}")
326
+ print("\n=== DEVICE QUEUES ===")
327
+
328
+ for dev_id in range(self.config.num_devices):
329
+ print(f"\nDEVICE {dev_id} (Stages: {self.device_queues[dev_id].stages}):")
330
+ print("-" * 80)
331
+ print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
332
+ print("-" * 80)
333
+
334
+ for op in self.device_queues[dev_id].ops:
335
+ op_type = op.op_type
336
+ start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
337
+ end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
338
+
339
+ duration = "N/A"
340
+ if op.start_time is not None and op.end_time is not None:
341
+ duration = f"{op.end_time - op.start_time:.2f}"
342
+
343
+ print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
344
+
345
+ # Find the total execution time (if timing info is available)
346
+ if all(op.end_time is not None for op in self.ops.values()):
347
+ total_time = max(op.end_time for op in self.ops.values())
348
+ print(f"\nTotal execution time: {total_time:.2f}")
349
+
350
+ def execute(self):
351
+ # TODO: change the execution order to topological order via DAG
352
+ def execute_op(op: Operation):
353
+ if op.end_time is not None:
354
+ return
355
+ deps = self.get_dependencies(op)
356
+ if len(deps) == 0:
357
+ op.set_start_time(0.0)
358
+ else:
359
+ for dep, gap in deps:
360
+ if dep.end_time is None or dep.start_time is None:
361
+ execute_op(dep)
362
+ op.set_start_time(max(dep.end_time + gap for dep, gap in deps))
363
+ op.set_end_time(op.start_time + self.config.get_op_time(
364
+ op.op_type, op.stage_id
365
+ ))
366
+
367
+ op_num = len(self.device_queues[0].ops)
368
+ for i in range(op_num):
369
+ for dev_id in range(self.config.num_devices):
370
+ if len(self.device_queues[dev_id].ops) <= i:
371
+ continue
372
+ op = self.device_queues[dev_id].ops[i]
373
+ execute_op(op)
374
+
375
+ for op in self.ops.values():
376
+ assert (
377
+ op.start_time is not None
378
+ ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
379
+ assert (
380
+ op.end_time is not None
381
+ ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
382
+
383
+ def get_total_execution_time(self):
384
+ return max(op.end_time for op in self.ops.values())
385
+
386
+ def get_bubble_rate(self):
387
+ actual_time = self.get_total_execution_time()
388
+ ideal_time = 0
389
+ for stage_id in range(self.config.num_stages):
390
+ for op_type in ["forward", "backward"]:
391
+ ideal_time += self.config.get_op_time(op_type, stage_id)
392
+ ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
393
+
394
+ return (actual_time - ideal_time) / ideal_time
395
+
396
+ def get_device_running_time(self):
397
+ device_time = [0] * self.config.num_devices
398
+ for dev_id in range(self.config.num_devices):
399
+ for op in self.device_queues[dev_id].ops:
400
+ device_time[dev_id] += op.end_time - op.start_time
401
+ return device_time
src/strategies.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
3
+
4
+
5
+ def generate_1f1b_schedule(config: ScheduleConfig):
6
+ schedule = Schedule(config)
7
+
8
+ assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
9
+
10
+ for i in range(config.num_devices):
11
+ fwd_batch_id = 0
12
+ bwd_batch_id = 0
13
+ cooldown_batches = warmup_batches = config.num_devices - i - 1
14
+ steady_batches = config.num_batches - warmup_batches
15
+
16
+ for _ in range(warmup_batches):
17
+ schedule.device_queues[i].add_operation(
18
+ schedule.get_op(fwd_batch_id, i, "forward")
19
+ )
20
+ fwd_batch_id += 1
21
+
22
+ for _ in range(steady_batches):
23
+ schedule.device_queues[i].add_operation(
24
+ schedule.get_op(fwd_batch_id, i, "forward")
25
+ )
26
+ fwd_batch_id += 1
27
+ schedule.device_queues[i].add_operation(
28
+ schedule.get_op(bwd_batch_id, i, "backward")
29
+ )
30
+ bwd_batch_id += 1
31
+
32
+ for _ in range(cooldown_batches):
33
+ schedule.device_queues[i].add_operation(
34
+ schedule.get_op(bwd_batch_id, i, "backward")
35
+ )
36
+ bwd_batch_id += 1
37
+
38
+ return schedule
39
+
40
+
41
+ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
42
+ # Create a new schedule with split_backward=True to support backward_D and backward_W operations
43
+ schedule = Schedule(config)
44
+ total_batches = config.num_batches
45
+ assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
46
+ assert config.split_backward, "ZB-1P requires split_backward=True"
47
+
48
+ for i in range(config.num_devices):
49
+ fwd_batch_id = 0
50
+ bwd_d_batch_id = 0
51
+ bwd_w_batch_id = 0
52
+
53
+ cooldown_batches = warmup_batches = config.num_devices - i - 1
54
+ steady_batches = total_batches - warmup_batches
55
+
56
+ for _ in range(warmup_batches):
57
+ schedule.device_queues[i].add_operation(
58
+ schedule.get_op(fwd_batch_id, i, "forward")
59
+ )
60
+ fwd_batch_id += 1
61
+
62
+ for _ in range(steady_batches):
63
+ schedule.device_queues[i].add_operation(
64
+ schedule.get_op(fwd_batch_id, i, "forward")
65
+ )
66
+ schedule.device_queues[i].add_operation(
67
+ schedule.get_op(bwd_d_batch_id, i, "backward_D")
68
+ )
69
+ if fwd_batch_id - bwd_w_batch_id >= config.num_devices - 1:
70
+ schedule.device_queues[i].add_operation(
71
+ schedule.get_op(bwd_w_batch_id, i, "backward_W")
72
+ )
73
+ bwd_w_batch_id += 1
74
+ bwd_d_batch_id += 1
75
+ fwd_batch_id += 1
76
+
77
+ for _ in range(cooldown_batches):
78
+ schedule.device_queues[i].add_operation(
79
+ schedule.get_op(bwd_d_batch_id, i, "backward_D")
80
+ )
81
+
82
+ schedule.device_queues[i].add_operation(
83
+ schedule.get_op(bwd_w_batch_id, i, "backward_W")
84
+ )
85
+
86
+ bwd_w_batch_id += 1
87
+ bwd_d_batch_id += 1
88
+
89
+ while bwd_w_batch_id < total_batches:
90
+ schedule.device_queues[i].add_operation(
91
+ schedule.get_op(bwd_w_batch_id, i, "backward_W")
92
+ )
93
+ bwd_w_batch_id += 1
94
+
95
+ return schedule
96
+
97
+
98
+ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
99
+ schedule = Schedule(config)
100
+
101
+ assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
102
+
103
+ for i in range(config.num_devices):
104
+ fwd_batch_id = 0
105
+ bwd_batch_id = 0
106
+ cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
107
+ steady_batches = config.num_batches - warmup_batches
108
+
109
+ for _ in range(warmup_batches):
110
+ schedule.device_queues[i].add_operation(
111
+ schedule.get_op(fwd_batch_id, i, "forward")
112
+ )
113
+ fwd_batch_id += 1
114
+
115
+ for _ in range(steady_batches):
116
+ fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
117
+ bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
118
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
119
+ schedule.register_overlapped_operation(overlapped_op)
120
+ schedule.device_queues[i].add_operation(overlapped_op)
121
+
122
+ fwd_batch_id += 1
123
+ bwd_batch_id += 1
124
+
125
+ for _ in range(cooldown_batches):
126
+ schedule.device_queues[i].add_operation(
127
+ schedule.get_op(bwd_batch_id, i, "backward")
128
+ )
129
+ bwd_batch_id += 1
130
+
131
+ return schedule
132
+
133
+
134
+ def _get_pp_rank_microbatches(
135
+ num_microbatches,
136
+ num_devices,
137
+ device_id,
138
+ num_stages_per_device,
139
+ microbatch_group_size_per_vp_stage,
140
+ ):
141
+ """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
142
+ total_num_microbatches = num_microbatches * num_stages_per_device
143
+
144
+ if num_devices > 1:
145
+ # Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
146
+ # all workers, followed by more microbatches after depending on
147
+ # stage ID (more forward passes for earlier stages, later stages can
148
+ # immediately start with 1F1B).
149
+ num_warmup_microbatches = (num_devices - device_id - 1) * 2
150
+ num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
151
+ else:
152
+ # forward_backward_no_pipelining
153
+ num_warmup_microbatches = 1
154
+
155
+ if num_warmup_microbatches >= total_num_microbatches:
156
+ num_warmup_microbatches = total_num_microbatches
157
+
158
+ return num_warmup_microbatches
159
+
160
+
161
+ def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
162
+ """Get the schedule table for PP scheduling.
163
+
164
+ Create a tunable schedule lookup table.
165
+ The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
166
+ For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
167
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
168
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
169
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
170
+ """
171
+ schedule_table = []
172
+ for min_microbatch_id_in_group in range(
173
+ 0, num_microbatches, microbatch_group_size_per_vp_stage
174
+ ):
175
+ if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
176
+ # Construct schedule for the last microbatch group
177
+ schedule_table.extend(
178
+ [
179
+ (microbatch_id, model_chunk_id)
180
+ for model_chunk_id in range(num_model_chunks)
181
+ for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
182
+ ]
183
+ )
184
+ else:
185
+ # Construct schedule for other microbatch groups
186
+ schedule_table.extend(
187
+ [
188
+ (microbatch_id, model_chunk_id)
189
+ for model_chunk_id in range(num_model_chunks)
190
+ for microbatch_id in range(
191
+ min_microbatch_id_in_group,
192
+ min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
193
+ )
194
+ ]
195
+ )
196
+ return schedule_table
197
+
198
+
199
+ def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
200
+ """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
201
+ order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
202
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
203
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
204
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
205
+
206
+ Then the forward backward separated order is:
207
+ forward | 1 1 1 2 2 2 1 1 2 2
208
+ backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
209
+
210
+ If num_warmup_microbatches is 5, the output order is:
211
+ 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
212
+ """
213
+ _, model_chunk_id_table = zip(*schedule_table)
214
+ forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
215
+ backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
216
+ order = forward_order[:num_warmup_microbatches]
217
+ for i in range(num_warmup_microbatches, len(forward_order)):
218
+ order.append(forward_order[i])
219
+ order.append(backward_order[i - num_warmup_microbatches])
220
+ if num_warmup_microbatches > 0:
221
+ order.extend(backward_order[-num_warmup_microbatches:])
222
+ return order
223
+
224
+
225
+ # Some codes are copied from Megatron-LM
226
+ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
227
+ schedule = Schedule(config)
228
+
229
+ for device_id in range(config.num_devices):
230
+ microbatch_group_size_per_vp_stage = config.num_devices
231
+ num_warmup_microbatches = _get_pp_rank_microbatches(
232
+ config.num_batches,
233
+ config.num_devices,
234
+ device_id,
235
+ config.num_stages_per_device,
236
+ microbatch_group_size_per_vp_stage,
237
+ )
238
+
239
+ schedule_table = _get_schedule_table(
240
+ config.num_batches,
241
+ config.num_stages_per_device,
242
+ microbatch_group_size_per_vp_stage,
243
+ )
244
+
245
+ order = _convert_schedule_table_to_order(
246
+ num_warmup_microbatches,
247
+ num_model_chunks=config.num_stages_per_device,
248
+ schedule_table=schedule_table,
249
+ )
250
+
251
+ cur_stage_microbatch_id = {}
252
+ for i in range(1, config.num_stages_per_device+1):
253
+ cur_stage_microbatch_id[i] = 0
254
+ cur_stage_microbatch_id[-i] = 0
255
+ for order_item in order:
256
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
257
+
258
+ if order_item > 0:
259
+ op_type = "forward"
260
+ micro_batch_id = cur_stage_microbatch_id[order_item]
261
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
262
+ elif order_item < 0:
263
+ op_type = "backward"
264
+ micro_batch_id = cur_stage_microbatch_id[order_item]
265
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
266
+ else:
267
+ raise ValueError(f"Invalid order item: {order_item}")
268
+ schedule.device_queues[device_id].add_operation(
269
+ schedule.get_op(micro_batch_id, stage_id, op_type)
270
+ )
271
+ return schedule
272
+
273
+ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
274
+ schedule = Schedule(config)
275
+
276
+ for device_id in range(config.num_devices):
277
+ microbatch_group_size_per_vp_stage = config.num_devices
278
+ num_warmup_microbatches = _get_pp_rank_microbatches(
279
+ config.num_batches,
280
+ config.num_devices,
281
+ device_id,
282
+ config.num_stages_per_device,
283
+ microbatch_group_size_per_vp_stage,
284
+ )
285
+
286
+ schedule_table = _get_schedule_table(
287
+ config.num_batches,
288
+ config.num_stages_per_device,
289
+ microbatch_group_size_per_vp_stage,
290
+ )
291
+
292
+ # NOTE: Add one more warmup microbatch for overlapped operations!
293
+ num_warmup_microbatches += 1
294
+ order = _convert_schedule_table_to_order(
295
+ num_warmup_microbatches,
296
+ num_model_chunks=config.num_stages_per_device,
297
+ schedule_table=schedule_table,
298
+ )
299
+
300
+ cur_stage_microbatch_id = {}
301
+ for i in range(1, config.num_stages_per_device+1):
302
+ cur_stage_microbatch_id[i] = 0
303
+ cur_stage_microbatch_id[-i] = 0
304
+ i = 0
305
+
306
+ num_overlapped_batches = len(order) - num_warmup_microbatches * 2
307
+ while i < len(order):
308
+ if i < num_warmup_microbatches:
309
+ order_item = order[i]
310
+ assert order_item > 0
311
+ op_type = "forward"
312
+ micro_batch_id = cur_stage_microbatch_id[order_item]
313
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
314
+
315
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
316
+ schedule.device_queues[device_id].add_operation(
317
+ schedule.get_op(micro_batch_id, stage_id, op_type)
318
+ )
319
+ i += 1
320
+ elif i >= num_warmup_microbatches and i < num_warmup_microbatches + num_overlapped_batches - 1:
321
+ order_item_a = order[i]
322
+ order_item_b = order[i+1]
323
+
324
+ op_type_a = "forward" if order_item_a > 0 else "backward"
325
+ micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
326
+ cur_stage_microbatch_id[order_item_a] = cur_stage_microbatch_id[order_item_a] + 1
327
+
328
+ op_type_b = "forward" if order_item_b > 0 else "backward"
329
+ micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
330
+ cur_stage_microbatch_id[order_item_b] = cur_stage_microbatch_id[order_item_b] + 1
331
+
332
+ stage_id_a = schedule.device_queues[device_id].stages[abs(order_item_a)-1]
333
+ stage_id_b = schedule.device_queues[device_id].stages[abs(order_item_b)-1]
334
+
335
+ op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
336
+ op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
337
+ overlapped_op = OverlappedOperation([op_a, op_b])
338
+ schedule.register_overlapped_operation(overlapped_op)
339
+ schedule.device_queues[device_id].add_operation(overlapped_op)
340
+
341
+ i += 2
342
+ else:
343
+ assert i >= num_warmup_microbatches + num_overlapped_batches
344
+ order_item = order[i]
345
+ assert order_item < 0
346
+ op_type = "backward"
347
+ micro_batch_id = cur_stage_microbatch_id[order_item]
348
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
349
+
350
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
351
+ schedule.device_queues[device_id].add_operation(
352
+ schedule.get_op(micro_batch_id, stage_id, op_type)
353
+ )
354
+ i += 1
355
+
356
+
357
+ return schedule
358
+
359
+
360
+ def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
361
+ """
362
+ Helper function to create overlapped operations correctly.
363
+ This handles the underlying operation creation and registration to avoid device_id issues.
364
+ """
365
+ # Get the operations from the schedule
366
+ op1 = schedule.ops[(batch_id1, stage_id, type1)]
367
+ op2 = schedule.ops[(batch_id2, stage_id, type2)]
368
+
369
+ # Create the overlapped operation
370
+ overlapped_op = OverlappedOperation([op1, op2])
371
+
372
+ # Register in the schedule to ensure proper tracking
373
+ schedule.register_overlapped_operation(overlapped_op)
374
+
375
+ return overlapped_op
376
+
377
+
378
+ def generate_dualpipe_schedule(config: ScheduleConfig):
379
+ """
380
+ Implements the DualPipe scheduling strategy.
381
+
382
+ DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
383
+ and backward computation-communication phases and reduces pipeline bubbles.
384
+
385
+ The DualPipe strategy has the following characteristics:
386
+ 1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
387
+ 2. Each device handles both a forward stage and a reverse stage
388
+ 3. Overlaps forward and backward operations to reduce bubble size
389
+ 4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
390
+ 5. Currently only supports split_backward=True.
391
+
392
+ Args:
393
+ config: The scheduling configuration
394
+
395
+ Returns:
396
+ A Schedule object with the DualPipe scheduling
397
+ """
398
+ # Ensure placement strategy is set for Schedule initialization
399
+ assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
400
+ # Assertions based on DualPipe requirements
401
+ assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
402
+ assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
403
+ assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
404
+ # Assertion based on original implementation: num_chunks >= num_ranks * 2
405
+ # Here, M (config.num_batches) corresponds to half_num_chunks
406
+ assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
407
+ assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
408
+
409
+ schedule = Schedule(config, init_ops=False)
410
+
411
+ num_stages = config.num_stages
412
+ num_devices = config.num_devices
413
+ # config.num_batches is M in the original paper, which corresponds to half_num_chunks
414
+ half_num_chunks = config.num_batches // 2
415
+ num_half_ranks = num_devices // 2
416
+
417
+ fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
418
+ bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
419
+
420
+ waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
421
+
422
+ for device_id in range(num_devices):
423
+ is_in_second_half = device_id >= num_half_ranks
424
+ if is_in_second_half:
425
+ fwd_batch_ids[device_id, 1] = 0
426
+ fwd_batch_ids[device_id, 0] = config.num_batches // 2
427
+ bwd_d_batch_ids[device_id, 1] = 0
428
+ bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
429
+ else:
430
+ fwd_batch_ids[device_id, 0] = 0
431
+ fwd_batch_ids[device_id, 1] = config.num_batches // 2
432
+ bwd_d_batch_ids[device_id, 0] = 0
433
+ bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
434
+ def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
435
+ stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
436
+ stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
437
+ if not is_in_second_half:
438
+ # First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
439
+ return stage_fwd_dir if phase == 0 else stage_rev_dir
440
+ else:
441
+ # Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
442
+ return stage_rev_dir if phase == 0 else stage_fwd_dir
443
+
444
+
445
+ def add_op_to_queue(device_id, stage_id, op_type, batch_id):
446
+ # Retrieve the correct pre-initialized Operation object
447
+ op = Operation(batch_id, stage_id, op_type)
448
+ schedule.register_operation(op)
449
+ # Add to the device queue
450
+ schedule.device_queues[device_id].add_operation(op)
451
+
452
+ def _schedule_forward_chunk(device_id, phase, is_in_second_half):
453
+ """Schedules a forward compute operation."""
454
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
455
+ batch_id = fwd_batch_ids[device_id, phase]
456
+ add_op_to_queue(device_id, stage_id, "forward", batch_id)
457
+ fwd_batch_ids[device_id, phase] += 1
458
+
459
+ def _schedule_backward_chunk(device_id, phase, is_in_second_half):
460
+ """Schedules a backward_D with backward_W compute operation."""
461
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
462
+ batch_id = bwd_d_batch_ids[device_id, phase]
463
+ add_op_to_queue(device_id, stage_id, "backward", batch_id)
464
+ bwd_d_batch_ids[device_id, phase] += 1
465
+
466
+ def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
467
+ """Schedules a backward_D compute operation."""
468
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
469
+ batch_id = bwd_d_batch_ids[device_id, phase]
470
+ add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
471
+ bwd_d_batch_ids[device_id, phase] += 1
472
+ waited_weight_grad[device_id].append((stage_id, batch_id))
473
+
474
+ def _schedule_backward_weight_chunk(device_id):
475
+ """Schedules a backward_W compute operation."""
476
+ stage_id, batch_id = waited_weight_grad[device_id].popleft()
477
+ add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
478
+
479
+ def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
480
+ """Schedules an overlapped forward and backward_D compute operation."""
481
+ fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
482
+ bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
483
+
484
+ fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
485
+
486
+ fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
487
+ schedule.register_operation(fwd_op)
488
+ fwd_batch_ids[device_id, fwd_phase] += 1
489
+
490
+ bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
491
+ bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
492
+ schedule.register_operation(bwd_op)
493
+ bwd_d_batch_ids[device_id, bwd_phase] += 1
494
+
495
+ # Create and register the overlapped operation
496
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
497
+ schedule.register_overlapped_operation(overlapped_op)
498
+
499
+ # Add the overlapped operation to the queue
500
+ schedule.device_queues[device_id].add_operation(overlapped_op)
501
+
502
+
503
+ # Process each device (rank in original code)
504
+ for device_id in range(num_devices):
505
+ half_rank = min(device_id, num_devices - 1 - device_id)
506
+ is_in_second_half = device_id >= num_half_ranks
507
+ is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
508
+
509
+ # Map original steps to operation additions
510
+ # Step 1: nF0
511
+ step_1_count = (num_half_ranks - half_rank - 1) * 2
512
+ for _ in range(step_1_count):
513
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
514
+
515
+ # Step 2: nF0F1
516
+ step_2_count = half_rank + 1
517
+ for i in range(step_2_count):
518
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
519
+ _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
520
+
521
+ # Step 3: nB1W1F1
522
+ step_3_count = num_half_ranks - half_rank - 1
523
+ for _ in range(step_3_count):
524
+ _schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
525
+ _schedule_backward_weight_chunk(device_id,) # W1
526
+ _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
527
+
528
+ # Step 4 (Main step): nF0B1F1B0
529
+ step_4_count = half_num_chunks - num_devices + half_rank + 1
530
+ for i in range(step_4_count):
531
+ # if i == 0 and is_middle_rank:
532
+ # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
533
+ # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
534
+ # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
535
+ # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
536
+ # else:
537
+ # Overlap F0 and B1_D, then schedule W1
538
+ _schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
539
+
540
+ # Overlap F1 and B0_D, then schedule W0
541
+ _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
542
+
543
+ # Step 5: nB1F1B0
544
+ step_5_count = num_half_ranks - half_rank - 1
545
+ for _ in range(step_5_count):
546
+ _schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
547
+ _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
548
+
549
+ # Step 6: nB1B0
550
+ step_6_count = half_rank + 1
551
+ enable_zb = False
552
+ for i in range(step_6_count):
553
+ if i == step_6_count // 2 and half_rank % 2 == 1:
554
+ enable_zb = True
555
+ if enable_zb:
556
+ _schedule_backward_input_chunk(device_id, 1, is_in_second_half)
557
+ else:
558
+ _schedule_backward_chunk(device_id, 1, is_in_second_half)
559
+ if i == step_6_count // 2 and half_rank % 2 == 0:
560
+ enable_zb = True
561
+ if enable_zb:
562
+ _schedule_backward_input_chunk(device_id, 0, is_in_second_half)
563
+ else:
564
+ _schedule_backward_chunk(device_id, 0, is_in_second_half)
565
+
566
+ # Step 7: nWB0
567
+ step_7_count = num_half_ranks - half_rank - 1
568
+ for _ in range(step_7_count):
569
+ _schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
570
+ _schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
571
+
572
+ # Step 8: nW
573
+ step_8_count = half_rank + 1
574
+ for _ in range(step_8_count):
575
+ # W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
576
+ # W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
577
+ # The last W0 gradients correspond to B0_D from step 6 or 7.
578
+ _schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
579
+
580
+ return schedule
581
+
src/visualizer.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ from dash import dcc, html
3
+ from dash.dependencies import Input, Output
4
+ import plotly.graph_objects as go
5
+ from typing import List, Dict
6
+ from tqdm import tqdm
7
+ from functools import lru_cache
8
+ import webbrowser
9
+ from threading import Timer
10
+
11
+ from src.execution_model import Schedule, OverlappedOperation
12
+
13
+
14
+ def convert_schedule_to_visualization_format(schedule: Schedule):
15
+ """
16
+ Converts a Schedule object to the format needed for visualization.
17
+
18
+ Returns:
19
+ Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
20
+ """
21
+ # Make sure all operations have start and end times
22
+ for op in schedule.ops.values():
23
+ if op.start_time is None or op.end_time is None:
24
+ raise ValueError(
25
+ "Operations must have start and end times. Run ScheduleExecutor.execute() first."
26
+ )
27
+
28
+ visualization_data = {}
29
+
30
+ # Organize operations by device
31
+ for device_id, device_queue in enumerate(schedule.device_queues):
32
+ visualization_data[device_id] = []
33
+
34
+ for op in device_queue.ops:
35
+ # Handle both regular Operations and OverlappedOperations
36
+ if isinstance(op, OverlappedOperation):
37
+ visualization_data[device_id].append(
38
+ {
39
+ "type": op.op_type,
40
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
41
+ "stage": op.stage_id,
42
+ "start_time": op.start_time,
43
+ "duration": op.end_time - op.start_time,
44
+ "is_overlapped": True,
45
+ "operations": [
46
+ {
47
+ "type": nested_op.op_type,
48
+ "batch": nested_op.batch_id + 1,
49
+ "stage": nested_op.stage_id
50
+ }
51
+ for nested_op in op.operations
52
+ ]
53
+ }
54
+ )
55
+ else:
56
+ visualization_data[device_id].append(
57
+ {
58
+ "type": op.op_type,
59
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
60
+ "stage": op.stage_id,
61
+ "start_time": op.start_time,
62
+ "duration": op.end_time - op.start_time,
63
+ "is_overlapped": False
64
+ }
65
+ )
66
+
67
+ return visualization_data
68
+
69
+
70
+ # Cache the color calculation as it's repeatedly called with the same parameters
71
+ @lru_cache(maxsize=128)
72
+ def get_color(op_type: str, stage_id: int, num_devices: int):
73
+ # A more harmonious blue palette with low saturation and high brightness
74
+ forward_colors = [
75
+ "#0a5aff", # Intense blue
76
+ "#4c88ff", # Blue (deeper)
77
+ "#7aa7ff", # Medium blue
78
+ "#a8c5ff", # Soft blue
79
+ "#d6e4ff", # Very light blue
80
+ ]
81
+
82
+ # Orange palette for backward operations with low saturation and high brightness
83
+ backward_colors = [
84
+ "#f47b00", # Intense orange
85
+ "#ffa952", # Orange
86
+ "#ffc78e", # Light orange
87
+ "#ffe6cc", # Very light orange
88
+ ]
89
+
90
+ # Improved teal/turquoise palette with low saturation and high brightness
91
+ backward_d_colors = [
92
+ "#4dcccc", # Light teal
93
+ "#33b3b3", # Teal
94
+ "#009999", # Medium teal
95
+ "#008080", # Dark teal
96
+ ]
97
+
98
+ # Improved green palette with low saturation and high brightness
99
+ backward_w_colors = [
100
+ "#33b373", # Medium green
101
+ "#009959", # Forest green
102
+ "#008040", # Dark green
103
+ ]
104
+
105
+ virtual_stage = stage_id // num_devices
106
+
107
+ # If virtual_stage is beyond our color list, cycle through the colors
108
+ color_index = virtual_stage % len(forward_colors)
109
+
110
+ if op_type == "forward":
111
+ return forward_colors[color_index]
112
+ elif op_type == "backward":
113
+ return backward_colors[color_index % len(backward_colors)]
114
+ elif op_type == "backward_D":
115
+ return backward_d_colors[color_index % len(backward_d_colors)]
116
+ elif op_type == "backward_W":
117
+ return backward_w_colors[color_index % len(backward_w_colors)]
118
+ else:
119
+ raise ValueError(f"Invalid operation type: {op_type}")
120
+
121
+
122
+ def create_pipeline_figure(
123
+ schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
124
+ ):
125
+ """
126
+ Create a Plotly figure for pipeline parallelism scheduling.
127
+
128
+ Args:
129
+ schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
130
+ max_time: Optional maximum time to display
131
+ show_progress: Whether to show a progress bar
132
+ """
133
+ # Find the number of devices
134
+ num_devices = len(schedule_data)
135
+
136
+ empty_color = "whitesmoke"
137
+
138
+ # Find the maximum time in the schedule if not provided
139
+ if max_time is None:
140
+ max_time = 0
141
+ for device in schedule_data:
142
+ for task in schedule_data[device]:
143
+ end_time = task["start_time"] + task["duration"]
144
+ if end_time > max_time:
145
+ max_time = end_time
146
+
147
+ # Determine maximum batch number to decide whether to show text labels
148
+ max_batch = 0
149
+ for device in schedule_data:
150
+ for task in schedule_data[device]:
151
+ max_batch = max(max_batch, task["batch"])
152
+
153
+ # Flag to determine whether to show text labels
154
+ num_operations_per_device = len(schedule_data[0])
155
+ show_text_labels = num_operations_per_device <= 64
156
+
157
+ # Create a figure
158
+ fig = go.Figure()
159
+
160
+ # Initialize progress tracking
161
+ total_tasks = sum(len(tasks) for tasks in schedule_data.values())
162
+ tasks_processed = 0
163
+
164
+ if show_progress:
165
+ progress_bar = tqdm(
166
+ total=total_tasks + num_devices + 3, desc="Creating visualization"
167
+ )
168
+
169
+ # Create a custom y-axis with no gaps between devices
170
+ y_spacing = 1.0 # Use 1.0 for no gaps
171
+
172
+ # Batch processing for increased performance
173
+ shapes = []
174
+ annotations = []
175
+ hover_traces = []
176
+
177
+ # Add rectangles for each task
178
+ for device_idx, device in enumerate(schedule_data):
179
+ device_idx_reversed = num_devices - device_idx - 1
180
+
181
+ # Sort tasks by start time to ensure correct rendering
182
+ sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
183
+
184
+ for task in sorted_tasks:
185
+ # Calculate y positions with no gaps
186
+ y_pos = device_idx_reversed * y_spacing
187
+ start_time = task["start_time"]
188
+ duration = task["duration"]
189
+
190
+ # Special handling for overlapped operations
191
+ if task.get("is_overlapped", False) and "operations" in task:
192
+ # Prepare hover text for the entire overlapped operation
193
+ op_details = "<br>".join([
194
+ f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
195
+ for op in task["operations"]
196
+ ])
197
+ hover_text = (
198
+ f"Overlapped Operations:<br>{op_details}<br>"
199
+ f"Start: {task['start_time']:.2f}<br>"
200
+ f"End: {task['start_time'] + task['duration']:.2f}<br>"
201
+ f"Duration: {task['duration']:.2f}"
202
+ )
203
+
204
+ # Add invisible marker for hover info
205
+ hover_traces.append(
206
+ dict(
207
+ x=[start_time + duration / 2],
208
+ y=[y_pos],
209
+ mode="markers",
210
+ marker=dict(opacity=0), # Invisible marker
211
+ hoverinfo="text",
212
+ text=hover_text,
213
+ showlegend=False,
214
+ )
215
+ )
216
+
217
+ # Calculate height of each sub-operation
218
+ sub_height = 1.0 / len(task["operations"])
219
+
220
+ # Add rectangles and annotations for each sub-operation
221
+ for i, sub_op in enumerate(task["operations"]):
222
+ # Determine color for this sub-operation
223
+ color = get_color(sub_op["type"], sub_op["stage"], num_devices)
224
+
225
+ # Calculate y position for this sub-operation
226
+ sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
227
+ sub_y_pos_top = sub_y_pos_bottom + sub_height
228
+ sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
229
+
230
+ # Add rectangle for this sub-operation
231
+ shapes.append(
232
+ dict(
233
+ type="rect",
234
+ x0=start_time,
235
+ y0=sub_y_pos_bottom,
236
+ x1=start_time + duration,
237
+ y1=sub_y_pos_top,
238
+ line=dict(color="black", width=0.5),
239
+ fillcolor=color,
240
+ layer="above",
241
+ )
242
+ )
243
+
244
+ # Add batch number text for this sub-operation only if show_text_labels is True
245
+ if show_text_labels:
246
+ # Determine text color based on background color
247
+ if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
248
+ text_color = "black"
249
+ else:
250
+ text_color = "white"
251
+
252
+ annotations.append(
253
+ dict(
254
+ x=start_time + duration / 2,
255
+ y=sub_y_center,
256
+ text=f"{sub_op['batch']}",
257
+ showarrow=False,
258
+ font=dict(color=text_color, size=12, family="Arial, bold"),
259
+ )
260
+ )
261
+ else:
262
+ # Regular (non-overlapped) operation
263
+ # Determine task color and text color
264
+ if task["type"] == "forward":
265
+ color = get_color(task["type"], task["stage"], num_devices)
266
+ text_color = "white"
267
+ name = "Forward"
268
+ elif task["type"] == "backward":
269
+ color = get_color(task["type"], task["stage"], num_devices)
270
+ text_color = "black"
271
+ name = "Backward"
272
+ elif task["type"] == "backward_D":
273
+ color = get_color(task["type"], task["stage"], num_devices)
274
+ text_color = "black"
275
+ name = "Backward (Grad)"
276
+ elif task["type"] == "backward_W":
277
+ color = get_color(task["type"], task["stage"], num_devices)
278
+ text_color = "black"
279
+ name = "Backward (Weight)"
280
+ else:
281
+ color = empty_color
282
+ text_color = "black"
283
+ name = "Unknown"
284
+
285
+ # Add rectangle for the task
286
+ shapes.append(
287
+ dict(
288
+ type="rect",
289
+ x0=start_time,
290
+ y0=y_pos - 0.5,
291
+ x1=start_time + duration,
292
+ y1=y_pos + 0.5,
293
+ line=dict(color="black", width=0.5),
294
+ fillcolor=color,
295
+ layer="above",
296
+ )
297
+ )
298
+
299
+ # Add batch number text only if show_text_labels is True
300
+ if show_text_labels:
301
+ annotations.append(
302
+ dict(
303
+ x=start_time + duration / 2,
304
+ y=y_pos,
305
+ text=f"{task['batch']}",
306
+ showarrow=False,
307
+ font=dict(color=text_color, size=12, family="Arial, bold"),
308
+ )
309
+ )
310
+
311
+ # Prepare hover data
312
+ hover_text = (
313
+ f"Batch: {task['batch']}<br>"
314
+ f"Stage: {task['stage']}<br>"
315
+ f"Type: {name}<br>"
316
+ f"Start: {task['start_time']:.2f}<br>"
317
+ f"End: {task['start_time'] + task['duration']:.2f}<br>"
318
+ f"Duration: {task['duration']:.2f}"
319
+ )
320
+
321
+ hover_traces.append(
322
+ dict(
323
+ x=[start_time + duration / 2],
324
+ y=[y_pos],
325
+ mode="markers",
326
+ marker=dict(opacity=0), # Invisible marker
327
+ hoverinfo="text",
328
+ text=hover_text,
329
+ showlegend=False,
330
+ )
331
+ )
332
+
333
+ # Update progress
334
+ if show_progress:
335
+ tasks_processed += 1
336
+ progress_bar.update(1)
337
+
338
+ # Add all shapes at once for better performance
339
+ fig.update_layout(shapes=shapes)
340
+
341
+ # Add all annotations at once
342
+ fig.update_layout(annotations=annotations)
343
+
344
+ # Add all hover traces at once
345
+ for trace in hover_traces:
346
+ fig.add_trace(go.Scatter(**trace))
347
+
348
+ # Add custom legend
349
+ legend_items = []
350
+
351
+ # Find the maximum virtual stage in the data
352
+ max_virtual_stage = 0
353
+ for device in schedule_data:
354
+ for task in schedule_data[device]:
355
+ virtual_stage = task["stage"] // num_devices
356
+ max_virtual_stage = max(max_virtual_stage, virtual_stage)
357
+
358
+ # Check if overlapped operations exist
359
+ has_overlapped = any(
360
+ task.get("is_overlapped", False)
361
+ for device in schedule_data
362
+ for task in schedule_data[device]
363
+ )
364
+
365
+ # Add forward and backward items for each virtual stage
366
+ for vs in range(max_virtual_stage + 1):
367
+ legend_items.append(
368
+ dict(
369
+ name=f"Forward (VS {vs})",
370
+ color=get_color("forward", vs * num_devices, num_devices),
371
+ )
372
+ )
373
+ legend_items.append(
374
+ dict(
375
+ name=f"Backward (VS {vs})",
376
+ color=get_color("backward", vs * num_devices, num_devices),
377
+ )
378
+ )
379
+ # Add entries for split backward operations if this is a zb1p schedule
380
+ if any(
381
+ task["type"] in ["backward_D", "backward_W"]
382
+ for device in schedule_data
383
+ for task in schedule_data[device]
384
+ ):
385
+ legend_items.append(
386
+ dict(
387
+ name=f"Backward Grad (VS {vs})",
388
+ color=get_color("backward_D", vs * num_devices, num_devices),
389
+ )
390
+ )
391
+ legend_items.append(
392
+ dict(
393
+ name=f"Backward Weight (VS {vs})",
394
+ color=get_color("backward_W", vs * num_devices, num_devices),
395
+ )
396
+ )
397
+
398
+ # If no tasks found, add default legend items
399
+ if not legend_items:
400
+ legend_items = [
401
+ dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
402
+ dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
403
+ dict(
404
+ name="Backward Grad (VS 0)",
405
+ color=get_color("backward_D", 0, num_devices),
406
+ ),
407
+ dict(
408
+ name="Backward Weight (VS 0)",
409
+ color=get_color("backward_W", 0, num_devices),
410
+ ),
411
+ ]
412
+
413
+ for i, item in enumerate(legend_items):
414
+ fig.add_trace(
415
+ go.Scatter(
416
+ x=[None],
417
+ y=[None],
418
+ mode="markers",
419
+ marker=dict(size=10, color=item["color"]),
420
+ name=item["name"],
421
+ showlegend=True,
422
+ )
423
+ )
424
+ if show_progress and i < len(legend_items) - 1:
425
+ progress_bar.update(1)
426
+
427
+ # Set axis properties
428
+ device_labels = [f"Device {i+1}" for i in range(num_devices)]
429
+
430
+ # Calculate tick positions with no gaps
431
+ tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
432
+
433
+ # Adjust the range to ensure there are no empty spaces at the end
434
+ x_end = max_time * 1.05 # Add a small margin
435
+
436
+ title_text = "Pipeline Parallelism Schedule"
437
+
438
+ fig.update_layout(
439
+ yaxis=dict(
440
+ tickmode="array",
441
+ tickvals=tick_positions,
442
+ ticktext=device_labels,
443
+ showgrid=False,
444
+ zeroline=False,
445
+ ),
446
+ margin=dict(l=50, r=20, t=40, b=40),
447
+ plot_bgcolor="white",
448
+ title=dict(
449
+ text=title_text,
450
+ x=0.5,
451
+ y=0.98, # Move title position closer to the top
452
+ font=dict(size=20),
453
+ ),
454
+ legend=dict(
455
+ orientation="v", # Changed from horizontal to vertical
456
+ yanchor="top",
457
+ y=1.02, # Position at the top
458
+ xanchor="right",
459
+ x=1.20, # Position further to the right to accommodate more items
460
+ title=dict(text="<b>Operation Types:</b>"),
461
+ itemsizing="constant",
462
+ tracegroupgap=0,
463
+ ),
464
+ width=2000, # Increase width to accommodate the expanded legend
465
+ height=400, # Maintain current height
466
+ bargap=0,
467
+ bargroupgap=0,
468
+ )
469
+
470
+ if show_progress:
471
+ progress_bar.update(1)
472
+ progress_bar.close()
473
+
474
+ return fig
475
+
476
+
477
+ # Cache for storing processed schedule data
478
+ _schedule_data_cache = {}
479
+
480
+
481
+ def create_dash_app(
482
+ schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
483
+ ):
484
+ """
485
+ Create a Dash app to visualize the pipeline schedule.
486
+
487
+ Args:
488
+ schedule: Schedule object to visualize
489
+ schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
490
+ enable_caching: Whether to cache the schedule data and figure
491
+ """
492
+ # Process schedule data only once and cache it
493
+ global _schedule_data_cache
494
+ cache_key = id(schedule)
495
+
496
+ if enable_caching and cache_key in _schedule_data_cache:
497
+ schedule_data = _schedule_data_cache[cache_key]
498
+ print("Using cached schedule data")
499
+ else:
500
+ schedule_data = convert_schedule_to_visualization_format(schedule)
501
+ if enable_caching:
502
+ _schedule_data_cache[cache_key] = schedule_data
503
+ print("Cached schedule data")
504
+
505
+ total_tasks = sum(len(tasks) for tasks in schedule_data.values())
506
+ print(f"Total tasks in schedule: {total_tasks}")
507
+
508
+ app = dash.Dash(__name__)
509
+ app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
510
+
511
+ # Create a more informative layout with data size information
512
+ app.layout = html.Div(
513
+ [
514
+ html.H1(
515
+ f"Pipeline Parallelism Visualization - {schedule_type}",
516
+ style={"textAlign": "center"},
517
+ ),
518
+ html.Div(
519
+ [
520
+ html.P(
521
+ f"Number of devices: {len(schedule_data)}",
522
+ style={"display": "inline-block", "marginRight": "20px"},
523
+ ),
524
+ html.P(
525
+ f"Total tasks: {total_tasks}",
526
+ style={"display": "inline-block", "marginRight": "20px"},
527
+ ),
528
+ ],
529
+ style={"marginBottom": "20px"},
530
+ ),
531
+ html.Div(id="graph-container", children=[]),
532
+ dcc.Loading(
533
+ id="loading-graph",
534
+ type="circle",
535
+ children=[
536
+ dcc.Graph(
537
+ id="pipeline-graph",
538
+ config={
539
+ "displayModeBar": True,
540
+ "toImageButtonOptions": {
541
+ "format": "png",
542
+ "filename": "pipeline_visualization",
543
+ },
544
+ },
545
+ ),
546
+ ],
547
+ ),
548
+ ]
549
+ )
550
+
551
+ # Cache for storing figure to avoid regenerating it
552
+ figure_cache = {}
553
+
554
+ @app.callback(
555
+ Output("pipeline-graph", "figure"),
556
+ Input("graph-container", "children"),
557
+ prevent_initial_call=False,
558
+ )
559
+ def load_graph(_):
560
+ # Use cached figure if available
561
+ cache_key = f"{id(schedule)}"
562
+ if enable_caching and cache_key in figure_cache:
563
+ print("Using cached figure")
564
+ return figure_cache[cache_key]
565
+
566
+ # Create the figure
567
+ figure = create_pipeline_figure(schedule_data, show_progress=True)
568
+
569
+ # Cache the figure
570
+ if enable_caching:
571
+ figure_cache[cache_key] = figure
572
+ print("Cached figure")
573
+
574
+ return figure
575
+
576
+ return app
577
+
578
+
579
+ def visualize_pipeline_parallelism_dash(
580
+ schedule: Schedule,
581
+ port: int = 8050,
582
+ debug: bool = False,
583
+ enable_caching: bool = True,
584
+ schedule_type="1f1b",
585
+ open_browser: bool = True,
586
+ ):
587
+ """
588
+ Launch a Dash app to visualize the pipeline schedule interactively.
589
+
590
+ Args:
591
+ schedule: Schedule object to visualize
592
+ port: Port to run the Dash app on
593
+ debug: Whether to run the Dash app in debug mode
594
+ enable_caching: Whether to cache schedule data and figures
595
+ schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
596
+ open_browser: Whether to automatically open a browser window
597
+ """
598
+ app = create_dash_app(
599
+ schedule, schedule_type=schedule_type, enable_caching=enable_caching
600
+ )
601
+
602
+ # Define function to open browser after a short delay
603
+ def open_browser_tab():
604
+ webbrowser.open_new_tab(f"http://localhost:{port}/")
605
+
606
+ # Open browser automatically if requested
607
+ if open_browser:
608
+ # Use a timer to open the browser after the server has started
609
+ Timer(1.0, open_browser_tab).start()
610
+
611
+ print(f"Starting Dash app on http://localhost:{port}/")
612
+ app.run_server(debug=debug, port=port)