Spaces:
Running
Running
Commit
·
c048b97
0
Parent(s):
Initial commit: PP schedule visualization.
Browse files- .gitattributes +2 -0
- .gitignore +12 -0
- Dockerfile +19 -0
- LICENSE +21 -0
- README.md +157 -0
- app.py +340 -0
- conf/config.yaml +25 -0
- main.py +156 -0
- pyproject.toml +69 -0
- requirements.txt +9 -0
- src/__init__.py +3 -0
- src/execution_model.py +401 -0
- src/strategies.py +581 -0
- src/visualizer.py +612 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
./venv
|
3 |
+
uv.lock
|
4 |
+
outputs/
|
5 |
+
.cursor/*
|
6 |
+
*.json
|
7 |
+
|
8 |
+
# Uncomment below if you want to include these files
|
9 |
+
# !assets/*.png
|
10 |
+
# !assets/*.jpg
|
11 |
+
# !docs/*.png
|
12 |
+
# !docs/*.jpg
|
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
RUN useradd -m -u 1000 user
|
5 |
+
USER user
|
6 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
7 |
+
ENV HOME="/home/user"
|
8 |
+
WORKDIR /home/user/app
|
9 |
+
|
10 |
+
COPY --chown=user requirements.txt ./
|
11 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
12 |
+
|
13 |
+
COPY --chown=user . ./
|
14 |
+
|
15 |
+
# Expose the port app will run on
|
16 |
+
EXPOSE 7860
|
17 |
+
|
18 |
+
# Start the app
|
19 |
+
CMD ["gunicorn", "-b", "0.0.0.0:7860", "app:server"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pipeline Parallelism Emulation and Visualization
|
2 |
+
|
3 |
+
This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
Pipeline parallelism is a technique used to train large models by partitioning the model across multiple devices and processing data in a pipelined fashion. This project allows you to:
|
8 |
+
|
9 |
+
- Simulate different pipeline parallelism strategies (1F1B, Interleaved, Zero-Bubble, etc.)
|
10 |
+
- Visualize the execution schedule on multiple devices
|
11 |
+
- Compare different strategies for efficiency
|
12 |
+
|
13 |
+
## Features
|
14 |
+
|
15 |
+
- **Supported Pipeline Strategies**:
|
16 |
+
- 1F1B (One-Forward-One-Backward)
|
17 |
+
- Interleaved 1F1B
|
18 |
+
- Zero-Bubble 1F1B (ZB-1P)
|
19 |
+
- 1F1B with computation-communication overlap
|
20 |
+
- Interleaved 1F1B with computation-communication overlap
|
21 |
+
- DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
|
22 |
+
|
23 |
+
- **Visualization**:
|
24 |
+
- Interactive visualization dashboard using Plotly/Dash
|
25 |
+
|
26 |
+
- **Configuration**:
|
27 |
+
- Configurable simulation parameters through Hydra
|
28 |
+
- Customizable stage latency and communication costs
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
This project uses [uv](https://github.com/astral-sh/uv) for dependency management.
|
33 |
+
|
34 |
+
Setup `uv` if not installed on your computer:
|
35 |
+
```bash
|
36 |
+
# On macOS and Linux
|
37 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
38 |
+
```
|
39 |
+
|
40 |
+
|
41 |
+
## Running the Interactive Server
|
42 |
+
|
43 |
+
To visualize schedules interactively:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
uv run src/server.py
|
47 |
+
```
|
48 |
+
|
49 |
+
This will start a Dash server (usually on `http://127.0.0.1:8050/`). Open this URL in your web browser.
|
50 |
+
|
51 |
+
You can then adjust parameters like the number of devices, stages, batches, operation times, and select different scheduling strategies to see the resulting pipeline visualization.
|
52 |
+
|
53 |
+
## Running from Command Line
|
54 |
+
|
55 |
+
### Running for 1F1B strategy:
|
56 |
+
```bash
|
57 |
+
uv run python main.py strategy=1f1b num_devices=4 num_stages=4 num_batches=8
|
58 |
+
```
|
59 |
+

|
60 |
+
|
61 |
+
### Running for interleaved strategy:
|
62 |
+
```bash
|
63 |
+
uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches=8
|
64 |
+
```
|
65 |
+

|
66 |
+
|
67 |
+
### Running for ZB-1P strategy:
|
68 |
+
```bash
|
69 |
+
uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
|
70 |
+
```
|
71 |
+

|
72 |
+
|
73 |
+
### Running for DualPipe strategy:
|
74 |
+
```bash
|
75 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
|
76 |
+
```
|
77 |
+

|
78 |
+
|
79 |
+
### Running for 1F1B-batch-overlap strategy:
|
80 |
+
```bash
|
81 |
+
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
82 |
+
```
|
83 |
+

|
84 |
+
|
85 |
+
### Running for 1F1B-interleave-overlap strategy:
|
86 |
+
```bash
|
87 |
+
uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=8 num_batches=8
|
88 |
+
```
|
89 |
+

|
90 |
+
|
91 |
+
|
92 |
+
## Configuration
|
93 |
+
|
94 |
+
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
95 |
+
|
96 |
+
#### Override Specific Parameters
|
97 |
+
|
98 |
+
You can override specific parameters at runtime:
|
99 |
+
```bash
|
100 |
+
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
101 |
+
```
|
102 |
+
|
103 |
+
Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
|
104 |
+
```bash
|
105 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
|
106 |
+
```
|
107 |
+
|
108 |
+
|
109 |
+
### Using Different Configuration Files
|
110 |
+
|
111 |
+
You can use different configuration files with Hydra in several ways:
|
112 |
+
|
113 |
+
#### Recommended Approach
|
114 |
+
|
115 |
+
1. Create multiple configuration files in the `conf` directory for different use cases:
|
116 |
+
```
|
117 |
+
conf/
|
118 |
+
├── config.yaml # Default configuration
|
119 |
+
└── model_A.yaml # Create your own config with stage-specific latency for performance projection
|
120 |
+
```
|
121 |
+
|
122 |
+
2. Run with your desired configuration using the `--config-name` flag:
|
123 |
+
```bash
|
124 |
+
uv run python main.py --config-name=model_A
|
125 |
+
```
|
126 |
+
|
127 |
+
|
128 |
+
## Project Structure
|
129 |
+
|
130 |
+
```
|
131 |
+
PP-Emulation/
|
132 |
+
├── conf/ # Hydra configuration files
|
133 |
+
│ └── config.yaml # Default configuration
|
134 |
+
├── src/ # Source code
|
135 |
+
│ ├── __init__.py # Package initialization
|
136 |
+
│ ├── execution_model.py # Schedule execution models
|
137 |
+
│ ├── strategies.py # Pipeline parallelism strategies
|
138 |
+
│ └── visualizer.py # Visualization utilities
|
139 |
+
├── main.py # Main entry point
|
140 |
+
├── pyproject.toml # Project metadata and dependencies
|
141 |
+
└── README.md # This file
|
142 |
+
```
|
143 |
+
|
144 |
+
## References
|
145 |
+
|
146 |
+
1. _PipeDream: Fast and Efficient Pipeline Parallel DNN Training_. [arxiv](https://arxiv.org/abs/1806.03377)
|
147 |
+
2. _Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM_. [arxiv](https://arxiv.org/abs/2104.04473)
|
148 |
+
3. _Zero Bubble Pipeline Parallelism_. [arxiv](https://arxiv.org/abs/2401.10241)
|
149 |
+
4. _Communication-Computation Overlap in MoE Training with 1F1B Pipeline Parallelism_. [blog](https://zhuanlan.zhihu.com/p/28463368206)
|
150 |
+
|
151 |
+
## License
|
152 |
+
|
153 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
154 |
+
|
155 |
+
## Contributing
|
156 |
+
|
157 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
app.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
import dash_bootstrap_components as dbc
|
3 |
+
from dash import dcc, html, Input, Output, State, callback_context
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
|
6 |
+
from src.execution_model import ScheduleConfig, Schedule
|
7 |
+
from src.strategies import (
|
8 |
+
generate_1f1b_schedule,
|
9 |
+
generate_zero_bubble_1p_schedule,
|
10 |
+
generate_1f1b_overlap_schedule,
|
11 |
+
generate_1f1b_interleave_schedule,
|
12 |
+
generate_1f1b_interleave_overlap_schedule,
|
13 |
+
generate_dualpipe_schedule
|
14 |
+
)
|
15 |
+
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
|
16 |
+
|
17 |
+
STRATEGIES = {
|
18 |
+
"1f1b": generate_1f1b_schedule,
|
19 |
+
"zb1p": generate_zero_bubble_1p_schedule,
|
20 |
+
"1f1b_overlap": generate_1f1b_overlap_schedule,
|
21 |
+
"1f1b_interleave": generate_1f1b_interleave_schedule,
|
22 |
+
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
|
23 |
+
"dualpipe": generate_dualpipe_schedule,
|
24 |
+
}
|
25 |
+
|
26 |
+
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
|
27 |
+
app.title = "Pipeline Parallelism Schedule Visualizer"
|
28 |
+
|
29 |
+
# Initial default values
|
30 |
+
default_values = {
|
31 |
+
"num_devices": 4,
|
32 |
+
"num_stages": 8,
|
33 |
+
"num_batches": 16,
|
34 |
+
"p2p_latency": 0.0,
|
35 |
+
"op_time_forward": 1.0,
|
36 |
+
"op_time_backward_d": 1.0,
|
37 |
+
"op_time_backward_w": 1.0,
|
38 |
+
"op_time_backward": 2.0,
|
39 |
+
"strategy": "1f1b_interleave",
|
40 |
+
"op_time_overlapped_fwd_bwd": None,
|
41 |
+
}
|
42 |
+
|
43 |
+
# Define input groups using dbc components
|
44 |
+
basic_params_card = dbc.Card(
|
45 |
+
dbc.CardBody([
|
46 |
+
html.H5("Basic Parameters", className="card-title"),
|
47 |
+
html.Div([
|
48 |
+
dbc.Label("Number of Devices (GPUs):"),
|
49 |
+
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
|
50 |
+
], className="mb-3"),
|
51 |
+
html.Div([
|
52 |
+
dbc.Label("Number of Stages (Model Chunks):"),
|
53 |
+
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
|
54 |
+
], className="mb-3"),
|
55 |
+
html.Div([
|
56 |
+
dbc.Label("Number of Microbatches:"),
|
57 |
+
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
|
58 |
+
], className="mb-3"),
|
59 |
+
html.Div([
|
60 |
+
dbc.Label("P2P Latency (ms):"),
|
61 |
+
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
|
62 |
+
], className="mb-3"),
|
63 |
+
])
|
64 |
+
)
|
65 |
+
|
66 |
+
scheduling_params_card = dbc.Card(
|
67 |
+
dbc.CardBody([
|
68 |
+
html.H5("Scheduling Parameters", className="card-title"),
|
69 |
+
html.Div([
|
70 |
+
dbc.Label("Scheduling Strategies:"),
|
71 |
+
dbc.Checklist(
|
72 |
+
id='strategy-checklist',
|
73 |
+
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
74 |
+
value=list(STRATEGIES.keys()),
|
75 |
+
inline=False,
|
76 |
+
),
|
77 |
+
], className="mb-3"),
|
78 |
+
])
|
79 |
+
)
|
80 |
+
|
81 |
+
timing_params_card = dbc.Card(
|
82 |
+
dbc.CardBody([
|
83 |
+
html.H5("Operation Timing (ms)", className="card-title"),
|
84 |
+
html.Div([
|
85 |
+
dbc.Label("Forward:"),
|
86 |
+
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
|
87 |
+
], className="mb-3"),
|
88 |
+
html.Div([
|
89 |
+
dbc.Label("Backward (Combined):"),
|
90 |
+
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
|
91 |
+
dbc.FormText("Used when strategy does NOT require split backward."),
|
92 |
+
], className="mb-3"),
|
93 |
+
html.Div([
|
94 |
+
dbc.Label("Backward D (Data Grad):"),
|
95 |
+
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
|
96 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
97 |
+
], className="mb-3"),
|
98 |
+
html.Div([
|
99 |
+
dbc.Label("Backward W (Weight Grad):"),
|
100 |
+
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
101 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
102 |
+
], className="mb-3"),
|
103 |
+
html.Div([
|
104 |
+
dbc.Label("Overlapped Forward+Backward:"),
|
105 |
+
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
|
106 |
+
dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
|
107 |
+
], className="mb-3"),
|
108 |
+
])
|
109 |
+
)
|
110 |
+
|
111 |
+
# Updated app layout using dbc components and structure
|
112 |
+
app.layout = dbc.Container([
|
113 |
+
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
|
114 |
+
|
115 |
+
dbc.Row([
|
116 |
+
dbc.Col(basic_params_card, md=4),
|
117 |
+
dbc.Col(scheduling_params_card, md=4),
|
118 |
+
dbc.Col(timing_params_card, md=4),
|
119 |
+
]),
|
120 |
+
|
121 |
+
dbc.Row([
|
122 |
+
dbc.Col([
|
123 |
+
dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
|
124 |
+
], className="text-center")
|
125 |
+
]),
|
126 |
+
|
127 |
+
dbc.Row([
|
128 |
+
dbc.Col([
|
129 |
+
dcc.Loading(
|
130 |
+
id="loading-graph-area",
|
131 |
+
type="circle",
|
132 |
+
children=html.Div(id='graph-output-container', className="mt-4")
|
133 |
+
)
|
134 |
+
])
|
135 |
+
])
|
136 |
+
], fluid=True)
|
137 |
+
|
138 |
+
@app.callback(
|
139 |
+
Output('graph-output-container', 'children'),
|
140 |
+
Input('generate-button', 'n_clicks'),
|
141 |
+
State('num_devices', 'value'),
|
142 |
+
State('num_stages', 'value'),
|
143 |
+
State('num_batches', 'value'),
|
144 |
+
State('p2p_latency', 'value'),
|
145 |
+
State('op_time_forward', 'value'),
|
146 |
+
State('op_time_backward', 'value'),
|
147 |
+
State('op_time_backward_d', 'value'),
|
148 |
+
State('op_time_backward_w', 'value'),
|
149 |
+
State('op_time_overlapped_fwd_bwd', 'value'),
|
150 |
+
State('strategy-checklist', 'value'),
|
151 |
+
prevent_initial_call=True
|
152 |
+
)
|
153 |
+
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
154 |
+
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
155 |
+
op_time_overlapped_fwd_bwd,
|
156 |
+
selected_strategies):
|
157 |
+
|
158 |
+
# Define the desired display order for strategies
|
159 |
+
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
160 |
+
|
161 |
+
output_components = []
|
162 |
+
valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
|
163 |
+
error_messages = [] # Store (strategy_name, error_message) for errors
|
164 |
+
automatic_adjustments = [] # Store messages about automatic parameter adjustments
|
165 |
+
|
166 |
+
if not selected_strategies:
|
167 |
+
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
168 |
+
|
169 |
+
if not all([num_devices, num_stages, num_batches, op_time_forward]):
|
170 |
+
return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]
|
171 |
+
|
172 |
+
for strategy in selected_strategies:
|
173 |
+
error_message = ""
|
174 |
+
placement_strategy = ""
|
175 |
+
|
176 |
+
# Use local copies of params that might be adjusted for this strategy
|
177 |
+
current_num_stages = num_stages
|
178 |
+
current_num_devices = num_devices
|
179 |
+
|
180 |
+
# Apply automatic adjustments for dualpipe
|
181 |
+
if strategy == "dualpipe" and num_stages != num_devices:
|
182 |
+
current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
|
183 |
+
automatic_adjustments.append(
|
184 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
185 |
+
)
|
186 |
+
|
187 |
+
# Apply automatic adjustments for strategies that require num_stages == num_devices
|
188 |
+
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
|
189 |
+
current_num_stages = num_devices
|
190 |
+
automatic_adjustments.append(
|
191 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
192 |
+
)
|
193 |
+
|
194 |
+
split_backward = strategy in ["zb1p", "dualpipe"]
|
195 |
+
|
196 |
+
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
|
197 |
+
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
|
198 |
+
elif not split_backward and not op_time_backward:
|
199 |
+
error_message = f"Strategy '{strategy}': Combined Backward time is required."
|
200 |
+
|
201 |
+
if not error_message:
|
202 |
+
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
203 |
+
placement_strategy = "standard"
|
204 |
+
# No need to check num_stages == num_devices as we've enforced it above
|
205 |
+
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
206 |
+
placement_strategy = "interleave"
|
207 |
+
if current_num_stages % current_num_devices != 0:
|
208 |
+
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
209 |
+
elif strategy == "dualpipe":
|
210 |
+
placement_strategy = "dualpipe"
|
211 |
+
if current_num_stages % 2 != 0:
|
212 |
+
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
213 |
+
|
214 |
+
# Create adjusted operation times based on placement strategy
|
215 |
+
if not error_message:
|
216 |
+
try:
|
217 |
+
# Calculate number of stages per device for time adjustment
|
218 |
+
stages_per_device = current_num_stages // current_num_devices
|
219 |
+
|
220 |
+
# Calculate scaling factor - this normalizes operation time by stages per device
|
221 |
+
# For standard placement (1:1 stage:device mapping), this remains 1.0
|
222 |
+
# For interleaved, this scales down the time proportionally
|
223 |
+
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
|
224 |
+
|
225 |
+
if stages_per_device > 1:
|
226 |
+
automatic_adjustments.append(
|
227 |
+
f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
|
228 |
+
)
|
229 |
+
|
230 |
+
# Apply scaling to operation times
|
231 |
+
op_times = {
|
232 |
+
"forward": float(op_time_forward) * time_scale_factor
|
233 |
+
}
|
234 |
+
|
235 |
+
if split_backward:
|
236 |
+
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
|
237 |
+
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
|
238 |
+
# Keep combined for compatibility
|
239 |
+
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
|
240 |
+
else:
|
241 |
+
op_times["backward"] = float(op_time_backward) * time_scale_factor
|
242 |
+
|
243 |
+
if op_time_overlapped_fwd_bwd is not None:
|
244 |
+
try:
|
245 |
+
overlapped_val = float(op_time_overlapped_fwd_bwd)
|
246 |
+
if overlapped_val > 0:
|
247 |
+
# Scale overlapped time too
|
248 |
+
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
|
249 |
+
except (ValueError, TypeError):
|
250 |
+
pass
|
251 |
+
|
252 |
+
config = ScheduleConfig(
|
253 |
+
num_devices=int(current_num_devices),
|
254 |
+
num_stages=int(current_num_stages), # Use adjusted value
|
255 |
+
num_batches=int(num_batches),
|
256 |
+
p2p_latency=float(p2p_latency),
|
257 |
+
placement_strategy=placement_strategy,
|
258 |
+
split_backward=split_backward,
|
259 |
+
op_times=op_times,
|
260 |
+
)
|
261 |
+
|
262 |
+
schedule_func = STRATEGIES.get(strategy)
|
263 |
+
if not schedule_func:
|
264 |
+
raise ValueError(f"Invalid strategy function for: {strategy}")
|
265 |
+
|
266 |
+
schedule = schedule_func(config)
|
267 |
+
schedule.execute()
|
268 |
+
|
269 |
+
# Store valid results instead of creating figure immediately
|
270 |
+
vis_data = convert_schedule_to_visualization_format(schedule)
|
271 |
+
valid_results.append((strategy, schedule, vis_data))
|
272 |
+
|
273 |
+
except (AssertionError, ValueError, TypeError) as e:
|
274 |
+
error_message = f"Error generating schedule for '{strategy}': {e}"
|
275 |
+
import traceback
|
276 |
+
traceback.print_exc()
|
277 |
+
except Exception as e:
|
278 |
+
error_message = f"An unexpected error occurred for '{strategy}': {e}"
|
279 |
+
import traceback
|
280 |
+
traceback.print_exc()
|
281 |
+
|
282 |
+
if error_message:
|
283 |
+
error_messages.append((strategy, error_message))
|
284 |
+
|
285 |
+
# Add alerts for any automatic parameter adjustments
|
286 |
+
for adjustment in automatic_adjustments:
|
287 |
+
output_components.append(
|
288 |
+
dbc.Alert(adjustment, color="info", dismissable=True)
|
289 |
+
)
|
290 |
+
|
291 |
+
# If we have valid results, calculate the maximum execution time across all schedules
|
292 |
+
if valid_results:
|
293 |
+
# Find global maximum execution time
|
294 |
+
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
295 |
+
|
296 |
+
# Sort valid results according to the display order
|
297 |
+
sorted_valid_results = []
|
298 |
+
|
299 |
+
# First add strategies in the predefined order
|
300 |
+
for strategy_name in strategy_display_order:
|
301 |
+
for result in valid_results:
|
302 |
+
if result[0] == strategy_name:
|
303 |
+
sorted_valid_results.append(result)
|
304 |
+
|
305 |
+
# Then add any remaining strategies that might not be in the predefined order
|
306 |
+
for result in valid_results:
|
307 |
+
if result[0] not in strategy_display_order:
|
308 |
+
sorted_valid_results.append(result)
|
309 |
+
|
310 |
+
# Create figures with aligned x-axis, using the sorted results
|
311 |
+
for strategy, _, vis_data in sorted_valid_results:
|
312 |
+
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
313 |
+
|
314 |
+
# Force the x-axis range to be the same for all figures
|
315 |
+
# Add a small margin (5%) for better visualization
|
316 |
+
margin = max_execution_time * 0.05
|
317 |
+
fig.update_layout(
|
318 |
+
xaxis=dict(
|
319 |
+
range=[0, max_execution_time + margin]
|
320 |
+
)
|
321 |
+
)
|
322 |
+
|
323 |
+
output_components.append(html.Div([
|
324 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
325 |
+
dcc.Graph(figure=fig)
|
326 |
+
]))
|
327 |
+
|
328 |
+
# Add error messages to output
|
329 |
+
for strategy, msg in error_messages:
|
330 |
+
output_components.append(
|
331 |
+
dbc.Alert(msg, color="danger", className="mt-3")
|
332 |
+
)
|
333 |
+
|
334 |
+
return output_components
|
335 |
+
|
336 |
+
# For Hugging Face Spaces deployment
|
337 |
+
server = app.server
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
app.run_server(debug=False, host='0.0.0.0', port=7860)
|
conf/config.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default configuration for Pipeline Parallelism Emulation
|
2 |
+
num_devices: 4
|
3 |
+
num_stages: 4
|
4 |
+
num_batches: 8
|
5 |
+
visualization_port: 8050
|
6 |
+
strategy: "1f1b" # Options: "1f1b", "interleave"
|
7 |
+
p2p_latency: 0.0
|
8 |
+
|
9 |
+
# Operation time configurations
|
10 |
+
op_times:
|
11 |
+
# Option 1: Simple configuration (same time for all stages)
|
12 |
+
forward: 1.0
|
13 |
+
backward: 2.0
|
14 |
+
backward_D: 1.0
|
15 |
+
backward_W: 1.0
|
16 |
+
overlapped_forward_backward: 3.0
|
17 |
+
|
18 |
+
# Option 2: Commented example of stage-specific configuration
|
19 |
+
# forward:
|
20 |
+
# 0: 0.8 # Stage 0 forward time
|
21 |
+
# 1: 1.2 # Stage 1 forward time
|
22 |
+
# 2: 1.5 # Stage 2 forward time
|
23 |
+
# 3: 0.9 # Stage 3 forward time
|
24 |
+
# backward:
|
25 |
+
# 0: 1.0 # Stage 0 backward time
|
main.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.execution_model import ScheduleConfig
|
2 |
+
from src.strategies import (
|
3 |
+
generate_1f1b_interleave_overlap_schedule,
|
4 |
+
generate_1f1b_interleave_schedule,
|
5 |
+
generate_1f1b_overlap_schedule,
|
6 |
+
generate_1f1b_schedule,
|
7 |
+
generate_zero_bubble_1p_schedule,
|
8 |
+
generate_dualpipe_schedule,
|
9 |
+
)
|
10 |
+
from src.visualizer import visualize_pipeline_parallelism_dash
|
11 |
+
import hydra
|
12 |
+
from omegaconf import DictConfig, OmegaConf
|
13 |
+
|
14 |
+
|
15 |
+
@hydra.main(config_path="conf", config_name="config", version_base=None)
|
16 |
+
def main(cfg: DictConfig) -> None:
|
17 |
+
"""Run pipeline parallelism simulation with the specified configuration."""
|
18 |
+
print(f"Running with configuration: {cfg}")
|
19 |
+
|
20 |
+
if cfg.strategy == "1f1b":
|
21 |
+
run_1f1b(cfg)
|
22 |
+
elif cfg.strategy == "interleave":
|
23 |
+
run_interleave(cfg)
|
24 |
+
elif cfg.strategy == "zb1p":
|
25 |
+
run_zero_bubble_1p(cfg)
|
26 |
+
elif cfg.strategy == "1f1b_overlap":
|
27 |
+
run_1f1b_overlap(cfg)
|
28 |
+
elif cfg.strategy == "1f1b_interleave_overlap":
|
29 |
+
run_1f1b_interleave_overlap(cfg)
|
30 |
+
elif cfg.strategy == "dualpipe":
|
31 |
+
run_dualpipe(cfg)
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
34 |
+
|
35 |
+
|
36 |
+
def run_1f1b(cfg: DictConfig) -> None:
|
37 |
+
"""Run 1F1B pipeline parallelism simulation."""
|
38 |
+
# Convert OmegaConf to dict for op_times if it exists
|
39 |
+
op_times = (
|
40 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
41 |
+
)
|
42 |
+
|
43 |
+
schedule_config = ScheduleConfig(
|
44 |
+
num_devices=cfg.num_devices,
|
45 |
+
num_stages=cfg.num_stages,
|
46 |
+
num_batches=cfg.num_batches,
|
47 |
+
p2p_latency=cfg.p2p_latency,
|
48 |
+
op_times=op_times,
|
49 |
+
placement_strategy="standard",
|
50 |
+
)
|
51 |
+
schedule = generate_1f1b_schedule(schedule_config)
|
52 |
+
schedule.execute()
|
53 |
+
|
54 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
55 |
+
|
56 |
+
|
57 |
+
def run_interleave(cfg: DictConfig) -> None:
|
58 |
+
"""Run interleaved pipeline parallelism simulation."""
|
59 |
+
# Convert OmegaConf to dict for op_times if it exists
|
60 |
+
op_times = (
|
61 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
62 |
+
)
|
63 |
+
|
64 |
+
schedule_config = ScheduleConfig(
|
65 |
+
num_devices=cfg.num_devices,
|
66 |
+
num_stages=cfg.num_stages,
|
67 |
+
num_batches=cfg.num_batches,
|
68 |
+
p2p_latency=cfg.p2p_latency,
|
69 |
+
placement_strategy="interleave",
|
70 |
+
op_times=op_times,
|
71 |
+
)
|
72 |
+
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
73 |
+
schedule.execute()
|
74 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
75 |
+
|
76 |
+
|
77 |
+
def run_zero_bubble_1p(cfg: DictConfig) -> None:
|
78 |
+
"""Run zero bubble 1P pipeline parallelism simulation."""
|
79 |
+
# Convert OmegaConf to dict for op_times if it exists
|
80 |
+
op_times = (
|
81 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
82 |
+
)
|
83 |
+
|
84 |
+
schedule_config = ScheduleConfig(
|
85 |
+
num_devices=cfg.num_devices,
|
86 |
+
num_stages=cfg.num_stages,
|
87 |
+
num_batches=cfg.num_batches,
|
88 |
+
p2p_latency=cfg.p2p_latency,
|
89 |
+
op_times=op_times,
|
90 |
+
split_backward=True,
|
91 |
+
)
|
92 |
+
schedule = generate_zero_bubble_1p_schedule(schedule_config)
|
93 |
+
schedule.execute()
|
94 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
95 |
+
|
96 |
+
|
97 |
+
def run_1f1b_overlap(cfg: DictConfig) -> None:
|
98 |
+
"""Run 1F1B overlap pipeline parallelism simulation."""
|
99 |
+
# Convert OmegaConf to dict for op_times if it exists
|
100 |
+
op_times = (
|
101 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
102 |
+
)
|
103 |
+
|
104 |
+
schedule_config = ScheduleConfig(
|
105 |
+
num_devices=cfg.num_devices,
|
106 |
+
num_stages=cfg.num_stages,
|
107 |
+
num_batches=cfg.num_batches,
|
108 |
+
p2p_latency=cfg.p2p_latency,
|
109 |
+
op_times=op_times,
|
110 |
+
split_backward=False,
|
111 |
+
)
|
112 |
+
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
113 |
+
schedule.execute()
|
114 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
115 |
+
|
116 |
+
def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
|
117 |
+
"""Run 1F1B interleave overlapped pipeline parallelism simulation."""
|
118 |
+
# Convert OmegaConf to dict for op_times if it exists
|
119 |
+
op_times = (
|
120 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
121 |
+
)
|
122 |
+
|
123 |
+
schedule_config = ScheduleConfig(
|
124 |
+
num_devices=cfg.num_devices,
|
125 |
+
num_stages=cfg.num_stages,
|
126 |
+
num_batches=cfg.num_batches,
|
127 |
+
p2p_latency=cfg.p2p_latency,
|
128 |
+
placement_strategy="interleave",
|
129 |
+
op_times=op_times,
|
130 |
+
)
|
131 |
+
schedule = generate_1f1b_interleave_overlap_schedule(schedule_config)
|
132 |
+
schedule.execute()
|
133 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
134 |
+
|
135 |
+
def run_dualpipe(cfg: DictConfig) -> None:
|
136 |
+
"""Run DualPipe pipeline parallelism simulation."""
|
137 |
+
# Convert OmegaConf to dict for op_times if it exists
|
138 |
+
op_times = (
|
139 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
140 |
+
)
|
141 |
+
|
142 |
+
schedule_config = ScheduleConfig(
|
143 |
+
num_devices=cfg.num_devices,
|
144 |
+
num_stages=cfg.num_stages,
|
145 |
+
num_batches=cfg.num_batches,
|
146 |
+
p2p_latency=cfg.p2p_latency,
|
147 |
+
op_times=op_times,
|
148 |
+
split_backward=True,
|
149 |
+
placement_strategy="dualpipe",
|
150 |
+
)
|
151 |
+
schedule = generate_dualpipe_schedule(schedule_config)
|
152 |
+
schedule.execute()
|
153 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["hatchling"]
|
3 |
+
build-backend = "hatchling.build"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "pp-emulation"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "Pipeline Parallelism Emulation and Visualization"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.10"
|
11 |
+
authors = [
|
12 |
+
{name = "Zhenhuan Liu"}
|
13 |
+
]
|
14 |
+
classifiers = [
|
15 |
+
"Programming Language :: Python :: 3",
|
16 |
+
"License :: OSI Approved :: MIT License",
|
17 |
+
"Operating System :: OS Independent",
|
18 |
+
]
|
19 |
+
dependencies = [
|
20 |
+
"dash>=2.14.0",
|
21 |
+
"hydra-core>=1.3.2",
|
22 |
+
"omegaconf>=2.3.0",
|
23 |
+
"plotly>=5.18.0",
|
24 |
+
"pandas>=2.1.0",
|
25 |
+
"numpy>=1.26.0",
|
26 |
+
"tqdm>=4.67.0",
|
27 |
+
"dash-bootstrap-components>=1.7.1",
|
28 |
+
"gunicorn>=23.0.0",
|
29 |
+
]
|
30 |
+
|
31 |
+
[project.optional-dependencies]
|
32 |
+
dev = [
|
33 |
+
"pytest>=7.4.0",
|
34 |
+
"black>=23.7.0",
|
35 |
+
"isort>=5.12.0",
|
36 |
+
"mypy>=1.5.1",
|
37 |
+
]
|
38 |
+
|
39 |
+
# Add Hatch configuration to explicitly define where source code is located
|
40 |
+
[tool.hatch.build.targets.wheel]
|
41 |
+
packages = ["src"]
|
42 |
+
|
43 |
+
[tool.hatch.build.targets.sdist]
|
44 |
+
include = [
|
45 |
+
"src",
|
46 |
+
"main.py",
|
47 |
+
"conf",
|
48 |
+
"LICENSE",
|
49 |
+
"README.md",
|
50 |
+
]
|
51 |
+
|
52 |
+
[tool.black]
|
53 |
+
line-length = 88
|
54 |
+
target-version = ["py310"]
|
55 |
+
|
56 |
+
[tool.isort]
|
57 |
+
profile = "black"
|
58 |
+
line_length = 88
|
59 |
+
|
60 |
+
[tool.mypy]
|
61 |
+
python_version = "3.10"
|
62 |
+
warn_return_any = true
|
63 |
+
warn_unused_configs = true
|
64 |
+
disallow_untyped_defs = true
|
65 |
+
disallow_incomplete_defs = true
|
66 |
+
|
67 |
+
[tool.pytest]
|
68 |
+
testpaths = ["tests"]
|
69 |
+
pythonpath = ["."]
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dash==2.14.2
|
2 |
+
dash-bootstrap-components==1.7.1
|
3 |
+
plotly==5.18.0
|
4 |
+
gunicorn==21.2.0
|
5 |
+
hydra-core==1.3.2
|
6 |
+
omegaconf==2.3.0
|
7 |
+
pandas==2.1.0
|
8 |
+
numpy==1.26.0
|
9 |
+
tqdm==4.67.0
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""Pipeline Parallelism Emulation and Visualization package."""
|
2 |
+
|
3 |
+
__version__ = "0.1.0"
|
src/execution_model.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
|
4 |
+
|
5 |
+
class Operation:
|
6 |
+
"""Operation is a single operation in the pipeline."""
|
7 |
+
|
8 |
+
def __init__(self, batch_id: int, stage_id: int, op_type: str):
|
9 |
+
self.batch_id = batch_id
|
10 |
+
self.stage_id = stage_id
|
11 |
+
self.op_type = op_type
|
12 |
+
self.device_id = None
|
13 |
+
|
14 |
+
self.start_time = None
|
15 |
+
self.end_time = None
|
16 |
+
|
17 |
+
def set_end_time(self, end_time: float):
|
18 |
+
self.end_time = end_time
|
19 |
+
|
20 |
+
def set_start_time(self, start_time: float):
|
21 |
+
self.start_time = start_time
|
22 |
+
|
23 |
+
def __repr__(self) -> str:
|
24 |
+
return f"Operation(batch_id={self.batch_id}, stage_id={self.stage_id}, op_type={self.op_type})"
|
25 |
+
|
26 |
+
class OverlappedOperation:
|
27 |
+
"""Represents multiple operations that are overlapped/executed concurrently."""
|
28 |
+
|
29 |
+
def __init__(self, operations: List[Operation]):
|
30 |
+
self.operations = operations
|
31 |
+
self.device_id = operations[0].device_id
|
32 |
+
|
33 |
+
# Validate all operations are on the same device
|
34 |
+
for op in operations:
|
35 |
+
assert op.device_id == self.device_id, "All operations must be on the same device"
|
36 |
+
|
37 |
+
# Create a combined op_type (e.g., "overlapped_forward_backward")
|
38 |
+
self.op_type = "overlapped_" + "_".join([op.op_type for op in operations])
|
39 |
+
|
40 |
+
# Use the batch_id and stage_id of the first operation for identification
|
41 |
+
# (though we'll track all operations internally)
|
42 |
+
self.batch_id = operations[0].batch_id
|
43 |
+
self.stage_id = operations[0].stage_id
|
44 |
+
|
45 |
+
# Initialize timing information
|
46 |
+
self.start_time = None
|
47 |
+
self.end_time = None
|
48 |
+
|
49 |
+
def set_end_time(self, end_time: float):
|
50 |
+
self.end_time = end_time
|
51 |
+
for op in self.operations:
|
52 |
+
op.set_end_time(end_time)
|
53 |
+
|
54 |
+
def set_start_time(self, start_time: float):
|
55 |
+
self.start_time = start_time
|
56 |
+
for op in self.operations:
|
57 |
+
op.set_start_time(start_time)
|
58 |
+
|
59 |
+
def __repr__(self) -> str:
|
60 |
+
op_str = ", ".join([f"({op.batch_id},{op.stage_id},{op.op_type})" for op in self.operations])
|
61 |
+
return f"OverlappedOperation([{op_str}])"
|
62 |
+
|
63 |
+
class DeviceQueue:
|
64 |
+
def __init__(self, stages: List[int], device_id: int):
|
65 |
+
self.stages = stages
|
66 |
+
self.device_id = device_id
|
67 |
+
self.ops = [] # List of operations
|
68 |
+
|
69 |
+
def add_operation(self, op: Operation):
|
70 |
+
assert op.stage_id in self.stages
|
71 |
+
self.ops.append(op)
|
72 |
+
assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
|
73 |
+
op.device_id = self.device_id
|
74 |
+
|
75 |
+
|
76 |
+
class ScheduleConfig:
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
num_devices: int,
|
80 |
+
num_stages: int,
|
81 |
+
num_batches: int,
|
82 |
+
p2p_latency: float = 0.0,
|
83 |
+
placement_strategy: str = "standard",
|
84 |
+
split_backward: bool = False,
|
85 |
+
op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
|
86 |
+
):
|
87 |
+
self.num_devices = num_devices
|
88 |
+
self.num_stages = num_stages
|
89 |
+
self.num_batches = num_batches
|
90 |
+
self.p2p_latency = p2p_latency
|
91 |
+
self.placement_strategy = placement_strategy
|
92 |
+
self.split_backward = split_backward
|
93 |
+
|
94 |
+
# Initialize default operation times
|
95 |
+
if self.split_backward:
|
96 |
+
self.op_times = {
|
97 |
+
"forward": 1.0,
|
98 |
+
"backward_D": 1.0,
|
99 |
+
"backward_W": 1.0,
|
100 |
+
"backward": 2.0,
|
101 |
+
}
|
102 |
+
else:
|
103 |
+
self.op_times = {
|
104 |
+
"forward": 1.0,
|
105 |
+
"backward": 2.0,
|
106 |
+
}
|
107 |
+
|
108 |
+
# Update with user-provided operation times
|
109 |
+
if op_times:
|
110 |
+
for op_type, times in op_times.items():
|
111 |
+
if isinstance(times, dict):
|
112 |
+
# If a dict is provided, it maps stage_id -> time
|
113 |
+
if op_type not in self.op_times:
|
114 |
+
self.op_times[op_type] = {}
|
115 |
+
elif not isinstance(self.op_times[op_type], dict):
|
116 |
+
# Convert float to dict if needed
|
117 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
118 |
+
|
119 |
+
# Update with provided stage-specific times
|
120 |
+
for stage_id, time in times.items():
|
121 |
+
if not isinstance(self.op_times[op_type], dict):
|
122 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
123 |
+
self.op_times[op_type][stage_id] = time
|
124 |
+
else:
|
125 |
+
# If a float is provided, use same time for all stages
|
126 |
+
self.op_times[op_type] = times
|
127 |
+
|
128 |
+
assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices"
|
129 |
+
self.num_stages_per_device = num_stages // num_devices
|
130 |
+
|
131 |
+
self.init_device_to_stages()
|
132 |
+
if self.placement_strategy == "dualpipe":
|
133 |
+
assert (
|
134 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
assert (
|
138 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
|
139 |
+
)
|
140 |
+
|
141 |
+
def init_device_to_stages(self):
|
142 |
+
if self.placement_strategy == "standard":
|
143 |
+
# Evenly distributed
|
144 |
+
stages_per_device = self.num_stages // self.num_devices
|
145 |
+
self.device_to_stages = defaultdict(list)
|
146 |
+
for i in range(self.num_stages):
|
147 |
+
device_to_put = i // stages_per_device
|
148 |
+
self.device_to_stages[device_to_put].append(i)
|
149 |
+
elif self.placement_strategy == "interleave":
|
150 |
+
self.device_to_stages = defaultdict(list)
|
151 |
+
for i in range(self.num_stages):
|
152 |
+
device_to_put = i % self.num_devices
|
153 |
+
self.device_to_stages[device_to_put].append(i)
|
154 |
+
elif self.placement_strategy == "dualpipe":
|
155 |
+
# For DualPipe, each device has two stages
|
156 |
+
assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
|
157 |
+
assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
|
158 |
+
self.device_to_stages = defaultdict(list)
|
159 |
+
for i in range(self.num_stages):
|
160 |
+
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
161 |
+
else:
|
162 |
+
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
163 |
+
|
164 |
+
def get_op_time(self, op_type: str, stage_id: int):
|
165 |
+
# For overlapped operations, extract the original operation types
|
166 |
+
if op_type.startswith("overlapped_"):
|
167 |
+
if op_type in self.op_times:
|
168 |
+
if isinstance(self.op_times[op_type], dict):
|
169 |
+
if stage_id in self.op_times[op_type]:
|
170 |
+
return self.op_times[op_type][stage_id]
|
171 |
+
else:
|
172 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
173 |
+
else:
|
174 |
+
return self.op_times[op_type]
|
175 |
+
else:
|
176 |
+
op_parts = op_type.split("_")[1:]
|
177 |
+
if len(op_parts) >= 2:
|
178 |
+
op_type1, op_type2 = op_parts[0], op_parts[1]
|
179 |
+
return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id)
|
180 |
+
|
181 |
+
if op_type not in self.op_times:
|
182 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
183 |
+
times = self.op_times[op_type]
|
184 |
+
if isinstance(times, dict):
|
185 |
+
# If we have stage-specific times, use those
|
186 |
+
if stage_id not in times:
|
187 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
188 |
+
return times[stage_id]
|
189 |
+
else:
|
190 |
+
# If we have a single float, use the same value for all stages
|
191 |
+
return times
|
192 |
+
|
193 |
+
|
194 |
+
class Schedule:
|
195 |
+
def __init__(self, config: ScheduleConfig, init_ops: bool = True):
|
196 |
+
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
197 |
+
self.device_queues: List[DeviceQueue] = []
|
198 |
+
for dev_id in range(config.num_devices):
|
199 |
+
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
200 |
+
self.config = config
|
201 |
+
|
202 |
+
if init_ops:
|
203 |
+
self.init_operations()
|
204 |
+
self.op_to_overlapped = {}
|
205 |
+
|
206 |
+
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
207 |
+
for op in overlapped_op.operations:
|
208 |
+
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
209 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
210 |
+
|
211 |
+
def register_operation(self, op: Operation):
|
212 |
+
assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
|
213 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
|
214 |
+
|
215 |
+
def init_operations(self):
|
216 |
+
op_types = ["forward", "backward"]
|
217 |
+
if self.config.split_backward:
|
218 |
+
op_types = ["forward", "backward_D", "backward_W"]
|
219 |
+
for batch_id in range(self.config.num_batches):
|
220 |
+
for stage_id in range(self.config.num_stages):
|
221 |
+
for op_type in op_types:
|
222 |
+
self.ops[(batch_id, stage_id, op_type)] = Operation(
|
223 |
+
batch_id, stage_id, op_type
|
224 |
+
)
|
225 |
+
|
226 |
+
def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
|
227 |
+
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
228 |
+
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
229 |
+
if allow_none:
|
230 |
+
if (batch_id, stage_id, op_type) not in self.ops:
|
231 |
+
return None
|
232 |
+
return self.ops[(batch_id, stage_id, op_type)]
|
233 |
+
|
234 |
+
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
235 |
+
deps = []
|
236 |
+
if isinstance(op, OverlappedOperation):
|
237 |
+
for sub_op in op.operations:
|
238 |
+
deps.extend(self.get_dependencies(sub_op, include_device_dependency=False))
|
239 |
+
|
240 |
+
if include_device_dependency:
|
241 |
+
device_index = self.device_queues[op.device_id].ops.index(op)
|
242 |
+
if device_index > 0:
|
243 |
+
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
|
244 |
+
return deps
|
245 |
+
if op.op_type == "forward":
|
246 |
+
if op.stage_id > 0:
|
247 |
+
deps.append(
|
248 |
+
(
|
249 |
+
self.get_op(op.batch_id, op.stage_id - 1, "forward"),
|
250 |
+
self.config.p2p_latency,
|
251 |
+
)
|
252 |
+
)
|
253 |
+
if self.config.split_backward:
|
254 |
+
if op.op_type == "backward_D":
|
255 |
+
if op.stage_id < self.config.num_stages - 1:
|
256 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
|
257 |
+
if op_bwd_d is not None:
|
258 |
+
deps.append(
|
259 |
+
(
|
260 |
+
op_bwd_d,
|
261 |
+
self.config.p2p_latency,
|
262 |
+
)
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
deps.append(
|
266 |
+
(
|
267 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
268 |
+
self.config.p2p_latency,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
elif op.op_type == "backward_W":
|
272 |
+
if op.stage_id < self.config.num_stages - 1:
|
273 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
|
274 |
+
if op_bwd_d is not None:
|
275 |
+
deps.append(
|
276 |
+
(
|
277 |
+
op_bwd_d,
|
278 |
+
self.config.p2p_latency,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
deps.append(
|
283 |
+
(
|
284 |
+
self.get_op(op.batch_id, op.stage_id, "backward"),
|
285 |
+
self.config.p2p_latency,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
elif op.op_type == "backward":
|
289 |
+
if op.stage_id < self.config.num_stages - 1:
|
290 |
+
op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
|
291 |
+
if op_bwd is not None:
|
292 |
+
deps.append(
|
293 |
+
(
|
294 |
+
op_bwd,
|
295 |
+
self.config.p2p_latency,
|
296 |
+
)
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
deps.append(
|
300 |
+
(
|
301 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
|
302 |
+
self.config.p2p_latency,
|
303 |
+
)
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
if op.op_type == "backward":
|
307 |
+
if op.stage_id < self.config.num_stages - 1:
|
308 |
+
deps.append(
|
309 |
+
(
|
310 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
311 |
+
self.config.p2p_latency,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
|
315 |
+
if include_device_dependency:
|
316 |
+
device_index = self.device_queues[op.device_id].ops.index(op)
|
317 |
+
if device_index > 0:
|
318 |
+
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
|
319 |
+
return deps
|
320 |
+
|
321 |
+
def show(self):
|
322 |
+
"""Display detailed information about the schedule for debugging purposes."""
|
323 |
+
print("\n=== SCHEDULE DETAILS ===")
|
324 |
+
print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}")
|
325 |
+
print(f"Placement Strategy: {self.config.placement_strategy}")
|
326 |
+
print("\n=== DEVICE QUEUES ===")
|
327 |
+
|
328 |
+
for dev_id in range(self.config.num_devices):
|
329 |
+
print(f"\nDEVICE {dev_id} (Stages: {self.device_queues[dev_id].stages}):")
|
330 |
+
print("-" * 80)
|
331 |
+
print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
|
332 |
+
print("-" * 80)
|
333 |
+
|
334 |
+
for op in self.device_queues[dev_id].ops:
|
335 |
+
op_type = op.op_type
|
336 |
+
start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
|
337 |
+
end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
|
338 |
+
|
339 |
+
duration = "N/A"
|
340 |
+
if op.start_time is not None and op.end_time is not None:
|
341 |
+
duration = f"{op.end_time - op.start_time:.2f}"
|
342 |
+
|
343 |
+
print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
|
344 |
+
|
345 |
+
# Find the total execution time (if timing info is available)
|
346 |
+
if all(op.end_time is not None for op in self.ops.values()):
|
347 |
+
total_time = max(op.end_time for op in self.ops.values())
|
348 |
+
print(f"\nTotal execution time: {total_time:.2f}")
|
349 |
+
|
350 |
+
def execute(self):
|
351 |
+
# TODO: change the execution order to topological order via DAG
|
352 |
+
def execute_op(op: Operation):
|
353 |
+
if op.end_time is not None:
|
354 |
+
return
|
355 |
+
deps = self.get_dependencies(op)
|
356 |
+
if len(deps) == 0:
|
357 |
+
op.set_start_time(0.0)
|
358 |
+
else:
|
359 |
+
for dep, gap in deps:
|
360 |
+
if dep.end_time is None or dep.start_time is None:
|
361 |
+
execute_op(dep)
|
362 |
+
op.set_start_time(max(dep.end_time + gap for dep, gap in deps))
|
363 |
+
op.set_end_time(op.start_time + self.config.get_op_time(
|
364 |
+
op.op_type, op.stage_id
|
365 |
+
))
|
366 |
+
|
367 |
+
op_num = len(self.device_queues[0].ops)
|
368 |
+
for i in range(op_num):
|
369 |
+
for dev_id in range(self.config.num_devices):
|
370 |
+
if len(self.device_queues[dev_id].ops) <= i:
|
371 |
+
continue
|
372 |
+
op = self.device_queues[dev_id].ops[i]
|
373 |
+
execute_op(op)
|
374 |
+
|
375 |
+
for op in self.ops.values():
|
376 |
+
assert (
|
377 |
+
op.start_time is not None
|
378 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
|
379 |
+
assert (
|
380 |
+
op.end_time is not None
|
381 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
|
382 |
+
|
383 |
+
def get_total_execution_time(self):
|
384 |
+
return max(op.end_time for op in self.ops.values())
|
385 |
+
|
386 |
+
def get_bubble_rate(self):
|
387 |
+
actual_time = self.get_total_execution_time()
|
388 |
+
ideal_time = 0
|
389 |
+
for stage_id in range(self.config.num_stages):
|
390 |
+
for op_type in ["forward", "backward"]:
|
391 |
+
ideal_time += self.config.get_op_time(op_type, stage_id)
|
392 |
+
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
|
393 |
+
|
394 |
+
return (actual_time - ideal_time) / ideal_time
|
395 |
+
|
396 |
+
def get_device_running_time(self):
|
397 |
+
device_time = [0] * self.config.num_devices
|
398 |
+
for dev_id in range(self.config.num_devices):
|
399 |
+
for op in self.device_queues[dev_id].ops:
|
400 |
+
device_time[dev_id] += op.end_time - op.start_time
|
401 |
+
return device_time
|
src/strategies.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict, deque
|
2 |
+
from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
|
3 |
+
|
4 |
+
|
5 |
+
def generate_1f1b_schedule(config: ScheduleConfig):
|
6 |
+
schedule = Schedule(config)
|
7 |
+
|
8 |
+
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
|
9 |
+
|
10 |
+
for i in range(config.num_devices):
|
11 |
+
fwd_batch_id = 0
|
12 |
+
bwd_batch_id = 0
|
13 |
+
cooldown_batches = warmup_batches = config.num_devices - i - 1
|
14 |
+
steady_batches = config.num_batches - warmup_batches
|
15 |
+
|
16 |
+
for _ in range(warmup_batches):
|
17 |
+
schedule.device_queues[i].add_operation(
|
18 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
19 |
+
)
|
20 |
+
fwd_batch_id += 1
|
21 |
+
|
22 |
+
for _ in range(steady_batches):
|
23 |
+
schedule.device_queues[i].add_operation(
|
24 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
25 |
+
)
|
26 |
+
fwd_batch_id += 1
|
27 |
+
schedule.device_queues[i].add_operation(
|
28 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
29 |
+
)
|
30 |
+
bwd_batch_id += 1
|
31 |
+
|
32 |
+
for _ in range(cooldown_batches):
|
33 |
+
schedule.device_queues[i].add_operation(
|
34 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
35 |
+
)
|
36 |
+
bwd_batch_id += 1
|
37 |
+
|
38 |
+
return schedule
|
39 |
+
|
40 |
+
|
41 |
+
def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
42 |
+
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
|
43 |
+
schedule = Schedule(config)
|
44 |
+
total_batches = config.num_batches
|
45 |
+
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
|
46 |
+
assert config.split_backward, "ZB-1P requires split_backward=True"
|
47 |
+
|
48 |
+
for i in range(config.num_devices):
|
49 |
+
fwd_batch_id = 0
|
50 |
+
bwd_d_batch_id = 0
|
51 |
+
bwd_w_batch_id = 0
|
52 |
+
|
53 |
+
cooldown_batches = warmup_batches = config.num_devices - i - 1
|
54 |
+
steady_batches = total_batches - warmup_batches
|
55 |
+
|
56 |
+
for _ in range(warmup_batches):
|
57 |
+
schedule.device_queues[i].add_operation(
|
58 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
59 |
+
)
|
60 |
+
fwd_batch_id += 1
|
61 |
+
|
62 |
+
for _ in range(steady_batches):
|
63 |
+
schedule.device_queues[i].add_operation(
|
64 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
65 |
+
)
|
66 |
+
schedule.device_queues[i].add_operation(
|
67 |
+
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
68 |
+
)
|
69 |
+
if fwd_batch_id - bwd_w_batch_id >= config.num_devices - 1:
|
70 |
+
schedule.device_queues[i].add_operation(
|
71 |
+
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
72 |
+
)
|
73 |
+
bwd_w_batch_id += 1
|
74 |
+
bwd_d_batch_id += 1
|
75 |
+
fwd_batch_id += 1
|
76 |
+
|
77 |
+
for _ in range(cooldown_batches):
|
78 |
+
schedule.device_queues[i].add_operation(
|
79 |
+
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
80 |
+
)
|
81 |
+
|
82 |
+
schedule.device_queues[i].add_operation(
|
83 |
+
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
84 |
+
)
|
85 |
+
|
86 |
+
bwd_w_batch_id += 1
|
87 |
+
bwd_d_batch_id += 1
|
88 |
+
|
89 |
+
while bwd_w_batch_id < total_batches:
|
90 |
+
schedule.device_queues[i].add_operation(
|
91 |
+
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
92 |
+
)
|
93 |
+
bwd_w_batch_id += 1
|
94 |
+
|
95 |
+
return schedule
|
96 |
+
|
97 |
+
|
98 |
+
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
99 |
+
schedule = Schedule(config)
|
100 |
+
|
101 |
+
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
|
102 |
+
|
103 |
+
for i in range(config.num_devices):
|
104 |
+
fwd_batch_id = 0
|
105 |
+
bwd_batch_id = 0
|
106 |
+
cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
|
107 |
+
steady_batches = config.num_batches - warmup_batches
|
108 |
+
|
109 |
+
for _ in range(warmup_batches):
|
110 |
+
schedule.device_queues[i].add_operation(
|
111 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
112 |
+
)
|
113 |
+
fwd_batch_id += 1
|
114 |
+
|
115 |
+
for _ in range(steady_batches):
|
116 |
+
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
|
117 |
+
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
|
118 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
119 |
+
schedule.register_overlapped_operation(overlapped_op)
|
120 |
+
schedule.device_queues[i].add_operation(overlapped_op)
|
121 |
+
|
122 |
+
fwd_batch_id += 1
|
123 |
+
bwd_batch_id += 1
|
124 |
+
|
125 |
+
for _ in range(cooldown_batches):
|
126 |
+
schedule.device_queues[i].add_operation(
|
127 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
128 |
+
)
|
129 |
+
bwd_batch_id += 1
|
130 |
+
|
131 |
+
return schedule
|
132 |
+
|
133 |
+
|
134 |
+
def _get_pp_rank_microbatches(
|
135 |
+
num_microbatches,
|
136 |
+
num_devices,
|
137 |
+
device_id,
|
138 |
+
num_stages_per_device,
|
139 |
+
microbatch_group_size_per_vp_stage,
|
140 |
+
):
|
141 |
+
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
142 |
+
total_num_microbatches = num_microbatches * num_stages_per_device
|
143 |
+
|
144 |
+
if num_devices > 1:
|
145 |
+
# Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
|
146 |
+
# all workers, followed by more microbatches after depending on
|
147 |
+
# stage ID (more forward passes for earlier stages, later stages can
|
148 |
+
# immediately start with 1F1B).
|
149 |
+
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
150 |
+
num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
|
151 |
+
else:
|
152 |
+
# forward_backward_no_pipelining
|
153 |
+
num_warmup_microbatches = 1
|
154 |
+
|
155 |
+
if num_warmup_microbatches >= total_num_microbatches:
|
156 |
+
num_warmup_microbatches = total_num_microbatches
|
157 |
+
|
158 |
+
return num_warmup_microbatches
|
159 |
+
|
160 |
+
|
161 |
+
def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
|
162 |
+
"""Get the schedule table for PP scheduling.
|
163 |
+
|
164 |
+
Create a tunable schedule lookup table.
|
165 |
+
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
166 |
+
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
167 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
168 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
169 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
170 |
+
"""
|
171 |
+
schedule_table = []
|
172 |
+
for min_microbatch_id_in_group in range(
|
173 |
+
0, num_microbatches, microbatch_group_size_per_vp_stage
|
174 |
+
):
|
175 |
+
if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
|
176 |
+
# Construct schedule for the last microbatch group
|
177 |
+
schedule_table.extend(
|
178 |
+
[
|
179 |
+
(microbatch_id, model_chunk_id)
|
180 |
+
for model_chunk_id in range(num_model_chunks)
|
181 |
+
for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
|
182 |
+
]
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
# Construct schedule for other microbatch groups
|
186 |
+
schedule_table.extend(
|
187 |
+
[
|
188 |
+
(microbatch_id, model_chunk_id)
|
189 |
+
for model_chunk_id in range(num_model_chunks)
|
190 |
+
for microbatch_id in range(
|
191 |
+
min_microbatch_id_in_group,
|
192 |
+
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
|
193 |
+
)
|
194 |
+
]
|
195 |
+
)
|
196 |
+
return schedule_table
|
197 |
+
|
198 |
+
|
199 |
+
def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
|
200 |
+
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
201 |
+
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
202 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
203 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
204 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
205 |
+
|
206 |
+
Then the forward backward separated order is:
|
207 |
+
forward | 1 1 1 2 2 2 1 1 2 2
|
208 |
+
backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
|
209 |
+
|
210 |
+
If num_warmup_microbatches is 5, the output order is:
|
211 |
+
1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
|
212 |
+
"""
|
213 |
+
_, model_chunk_id_table = zip(*schedule_table)
|
214 |
+
forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
|
215 |
+
backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
|
216 |
+
order = forward_order[:num_warmup_microbatches]
|
217 |
+
for i in range(num_warmup_microbatches, len(forward_order)):
|
218 |
+
order.append(forward_order[i])
|
219 |
+
order.append(backward_order[i - num_warmup_microbatches])
|
220 |
+
if num_warmup_microbatches > 0:
|
221 |
+
order.extend(backward_order[-num_warmup_microbatches:])
|
222 |
+
return order
|
223 |
+
|
224 |
+
|
225 |
+
# Some codes are copied from Megatron-LM
|
226 |
+
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
227 |
+
schedule = Schedule(config)
|
228 |
+
|
229 |
+
for device_id in range(config.num_devices):
|
230 |
+
microbatch_group_size_per_vp_stage = config.num_devices
|
231 |
+
num_warmup_microbatches = _get_pp_rank_microbatches(
|
232 |
+
config.num_batches,
|
233 |
+
config.num_devices,
|
234 |
+
device_id,
|
235 |
+
config.num_stages_per_device,
|
236 |
+
microbatch_group_size_per_vp_stage,
|
237 |
+
)
|
238 |
+
|
239 |
+
schedule_table = _get_schedule_table(
|
240 |
+
config.num_batches,
|
241 |
+
config.num_stages_per_device,
|
242 |
+
microbatch_group_size_per_vp_stage,
|
243 |
+
)
|
244 |
+
|
245 |
+
order = _convert_schedule_table_to_order(
|
246 |
+
num_warmup_microbatches,
|
247 |
+
num_model_chunks=config.num_stages_per_device,
|
248 |
+
schedule_table=schedule_table,
|
249 |
+
)
|
250 |
+
|
251 |
+
cur_stage_microbatch_id = {}
|
252 |
+
for i in range(1, config.num_stages_per_device+1):
|
253 |
+
cur_stage_microbatch_id[i] = 0
|
254 |
+
cur_stage_microbatch_id[-i] = 0
|
255 |
+
for order_item in order:
|
256 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
257 |
+
|
258 |
+
if order_item > 0:
|
259 |
+
op_type = "forward"
|
260 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
261 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
262 |
+
elif order_item < 0:
|
263 |
+
op_type = "backward"
|
264 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
265 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
266 |
+
else:
|
267 |
+
raise ValueError(f"Invalid order item: {order_item}")
|
268 |
+
schedule.device_queues[device_id].add_operation(
|
269 |
+
schedule.get_op(micro_batch_id, stage_id, op_type)
|
270 |
+
)
|
271 |
+
return schedule
|
272 |
+
|
273 |
+
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
274 |
+
schedule = Schedule(config)
|
275 |
+
|
276 |
+
for device_id in range(config.num_devices):
|
277 |
+
microbatch_group_size_per_vp_stage = config.num_devices
|
278 |
+
num_warmup_microbatches = _get_pp_rank_microbatches(
|
279 |
+
config.num_batches,
|
280 |
+
config.num_devices,
|
281 |
+
device_id,
|
282 |
+
config.num_stages_per_device,
|
283 |
+
microbatch_group_size_per_vp_stage,
|
284 |
+
)
|
285 |
+
|
286 |
+
schedule_table = _get_schedule_table(
|
287 |
+
config.num_batches,
|
288 |
+
config.num_stages_per_device,
|
289 |
+
microbatch_group_size_per_vp_stage,
|
290 |
+
)
|
291 |
+
|
292 |
+
# NOTE: Add one more warmup microbatch for overlapped operations!
|
293 |
+
num_warmup_microbatches += 1
|
294 |
+
order = _convert_schedule_table_to_order(
|
295 |
+
num_warmup_microbatches,
|
296 |
+
num_model_chunks=config.num_stages_per_device,
|
297 |
+
schedule_table=schedule_table,
|
298 |
+
)
|
299 |
+
|
300 |
+
cur_stage_microbatch_id = {}
|
301 |
+
for i in range(1, config.num_stages_per_device+1):
|
302 |
+
cur_stage_microbatch_id[i] = 0
|
303 |
+
cur_stage_microbatch_id[-i] = 0
|
304 |
+
i = 0
|
305 |
+
|
306 |
+
num_overlapped_batches = len(order) - num_warmup_microbatches * 2
|
307 |
+
while i < len(order):
|
308 |
+
if i < num_warmup_microbatches:
|
309 |
+
order_item = order[i]
|
310 |
+
assert order_item > 0
|
311 |
+
op_type = "forward"
|
312 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
313 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
314 |
+
|
315 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
316 |
+
schedule.device_queues[device_id].add_operation(
|
317 |
+
schedule.get_op(micro_batch_id, stage_id, op_type)
|
318 |
+
)
|
319 |
+
i += 1
|
320 |
+
elif i >= num_warmup_microbatches and i < num_warmup_microbatches + num_overlapped_batches - 1:
|
321 |
+
order_item_a = order[i]
|
322 |
+
order_item_b = order[i+1]
|
323 |
+
|
324 |
+
op_type_a = "forward" if order_item_a > 0 else "backward"
|
325 |
+
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
|
326 |
+
cur_stage_microbatch_id[order_item_a] = cur_stage_microbatch_id[order_item_a] + 1
|
327 |
+
|
328 |
+
op_type_b = "forward" if order_item_b > 0 else "backward"
|
329 |
+
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
|
330 |
+
cur_stage_microbatch_id[order_item_b] = cur_stage_microbatch_id[order_item_b] + 1
|
331 |
+
|
332 |
+
stage_id_a = schedule.device_queues[device_id].stages[abs(order_item_a)-1]
|
333 |
+
stage_id_b = schedule.device_queues[device_id].stages[abs(order_item_b)-1]
|
334 |
+
|
335 |
+
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
|
336 |
+
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
|
337 |
+
overlapped_op = OverlappedOperation([op_a, op_b])
|
338 |
+
schedule.register_overlapped_operation(overlapped_op)
|
339 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
340 |
+
|
341 |
+
i += 2
|
342 |
+
else:
|
343 |
+
assert i >= num_warmup_microbatches + num_overlapped_batches
|
344 |
+
order_item = order[i]
|
345 |
+
assert order_item < 0
|
346 |
+
op_type = "backward"
|
347 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
348 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
349 |
+
|
350 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
351 |
+
schedule.device_queues[device_id].add_operation(
|
352 |
+
schedule.get_op(micro_batch_id, stage_id, op_type)
|
353 |
+
)
|
354 |
+
i += 1
|
355 |
+
|
356 |
+
|
357 |
+
return schedule
|
358 |
+
|
359 |
+
|
360 |
+
def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
|
361 |
+
"""
|
362 |
+
Helper function to create overlapped operations correctly.
|
363 |
+
This handles the underlying operation creation and registration to avoid device_id issues.
|
364 |
+
"""
|
365 |
+
# Get the operations from the schedule
|
366 |
+
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
367 |
+
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
368 |
+
|
369 |
+
# Create the overlapped operation
|
370 |
+
overlapped_op = OverlappedOperation([op1, op2])
|
371 |
+
|
372 |
+
# Register in the schedule to ensure proper tracking
|
373 |
+
schedule.register_overlapped_operation(overlapped_op)
|
374 |
+
|
375 |
+
return overlapped_op
|
376 |
+
|
377 |
+
|
378 |
+
def generate_dualpipe_schedule(config: ScheduleConfig):
|
379 |
+
"""
|
380 |
+
Implements the DualPipe scheduling strategy.
|
381 |
+
|
382 |
+
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
383 |
+
and backward computation-communication phases and reduces pipeline bubbles.
|
384 |
+
|
385 |
+
The DualPipe strategy has the following characteristics:
|
386 |
+
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
387 |
+
2. Each device handles both a forward stage and a reverse stage
|
388 |
+
3. Overlaps forward and backward operations to reduce bubble size
|
389 |
+
4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
|
390 |
+
5. Currently only supports split_backward=True.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
config: The scheduling configuration
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
A Schedule object with the DualPipe scheduling
|
397 |
+
"""
|
398 |
+
# Ensure placement strategy is set for Schedule initialization
|
399 |
+
assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
|
400 |
+
# Assertions based on DualPipe requirements
|
401 |
+
assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
|
402 |
+
assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
|
403 |
+
assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
|
404 |
+
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
405 |
+
# Here, M (config.num_batches) corresponds to half_num_chunks
|
406 |
+
assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
|
407 |
+
assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
|
408 |
+
|
409 |
+
schedule = Schedule(config, init_ops=False)
|
410 |
+
|
411 |
+
num_stages = config.num_stages
|
412 |
+
num_devices = config.num_devices
|
413 |
+
# config.num_batches is M in the original paper, which corresponds to half_num_chunks
|
414 |
+
half_num_chunks = config.num_batches // 2
|
415 |
+
num_half_ranks = num_devices // 2
|
416 |
+
|
417 |
+
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
418 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
419 |
+
|
420 |
+
waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
|
421 |
+
|
422 |
+
for device_id in range(num_devices):
|
423 |
+
is_in_second_half = device_id >= num_half_ranks
|
424 |
+
if is_in_second_half:
|
425 |
+
fwd_batch_ids[device_id, 1] = 0
|
426 |
+
fwd_batch_ids[device_id, 0] = config.num_batches // 2
|
427 |
+
bwd_d_batch_ids[device_id, 1] = 0
|
428 |
+
bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
|
429 |
+
else:
|
430 |
+
fwd_batch_ids[device_id, 0] = 0
|
431 |
+
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
432 |
+
bwd_d_batch_ids[device_id, 0] = 0
|
433 |
+
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
434 |
+
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
435 |
+
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
|
436 |
+
stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
|
437 |
+
if not is_in_second_half:
|
438 |
+
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
439 |
+
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
440 |
+
else:
|
441 |
+
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
442 |
+
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
443 |
+
|
444 |
+
|
445 |
+
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
446 |
+
# Retrieve the correct pre-initialized Operation object
|
447 |
+
op = Operation(batch_id, stage_id, op_type)
|
448 |
+
schedule.register_operation(op)
|
449 |
+
# Add to the device queue
|
450 |
+
schedule.device_queues[device_id].add_operation(op)
|
451 |
+
|
452 |
+
def _schedule_forward_chunk(device_id, phase, is_in_second_half):
|
453 |
+
"""Schedules a forward compute operation."""
|
454 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
455 |
+
batch_id = fwd_batch_ids[device_id, phase]
|
456 |
+
add_op_to_queue(device_id, stage_id, "forward", batch_id)
|
457 |
+
fwd_batch_ids[device_id, phase] += 1
|
458 |
+
|
459 |
+
def _schedule_backward_chunk(device_id, phase, is_in_second_half):
|
460 |
+
"""Schedules a backward_D with backward_W compute operation."""
|
461 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
462 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
463 |
+
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
464 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
465 |
+
|
466 |
+
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
467 |
+
"""Schedules a backward_D compute operation."""
|
468 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
469 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
470 |
+
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
|
471 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
472 |
+
waited_weight_grad[device_id].append((stage_id, batch_id))
|
473 |
+
|
474 |
+
def _schedule_backward_weight_chunk(device_id):
|
475 |
+
"""Schedules a backward_W compute operation."""
|
476 |
+
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
477 |
+
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
478 |
+
|
479 |
+
def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
|
480 |
+
"""Schedules an overlapped forward and backward_D compute operation."""
|
481 |
+
fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
|
482 |
+
bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
|
483 |
+
|
484 |
+
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
485 |
+
|
486 |
+
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
487 |
+
schedule.register_operation(fwd_op)
|
488 |
+
fwd_batch_ids[device_id, fwd_phase] += 1
|
489 |
+
|
490 |
+
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
|
491 |
+
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
|
492 |
+
schedule.register_operation(bwd_op)
|
493 |
+
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
494 |
+
|
495 |
+
# Create and register the overlapped operation
|
496 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
497 |
+
schedule.register_overlapped_operation(overlapped_op)
|
498 |
+
|
499 |
+
# Add the overlapped operation to the queue
|
500 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
501 |
+
|
502 |
+
|
503 |
+
# Process each device (rank in original code)
|
504 |
+
for device_id in range(num_devices):
|
505 |
+
half_rank = min(device_id, num_devices - 1 - device_id)
|
506 |
+
is_in_second_half = device_id >= num_half_ranks
|
507 |
+
is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
|
508 |
+
|
509 |
+
# Map original steps to operation additions
|
510 |
+
# Step 1: nF0
|
511 |
+
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
512 |
+
for _ in range(step_1_count):
|
513 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
514 |
+
|
515 |
+
# Step 2: nF0F1
|
516 |
+
step_2_count = half_rank + 1
|
517 |
+
for i in range(step_2_count):
|
518 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
519 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
520 |
+
|
521 |
+
# Step 3: nB1W1F1
|
522 |
+
step_3_count = num_half_ranks - half_rank - 1
|
523 |
+
for _ in range(step_3_count):
|
524 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
|
525 |
+
_schedule_backward_weight_chunk(device_id,) # W1
|
526 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
527 |
+
|
528 |
+
# Step 4 (Main step): nF0B1F1B0
|
529 |
+
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
530 |
+
for i in range(step_4_count):
|
531 |
+
# if i == 0 and is_middle_rank:
|
532 |
+
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
|
533 |
+
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
534 |
+
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
|
535 |
+
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
|
536 |
+
# else:
|
537 |
+
# Overlap F0 and B1_D, then schedule W1
|
538 |
+
_schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
|
539 |
+
|
540 |
+
# Overlap F1 and B0_D, then schedule W0
|
541 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
542 |
+
|
543 |
+
# Step 5: nB1F1B0
|
544 |
+
step_5_count = num_half_ranks - half_rank - 1
|
545 |
+
for _ in range(step_5_count):
|
546 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
|
547 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
548 |
+
|
549 |
+
# Step 6: nB1B0
|
550 |
+
step_6_count = half_rank + 1
|
551 |
+
enable_zb = False
|
552 |
+
for i in range(step_6_count):
|
553 |
+
if i == step_6_count // 2 and half_rank % 2 == 1:
|
554 |
+
enable_zb = True
|
555 |
+
if enable_zb:
|
556 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
|
557 |
+
else:
|
558 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half)
|
559 |
+
if i == step_6_count // 2 and half_rank % 2 == 0:
|
560 |
+
enable_zb = True
|
561 |
+
if enable_zb:
|
562 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half)
|
563 |
+
else:
|
564 |
+
_schedule_backward_chunk(device_id, 0, is_in_second_half)
|
565 |
+
|
566 |
+
# Step 7: nWB0
|
567 |
+
step_7_count = num_half_ranks - half_rank - 1
|
568 |
+
for _ in range(step_7_count):
|
569 |
+
_schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
|
570 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
|
571 |
+
|
572 |
+
# Step 8: nW
|
573 |
+
step_8_count = half_rank + 1
|
574 |
+
for _ in range(step_8_count):
|
575 |
+
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
576 |
+
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
577 |
+
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
578 |
+
_schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
|
579 |
+
|
580 |
+
return schedule
|
581 |
+
|
src/visualizer.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
from dash import dcc, html
|
3 |
+
from dash.dependencies import Input, Output
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from typing import List, Dict
|
6 |
+
from tqdm import tqdm
|
7 |
+
from functools import lru_cache
|
8 |
+
import webbrowser
|
9 |
+
from threading import Timer
|
10 |
+
|
11 |
+
from src.execution_model import Schedule, OverlappedOperation
|
12 |
+
|
13 |
+
|
14 |
+
def convert_schedule_to_visualization_format(schedule: Schedule):
|
15 |
+
"""
|
16 |
+
Converts a Schedule object to the format needed for visualization.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
20 |
+
"""
|
21 |
+
# Make sure all operations have start and end times
|
22 |
+
for op in schedule.ops.values():
|
23 |
+
if op.start_time is None or op.end_time is None:
|
24 |
+
raise ValueError(
|
25 |
+
"Operations must have start and end times. Run ScheduleExecutor.execute() first."
|
26 |
+
)
|
27 |
+
|
28 |
+
visualization_data = {}
|
29 |
+
|
30 |
+
# Organize operations by device
|
31 |
+
for device_id, device_queue in enumerate(schedule.device_queues):
|
32 |
+
visualization_data[device_id] = []
|
33 |
+
|
34 |
+
for op in device_queue.ops:
|
35 |
+
# Handle both regular Operations and OverlappedOperations
|
36 |
+
if isinstance(op, OverlappedOperation):
|
37 |
+
visualization_data[device_id].append(
|
38 |
+
{
|
39 |
+
"type": op.op_type,
|
40 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
41 |
+
"stage": op.stage_id,
|
42 |
+
"start_time": op.start_time,
|
43 |
+
"duration": op.end_time - op.start_time,
|
44 |
+
"is_overlapped": True,
|
45 |
+
"operations": [
|
46 |
+
{
|
47 |
+
"type": nested_op.op_type,
|
48 |
+
"batch": nested_op.batch_id + 1,
|
49 |
+
"stage": nested_op.stage_id
|
50 |
+
}
|
51 |
+
for nested_op in op.operations
|
52 |
+
]
|
53 |
+
}
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
visualization_data[device_id].append(
|
57 |
+
{
|
58 |
+
"type": op.op_type,
|
59 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
60 |
+
"stage": op.stage_id,
|
61 |
+
"start_time": op.start_time,
|
62 |
+
"duration": op.end_time - op.start_time,
|
63 |
+
"is_overlapped": False
|
64 |
+
}
|
65 |
+
)
|
66 |
+
|
67 |
+
return visualization_data
|
68 |
+
|
69 |
+
|
70 |
+
# Cache the color calculation as it's repeatedly called with the same parameters
|
71 |
+
@lru_cache(maxsize=128)
|
72 |
+
def get_color(op_type: str, stage_id: int, num_devices: int):
|
73 |
+
# A more harmonious blue palette with low saturation and high brightness
|
74 |
+
forward_colors = [
|
75 |
+
"#0a5aff", # Intense blue
|
76 |
+
"#4c88ff", # Blue (deeper)
|
77 |
+
"#7aa7ff", # Medium blue
|
78 |
+
"#a8c5ff", # Soft blue
|
79 |
+
"#d6e4ff", # Very light blue
|
80 |
+
]
|
81 |
+
|
82 |
+
# Orange palette for backward operations with low saturation and high brightness
|
83 |
+
backward_colors = [
|
84 |
+
"#f47b00", # Intense orange
|
85 |
+
"#ffa952", # Orange
|
86 |
+
"#ffc78e", # Light orange
|
87 |
+
"#ffe6cc", # Very light orange
|
88 |
+
]
|
89 |
+
|
90 |
+
# Improved teal/turquoise palette with low saturation and high brightness
|
91 |
+
backward_d_colors = [
|
92 |
+
"#4dcccc", # Light teal
|
93 |
+
"#33b3b3", # Teal
|
94 |
+
"#009999", # Medium teal
|
95 |
+
"#008080", # Dark teal
|
96 |
+
]
|
97 |
+
|
98 |
+
# Improved green palette with low saturation and high brightness
|
99 |
+
backward_w_colors = [
|
100 |
+
"#33b373", # Medium green
|
101 |
+
"#009959", # Forest green
|
102 |
+
"#008040", # Dark green
|
103 |
+
]
|
104 |
+
|
105 |
+
virtual_stage = stage_id // num_devices
|
106 |
+
|
107 |
+
# If virtual_stage is beyond our color list, cycle through the colors
|
108 |
+
color_index = virtual_stage % len(forward_colors)
|
109 |
+
|
110 |
+
if op_type == "forward":
|
111 |
+
return forward_colors[color_index]
|
112 |
+
elif op_type == "backward":
|
113 |
+
return backward_colors[color_index % len(backward_colors)]
|
114 |
+
elif op_type == "backward_D":
|
115 |
+
return backward_d_colors[color_index % len(backward_d_colors)]
|
116 |
+
elif op_type == "backward_W":
|
117 |
+
return backward_w_colors[color_index % len(backward_w_colors)]
|
118 |
+
else:
|
119 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
120 |
+
|
121 |
+
|
122 |
+
def create_pipeline_figure(
|
123 |
+
schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
Create a Plotly figure for pipeline parallelism scheduling.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
|
130 |
+
max_time: Optional maximum time to display
|
131 |
+
show_progress: Whether to show a progress bar
|
132 |
+
"""
|
133 |
+
# Find the number of devices
|
134 |
+
num_devices = len(schedule_data)
|
135 |
+
|
136 |
+
empty_color = "whitesmoke"
|
137 |
+
|
138 |
+
# Find the maximum time in the schedule if not provided
|
139 |
+
if max_time is None:
|
140 |
+
max_time = 0
|
141 |
+
for device in schedule_data:
|
142 |
+
for task in schedule_data[device]:
|
143 |
+
end_time = task["start_time"] + task["duration"]
|
144 |
+
if end_time > max_time:
|
145 |
+
max_time = end_time
|
146 |
+
|
147 |
+
# Determine maximum batch number to decide whether to show text labels
|
148 |
+
max_batch = 0
|
149 |
+
for device in schedule_data:
|
150 |
+
for task in schedule_data[device]:
|
151 |
+
max_batch = max(max_batch, task["batch"])
|
152 |
+
|
153 |
+
# Flag to determine whether to show text labels
|
154 |
+
num_operations_per_device = len(schedule_data[0])
|
155 |
+
show_text_labels = num_operations_per_device <= 64
|
156 |
+
|
157 |
+
# Create a figure
|
158 |
+
fig = go.Figure()
|
159 |
+
|
160 |
+
# Initialize progress tracking
|
161 |
+
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
162 |
+
tasks_processed = 0
|
163 |
+
|
164 |
+
if show_progress:
|
165 |
+
progress_bar = tqdm(
|
166 |
+
total=total_tasks + num_devices + 3, desc="Creating visualization"
|
167 |
+
)
|
168 |
+
|
169 |
+
# Create a custom y-axis with no gaps between devices
|
170 |
+
y_spacing = 1.0 # Use 1.0 for no gaps
|
171 |
+
|
172 |
+
# Batch processing for increased performance
|
173 |
+
shapes = []
|
174 |
+
annotations = []
|
175 |
+
hover_traces = []
|
176 |
+
|
177 |
+
# Add rectangles for each task
|
178 |
+
for device_idx, device in enumerate(schedule_data):
|
179 |
+
device_idx_reversed = num_devices - device_idx - 1
|
180 |
+
|
181 |
+
# Sort tasks by start time to ensure correct rendering
|
182 |
+
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
183 |
+
|
184 |
+
for task in sorted_tasks:
|
185 |
+
# Calculate y positions with no gaps
|
186 |
+
y_pos = device_idx_reversed * y_spacing
|
187 |
+
start_time = task["start_time"]
|
188 |
+
duration = task["duration"]
|
189 |
+
|
190 |
+
# Special handling for overlapped operations
|
191 |
+
if task.get("is_overlapped", False) and "operations" in task:
|
192 |
+
# Prepare hover text for the entire overlapped operation
|
193 |
+
op_details = "<br>".join([
|
194 |
+
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
195 |
+
for op in task["operations"]
|
196 |
+
])
|
197 |
+
hover_text = (
|
198 |
+
f"Overlapped Operations:<br>{op_details}<br>"
|
199 |
+
f"Start: {task['start_time']:.2f}<br>"
|
200 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
201 |
+
f"Duration: {task['duration']:.2f}"
|
202 |
+
)
|
203 |
+
|
204 |
+
# Add invisible marker for hover info
|
205 |
+
hover_traces.append(
|
206 |
+
dict(
|
207 |
+
x=[start_time + duration / 2],
|
208 |
+
y=[y_pos],
|
209 |
+
mode="markers",
|
210 |
+
marker=dict(opacity=0), # Invisible marker
|
211 |
+
hoverinfo="text",
|
212 |
+
text=hover_text,
|
213 |
+
showlegend=False,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
|
217 |
+
# Calculate height of each sub-operation
|
218 |
+
sub_height = 1.0 / len(task["operations"])
|
219 |
+
|
220 |
+
# Add rectangles and annotations for each sub-operation
|
221 |
+
for i, sub_op in enumerate(task["operations"]):
|
222 |
+
# Determine color for this sub-operation
|
223 |
+
color = get_color(sub_op["type"], sub_op["stage"], num_devices)
|
224 |
+
|
225 |
+
# Calculate y position for this sub-operation
|
226 |
+
sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
|
227 |
+
sub_y_pos_top = sub_y_pos_bottom + sub_height
|
228 |
+
sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
|
229 |
+
|
230 |
+
# Add rectangle for this sub-operation
|
231 |
+
shapes.append(
|
232 |
+
dict(
|
233 |
+
type="rect",
|
234 |
+
x0=start_time,
|
235 |
+
y0=sub_y_pos_bottom,
|
236 |
+
x1=start_time + duration,
|
237 |
+
y1=sub_y_pos_top,
|
238 |
+
line=dict(color="black", width=0.5),
|
239 |
+
fillcolor=color,
|
240 |
+
layer="above",
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
# Add batch number text for this sub-operation only if show_text_labels is True
|
245 |
+
if show_text_labels:
|
246 |
+
# Determine text color based on background color
|
247 |
+
if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
|
248 |
+
text_color = "black"
|
249 |
+
else:
|
250 |
+
text_color = "white"
|
251 |
+
|
252 |
+
annotations.append(
|
253 |
+
dict(
|
254 |
+
x=start_time + duration / 2,
|
255 |
+
y=sub_y_center,
|
256 |
+
text=f"{sub_op['batch']}",
|
257 |
+
showarrow=False,
|
258 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
259 |
+
)
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
# Regular (non-overlapped) operation
|
263 |
+
# Determine task color and text color
|
264 |
+
if task["type"] == "forward":
|
265 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
266 |
+
text_color = "white"
|
267 |
+
name = "Forward"
|
268 |
+
elif task["type"] == "backward":
|
269 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
270 |
+
text_color = "black"
|
271 |
+
name = "Backward"
|
272 |
+
elif task["type"] == "backward_D":
|
273 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
274 |
+
text_color = "black"
|
275 |
+
name = "Backward (Grad)"
|
276 |
+
elif task["type"] == "backward_W":
|
277 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
278 |
+
text_color = "black"
|
279 |
+
name = "Backward (Weight)"
|
280 |
+
else:
|
281 |
+
color = empty_color
|
282 |
+
text_color = "black"
|
283 |
+
name = "Unknown"
|
284 |
+
|
285 |
+
# Add rectangle for the task
|
286 |
+
shapes.append(
|
287 |
+
dict(
|
288 |
+
type="rect",
|
289 |
+
x0=start_time,
|
290 |
+
y0=y_pos - 0.5,
|
291 |
+
x1=start_time + duration,
|
292 |
+
y1=y_pos + 0.5,
|
293 |
+
line=dict(color="black", width=0.5),
|
294 |
+
fillcolor=color,
|
295 |
+
layer="above",
|
296 |
+
)
|
297 |
+
)
|
298 |
+
|
299 |
+
# Add batch number text only if show_text_labels is True
|
300 |
+
if show_text_labels:
|
301 |
+
annotations.append(
|
302 |
+
dict(
|
303 |
+
x=start_time + duration / 2,
|
304 |
+
y=y_pos,
|
305 |
+
text=f"{task['batch']}",
|
306 |
+
showarrow=False,
|
307 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
308 |
+
)
|
309 |
+
)
|
310 |
+
|
311 |
+
# Prepare hover data
|
312 |
+
hover_text = (
|
313 |
+
f"Batch: {task['batch']}<br>"
|
314 |
+
f"Stage: {task['stage']}<br>"
|
315 |
+
f"Type: {name}<br>"
|
316 |
+
f"Start: {task['start_time']:.2f}<br>"
|
317 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
318 |
+
f"Duration: {task['duration']:.2f}"
|
319 |
+
)
|
320 |
+
|
321 |
+
hover_traces.append(
|
322 |
+
dict(
|
323 |
+
x=[start_time + duration / 2],
|
324 |
+
y=[y_pos],
|
325 |
+
mode="markers",
|
326 |
+
marker=dict(opacity=0), # Invisible marker
|
327 |
+
hoverinfo="text",
|
328 |
+
text=hover_text,
|
329 |
+
showlegend=False,
|
330 |
+
)
|
331 |
+
)
|
332 |
+
|
333 |
+
# Update progress
|
334 |
+
if show_progress:
|
335 |
+
tasks_processed += 1
|
336 |
+
progress_bar.update(1)
|
337 |
+
|
338 |
+
# Add all shapes at once for better performance
|
339 |
+
fig.update_layout(shapes=shapes)
|
340 |
+
|
341 |
+
# Add all annotations at once
|
342 |
+
fig.update_layout(annotations=annotations)
|
343 |
+
|
344 |
+
# Add all hover traces at once
|
345 |
+
for trace in hover_traces:
|
346 |
+
fig.add_trace(go.Scatter(**trace))
|
347 |
+
|
348 |
+
# Add custom legend
|
349 |
+
legend_items = []
|
350 |
+
|
351 |
+
# Find the maximum virtual stage in the data
|
352 |
+
max_virtual_stage = 0
|
353 |
+
for device in schedule_data:
|
354 |
+
for task in schedule_data[device]:
|
355 |
+
virtual_stage = task["stage"] // num_devices
|
356 |
+
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
357 |
+
|
358 |
+
# Check if overlapped operations exist
|
359 |
+
has_overlapped = any(
|
360 |
+
task.get("is_overlapped", False)
|
361 |
+
for device in schedule_data
|
362 |
+
for task in schedule_data[device]
|
363 |
+
)
|
364 |
+
|
365 |
+
# Add forward and backward items for each virtual stage
|
366 |
+
for vs in range(max_virtual_stage + 1):
|
367 |
+
legend_items.append(
|
368 |
+
dict(
|
369 |
+
name=f"Forward (VS {vs})",
|
370 |
+
color=get_color("forward", vs * num_devices, num_devices),
|
371 |
+
)
|
372 |
+
)
|
373 |
+
legend_items.append(
|
374 |
+
dict(
|
375 |
+
name=f"Backward (VS {vs})",
|
376 |
+
color=get_color("backward", vs * num_devices, num_devices),
|
377 |
+
)
|
378 |
+
)
|
379 |
+
# Add entries for split backward operations if this is a zb1p schedule
|
380 |
+
if any(
|
381 |
+
task["type"] in ["backward_D", "backward_W"]
|
382 |
+
for device in schedule_data
|
383 |
+
for task in schedule_data[device]
|
384 |
+
):
|
385 |
+
legend_items.append(
|
386 |
+
dict(
|
387 |
+
name=f"Backward Grad (VS {vs})",
|
388 |
+
color=get_color("backward_D", vs * num_devices, num_devices),
|
389 |
+
)
|
390 |
+
)
|
391 |
+
legend_items.append(
|
392 |
+
dict(
|
393 |
+
name=f"Backward Weight (VS {vs})",
|
394 |
+
color=get_color("backward_W", vs * num_devices, num_devices),
|
395 |
+
)
|
396 |
+
)
|
397 |
+
|
398 |
+
# If no tasks found, add default legend items
|
399 |
+
if not legend_items:
|
400 |
+
legend_items = [
|
401 |
+
dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
|
402 |
+
dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
|
403 |
+
dict(
|
404 |
+
name="Backward Grad (VS 0)",
|
405 |
+
color=get_color("backward_D", 0, num_devices),
|
406 |
+
),
|
407 |
+
dict(
|
408 |
+
name="Backward Weight (VS 0)",
|
409 |
+
color=get_color("backward_W", 0, num_devices),
|
410 |
+
),
|
411 |
+
]
|
412 |
+
|
413 |
+
for i, item in enumerate(legend_items):
|
414 |
+
fig.add_trace(
|
415 |
+
go.Scatter(
|
416 |
+
x=[None],
|
417 |
+
y=[None],
|
418 |
+
mode="markers",
|
419 |
+
marker=dict(size=10, color=item["color"]),
|
420 |
+
name=item["name"],
|
421 |
+
showlegend=True,
|
422 |
+
)
|
423 |
+
)
|
424 |
+
if show_progress and i < len(legend_items) - 1:
|
425 |
+
progress_bar.update(1)
|
426 |
+
|
427 |
+
# Set axis properties
|
428 |
+
device_labels = [f"Device {i+1}" for i in range(num_devices)]
|
429 |
+
|
430 |
+
# Calculate tick positions with no gaps
|
431 |
+
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
432 |
+
|
433 |
+
# Adjust the range to ensure there are no empty spaces at the end
|
434 |
+
x_end = max_time * 1.05 # Add a small margin
|
435 |
+
|
436 |
+
title_text = "Pipeline Parallelism Schedule"
|
437 |
+
|
438 |
+
fig.update_layout(
|
439 |
+
yaxis=dict(
|
440 |
+
tickmode="array",
|
441 |
+
tickvals=tick_positions,
|
442 |
+
ticktext=device_labels,
|
443 |
+
showgrid=False,
|
444 |
+
zeroline=False,
|
445 |
+
),
|
446 |
+
margin=dict(l=50, r=20, t=40, b=40),
|
447 |
+
plot_bgcolor="white",
|
448 |
+
title=dict(
|
449 |
+
text=title_text,
|
450 |
+
x=0.5,
|
451 |
+
y=0.98, # Move title position closer to the top
|
452 |
+
font=dict(size=20),
|
453 |
+
),
|
454 |
+
legend=dict(
|
455 |
+
orientation="v", # Changed from horizontal to vertical
|
456 |
+
yanchor="top",
|
457 |
+
y=1.02, # Position at the top
|
458 |
+
xanchor="right",
|
459 |
+
x=1.20, # Position further to the right to accommodate more items
|
460 |
+
title=dict(text="<b>Operation Types:</b>"),
|
461 |
+
itemsizing="constant",
|
462 |
+
tracegroupgap=0,
|
463 |
+
),
|
464 |
+
width=2000, # Increase width to accommodate the expanded legend
|
465 |
+
height=400, # Maintain current height
|
466 |
+
bargap=0,
|
467 |
+
bargroupgap=0,
|
468 |
+
)
|
469 |
+
|
470 |
+
if show_progress:
|
471 |
+
progress_bar.update(1)
|
472 |
+
progress_bar.close()
|
473 |
+
|
474 |
+
return fig
|
475 |
+
|
476 |
+
|
477 |
+
# Cache for storing processed schedule data
|
478 |
+
_schedule_data_cache = {}
|
479 |
+
|
480 |
+
|
481 |
+
def create_dash_app(
|
482 |
+
schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
|
483 |
+
):
|
484 |
+
"""
|
485 |
+
Create a Dash app to visualize the pipeline schedule.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
schedule: Schedule object to visualize
|
489 |
+
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
490 |
+
enable_caching: Whether to cache the schedule data and figure
|
491 |
+
"""
|
492 |
+
# Process schedule data only once and cache it
|
493 |
+
global _schedule_data_cache
|
494 |
+
cache_key = id(schedule)
|
495 |
+
|
496 |
+
if enable_caching and cache_key in _schedule_data_cache:
|
497 |
+
schedule_data = _schedule_data_cache[cache_key]
|
498 |
+
print("Using cached schedule data")
|
499 |
+
else:
|
500 |
+
schedule_data = convert_schedule_to_visualization_format(schedule)
|
501 |
+
if enable_caching:
|
502 |
+
_schedule_data_cache[cache_key] = schedule_data
|
503 |
+
print("Cached schedule data")
|
504 |
+
|
505 |
+
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
506 |
+
print(f"Total tasks in schedule: {total_tasks}")
|
507 |
+
|
508 |
+
app = dash.Dash(__name__)
|
509 |
+
app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
|
510 |
+
|
511 |
+
# Create a more informative layout with data size information
|
512 |
+
app.layout = html.Div(
|
513 |
+
[
|
514 |
+
html.H1(
|
515 |
+
f"Pipeline Parallelism Visualization - {schedule_type}",
|
516 |
+
style={"textAlign": "center"},
|
517 |
+
),
|
518 |
+
html.Div(
|
519 |
+
[
|
520 |
+
html.P(
|
521 |
+
f"Number of devices: {len(schedule_data)}",
|
522 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
523 |
+
),
|
524 |
+
html.P(
|
525 |
+
f"Total tasks: {total_tasks}",
|
526 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
527 |
+
),
|
528 |
+
],
|
529 |
+
style={"marginBottom": "20px"},
|
530 |
+
),
|
531 |
+
html.Div(id="graph-container", children=[]),
|
532 |
+
dcc.Loading(
|
533 |
+
id="loading-graph",
|
534 |
+
type="circle",
|
535 |
+
children=[
|
536 |
+
dcc.Graph(
|
537 |
+
id="pipeline-graph",
|
538 |
+
config={
|
539 |
+
"displayModeBar": True,
|
540 |
+
"toImageButtonOptions": {
|
541 |
+
"format": "png",
|
542 |
+
"filename": "pipeline_visualization",
|
543 |
+
},
|
544 |
+
},
|
545 |
+
),
|
546 |
+
],
|
547 |
+
),
|
548 |
+
]
|
549 |
+
)
|
550 |
+
|
551 |
+
# Cache for storing figure to avoid regenerating it
|
552 |
+
figure_cache = {}
|
553 |
+
|
554 |
+
@app.callback(
|
555 |
+
Output("pipeline-graph", "figure"),
|
556 |
+
Input("graph-container", "children"),
|
557 |
+
prevent_initial_call=False,
|
558 |
+
)
|
559 |
+
def load_graph(_):
|
560 |
+
# Use cached figure if available
|
561 |
+
cache_key = f"{id(schedule)}"
|
562 |
+
if enable_caching and cache_key in figure_cache:
|
563 |
+
print("Using cached figure")
|
564 |
+
return figure_cache[cache_key]
|
565 |
+
|
566 |
+
# Create the figure
|
567 |
+
figure = create_pipeline_figure(schedule_data, show_progress=True)
|
568 |
+
|
569 |
+
# Cache the figure
|
570 |
+
if enable_caching:
|
571 |
+
figure_cache[cache_key] = figure
|
572 |
+
print("Cached figure")
|
573 |
+
|
574 |
+
return figure
|
575 |
+
|
576 |
+
return app
|
577 |
+
|
578 |
+
|
579 |
+
def visualize_pipeline_parallelism_dash(
|
580 |
+
schedule: Schedule,
|
581 |
+
port: int = 8050,
|
582 |
+
debug: bool = False,
|
583 |
+
enable_caching: bool = True,
|
584 |
+
schedule_type="1f1b",
|
585 |
+
open_browser: bool = True,
|
586 |
+
):
|
587 |
+
"""
|
588 |
+
Launch a Dash app to visualize the pipeline schedule interactively.
|
589 |
+
|
590 |
+
Args:
|
591 |
+
schedule: Schedule object to visualize
|
592 |
+
port: Port to run the Dash app on
|
593 |
+
debug: Whether to run the Dash app in debug mode
|
594 |
+
enable_caching: Whether to cache schedule data and figures
|
595 |
+
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
596 |
+
open_browser: Whether to automatically open a browser window
|
597 |
+
"""
|
598 |
+
app = create_dash_app(
|
599 |
+
schedule, schedule_type=schedule_type, enable_caching=enable_caching
|
600 |
+
)
|
601 |
+
|
602 |
+
# Define function to open browser after a short delay
|
603 |
+
def open_browser_tab():
|
604 |
+
webbrowser.open_new_tab(f"http://localhost:{port}/")
|
605 |
+
|
606 |
+
# Open browser automatically if requested
|
607 |
+
if open_browser:
|
608 |
+
# Use a timer to open the browser after the server has started
|
609 |
+
Timer(1.0, open_browser_tab).start()
|
610 |
+
|
611 |
+
print(f"Starting Dash app on http://localhost:{port}/")
|
612 |
+
app.run_server(debug=debug, port=port)
|