File size: 22,180 Bytes
c048b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
from typing import List, Dict
from tqdm import tqdm
from functools import lru_cache
import webbrowser
from threading import Timer

from src.execution_model import Schedule, OverlappedOperation


def convert_schedule_to_visualization_format(schedule: Schedule):
    """
    Converts a Schedule object to the format needed for visualization.

    Returns:
        Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
    """
    # Make sure all operations have start and end times
    for op in schedule.ops.values():
        if op.start_time is None or op.end_time is None:
            raise ValueError(
                "Operations must have start and end times. Run ScheduleExecutor.execute() first."
            )

    visualization_data = {}

    # Organize operations by device
    for device_id, device_queue in enumerate(schedule.device_queues):
        visualization_data[device_id] = []

        for op in device_queue.ops:
            # Handle both regular Operations and OverlappedOperations
            if isinstance(op, OverlappedOperation):
                visualization_data[device_id].append(
                    {
                        "type": op.op_type,
                        "batch": op.batch_id + 1,  # +1 because batch_id is 0-indexed
                        "stage": op.stage_id,
                        "start_time": op.start_time,
                        "duration": op.end_time - op.start_time,
                        "is_overlapped": True,
                        "operations": [
                            {
                                "type": nested_op.op_type,
                                "batch": nested_op.batch_id + 1,
                                "stage": nested_op.stage_id
                            }
                            for nested_op in op.operations
                        ]
                    }
                )
            else:
                visualization_data[device_id].append(
                    {
                        "type": op.op_type,
                        "batch": op.batch_id + 1,  # +1 because batch_id is 0-indexed
                        "stage": op.stage_id,
                        "start_time": op.start_time,
                        "duration": op.end_time - op.start_time,
                        "is_overlapped": False
                    }
                )

    return visualization_data


# Cache the color calculation as it's repeatedly called with the same parameters
@lru_cache(maxsize=128)
def get_color(op_type: str, stage_id: int, num_devices: int):
    # A more harmonious blue palette with low saturation and high brightness
    forward_colors = [
        "#0a5aff",  # Intense blue
        "#4c88ff",  # Blue (deeper)
        "#7aa7ff",  # Medium blue
        "#a8c5ff",  # Soft blue
        "#d6e4ff",  # Very light blue
    ]

    # Orange palette for backward operations with low saturation and high brightness
    backward_colors = [
        "#f47b00",  # Intense orange
        "#ffa952",  # Orange
        "#ffc78e",  # Light orange
        "#ffe6cc",  # Very light orange
    ]

    # Improved teal/turquoise palette with low saturation and high brightness
    backward_d_colors = [
        "#4dcccc",  # Light teal
        "#33b3b3",  # Teal
        "#009999",  # Medium teal
        "#008080",  # Dark teal
    ]

    # Improved green palette with low saturation and high brightness
    backward_w_colors = [
        "#33b373",  # Medium green
        "#009959",  # Forest green
        "#008040",  # Dark green
    ]
    
    virtual_stage = stage_id // num_devices

    # If virtual_stage is beyond our color list, cycle through the colors
    color_index = virtual_stage % len(forward_colors)

    if op_type == "forward":
        return forward_colors[color_index]
    elif op_type == "backward":
        return backward_colors[color_index % len(backward_colors)]
    elif op_type == "backward_D":
        return backward_d_colors[color_index % len(backward_d_colors)]
    elif op_type == "backward_W":
        return backward_w_colors[color_index % len(backward_w_colors)]
    else:
        raise ValueError(f"Invalid operation type: {op_type}")


def create_pipeline_figure(
    schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
):
    """
    Create a Plotly figure for pipeline parallelism scheduling.

    Args:
        schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
        max_time: Optional maximum time to display
        show_progress: Whether to show a progress bar
    """
    # Find the number of devices
    num_devices = len(schedule_data)

    empty_color = "whitesmoke"

    # Find the maximum time in the schedule if not provided
    if max_time is None:
        max_time = 0
        for device in schedule_data:
            for task in schedule_data[device]:
                end_time = task["start_time"] + task["duration"]
                if end_time > max_time:
                    max_time = end_time
    
    # Determine maximum batch number to decide whether to show text labels
    max_batch = 0
    for device in schedule_data:
        for task in schedule_data[device]:
            max_batch = max(max_batch, task["batch"])
    
    # Flag to determine whether to show text labels
    num_operations_per_device = len(schedule_data[0])
    show_text_labels = num_operations_per_device <= 64

    # Create a figure
    fig = go.Figure()

    # Initialize progress tracking
    total_tasks = sum(len(tasks) for tasks in schedule_data.values())
    tasks_processed = 0

    if show_progress:
        progress_bar = tqdm(
            total=total_tasks + num_devices + 3, desc="Creating visualization"
        )

    # Create a custom y-axis with no gaps between devices
    y_spacing = 1.0  # Use 1.0 for no gaps

    # Batch processing for increased performance
    shapes = []
    annotations = []
    hover_traces = []

    # Add rectangles for each task
    for device_idx, device in enumerate(schedule_data):
        device_idx_reversed = num_devices - device_idx - 1

        # Sort tasks by start time to ensure correct rendering
        sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])

        for task in sorted_tasks:
            # Calculate y positions with no gaps
            y_pos = device_idx_reversed * y_spacing
            start_time = task["start_time"]
            duration = task["duration"]
            
            # Special handling for overlapped operations
            if task.get("is_overlapped", False) and "operations" in task:
                # Prepare hover text for the entire overlapped operation
                op_details = "<br>".join([
                    f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
                    for op in task["operations"]
                ])
                hover_text = (
                    f"Overlapped Operations:<br>{op_details}<br>"
                    f"Start: {task['start_time']:.2f}<br>"
                    f"End: {task['start_time'] + task['duration']:.2f}<br>"
                    f"Duration: {task['duration']:.2f}"
                )
                
                # Add invisible marker for hover info
                hover_traces.append(
                    dict(
                        x=[start_time + duration / 2],
                        y=[y_pos],
                        mode="markers",
                        marker=dict(opacity=0),  # Invisible marker
                        hoverinfo="text",
                        text=hover_text,
                        showlegend=False,
                    )
                )
                
                # Calculate height of each sub-operation
                sub_height = 1.0 / len(task["operations"])
                
                # Add rectangles and annotations for each sub-operation
                for i, sub_op in enumerate(task["operations"]):
                    # Determine color for this sub-operation
                    color = get_color(sub_op["type"], sub_op["stage"], num_devices)
                    
                    # Calculate y position for this sub-operation
                    sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
                    sub_y_pos_top = sub_y_pos_bottom + sub_height
                    sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
                    
                    # Add rectangle for this sub-operation
                    shapes.append(
                        dict(
                            type="rect",
                            x0=start_time,
                            y0=sub_y_pos_bottom,
                            x1=start_time + duration,
                            y1=sub_y_pos_top,
                            line=dict(color="black", width=0.5),
                            fillcolor=color,
                            layer="above",
                        )
                    )
                    
                    # Add batch number text for this sub-operation only if show_text_labels is True
                    if show_text_labels:
                        # Determine text color based on background color
                        if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
                            text_color = "black"
                        else:
                            text_color = "white"
                            
                        annotations.append(
                            dict(
                                x=start_time + duration / 2,
                                y=sub_y_center,
                                text=f"{sub_op['batch']}",
                                showarrow=False,
                                font=dict(color=text_color, size=12, family="Arial, bold"),
                            )
                        )
            else:
                # Regular (non-overlapped) operation
                # Determine task color and text color
                if task["type"] == "forward":
                    color = get_color(task["type"], task["stage"], num_devices)
                    text_color = "white"
                    name = "Forward"
                elif task["type"] == "backward":
                    color = get_color(task["type"], task["stage"], num_devices)
                    text_color = "black"
                    name = "Backward"
                elif task["type"] == "backward_D":
                    color = get_color(task["type"], task["stage"], num_devices)
                    text_color = "black"
                    name = "Backward (Grad)"
                elif task["type"] == "backward_W":
                    color = get_color(task["type"], task["stage"], num_devices)
                    text_color = "black"
                    name = "Backward (Weight)"
                else:
                    color = empty_color
                    text_color = "black"
                    name = "Unknown"

                # Add rectangle for the task
                shapes.append(
                    dict(
                        type="rect",
                        x0=start_time,
                        y0=y_pos - 0.5,
                        x1=start_time + duration,
                        y1=y_pos + 0.5,
                        line=dict(color="black", width=0.5),
                        fillcolor=color,
                        layer="above",
                    )
                )

                # Add batch number text only if show_text_labels is True
                if show_text_labels:
                    annotations.append(
                        dict(
                            x=start_time + duration / 2,
                            y=y_pos,
                            text=f"{task['batch']}",
                            showarrow=False,
                            font=dict(color=text_color, size=12, family="Arial, bold"),
                        )
                    )

                # Prepare hover data
                hover_text = (
                    f"Batch: {task['batch']}<br>"
                    f"Stage: {task['stage']}<br>"
                    f"Type: {name}<br>"
                    f"Start: {task['start_time']:.2f}<br>"
                    f"End: {task['start_time'] + task['duration']:.2f}<br>"
                    f"Duration: {task['duration']:.2f}"
                )

                hover_traces.append(
                    dict(
                        x=[start_time + duration / 2],
                        y=[y_pos],
                        mode="markers",
                        marker=dict(opacity=0),  # Invisible marker
                        hoverinfo="text",
                        text=hover_text,
                        showlegend=False,
                    )
                )

            # Update progress
            if show_progress:
                tasks_processed += 1
                progress_bar.update(1)

    # Add all shapes at once for better performance
    fig.update_layout(shapes=shapes)

    # Add all annotations at once
    fig.update_layout(annotations=annotations)

    # Add all hover traces at once
    for trace in hover_traces:
        fig.add_trace(go.Scatter(**trace))

    # Add custom legend
    legend_items = []

    # Find the maximum virtual stage in the data
    max_virtual_stage = 0
    for device in schedule_data:
        for task in schedule_data[device]:
            virtual_stage = task["stage"] // num_devices
            max_virtual_stage = max(max_virtual_stage, virtual_stage)

    # Check if overlapped operations exist
    has_overlapped = any(
        task.get("is_overlapped", False)
        for device in schedule_data
        for task in schedule_data[device]
    )

    # Add forward and backward items for each virtual stage
    for vs in range(max_virtual_stage + 1):
        legend_items.append(
            dict(
                name=f"Forward (VS {vs})",
                color=get_color("forward", vs * num_devices, num_devices),
            )
        )
        legend_items.append(
            dict(
                name=f"Backward (VS {vs})",
                color=get_color("backward", vs * num_devices, num_devices),
            )
        )
        # Add entries for split backward operations if this is a zb1p schedule
        if any(
            task["type"] in ["backward_D", "backward_W"]
            for device in schedule_data
            for task in schedule_data[device]
        ):
            legend_items.append(
                dict(
                    name=f"Backward Grad (VS {vs})",
                    color=get_color("backward_D", vs * num_devices, num_devices),
                )
            )
            legend_items.append(
                dict(
                    name=f"Backward Weight (VS {vs})",
                    color=get_color("backward_W", vs * num_devices, num_devices),
                )
            )

    # If no tasks found, add default legend items
    if not legend_items:
        legend_items = [
            dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
            dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
            dict(
                name="Backward Grad (VS 0)",
                color=get_color("backward_D", 0, num_devices),
            ),
            dict(
                name="Backward Weight (VS 0)",
                color=get_color("backward_W", 0, num_devices),
            ),
        ]

    for i, item in enumerate(legend_items):
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="markers",
                marker=dict(size=10, color=item["color"]),
                name=item["name"],
                showlegend=True,
            )
        )
        if show_progress and i < len(legend_items) - 1:
            progress_bar.update(1)

    # Set axis properties
    device_labels = [f"Device {i+1}" for i in range(num_devices)]

    # Calculate tick positions with no gaps
    tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]

    # Adjust the range to ensure there are no empty spaces at the end
    x_end = max_time * 1.05  # Add a small margin

    title_text = "Pipeline Parallelism Schedule"

    fig.update_layout(
        yaxis=dict(
            tickmode="array",
            tickvals=tick_positions,
            ticktext=device_labels,
            showgrid=False,
            zeroline=False,
        ),
        margin=dict(l=50, r=20, t=40, b=40),
        plot_bgcolor="white",
        title=dict(
            text=title_text,
            x=0.5,
            y=0.98,  # Move title position closer to the top
            font=dict(size=20),
        ),
        legend=dict(
            orientation="v",  # Changed from horizontal to vertical
            yanchor="top",
            y=1.02,  # Position at the top
            xanchor="right",
            x=1.20,  # Position further to the right to accommodate more items
            title=dict(text="<b>Operation Types:</b>"),
            itemsizing="constant",
            tracegroupgap=0,
        ),
        width=2000,  # Increase width to accommodate the expanded legend
        height=400,  # Maintain current height
        bargap=0,
        bargroupgap=0,
    )

    if show_progress:
        progress_bar.update(1)
        progress_bar.close()

    return fig


# Cache for storing processed schedule data
_schedule_data_cache = {}


def create_dash_app(
    schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
):
    """
    Create a Dash app to visualize the pipeline schedule.

    Args:
        schedule: Schedule object to visualize
        schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
        enable_caching: Whether to cache the schedule data and figure
    """
    # Process schedule data only once and cache it
    global _schedule_data_cache
    cache_key = id(schedule)

    if enable_caching and cache_key in _schedule_data_cache:
        schedule_data = _schedule_data_cache[cache_key]
        print("Using cached schedule data")
    else:
        schedule_data = convert_schedule_to_visualization_format(schedule)
        if enable_caching:
            _schedule_data_cache[cache_key] = schedule_data
            print("Cached schedule data")

    total_tasks = sum(len(tasks) for tasks in schedule_data.values())
    print(f"Total tasks in schedule: {total_tasks}")

    app = dash.Dash(__name__)
    app.title = f"Pipeline Parallelism Visualization - {schedule_type}"

    # Create a more informative layout with data size information
    app.layout = html.Div(
        [
            html.H1(
                f"Pipeline Parallelism Visualization - {schedule_type}",
                style={"textAlign": "center"},
            ),
            html.Div(
                [
                    html.P(
                        f"Number of devices: {len(schedule_data)}",
                        style={"display": "inline-block", "marginRight": "20px"},
                    ),
                    html.P(
                        f"Total tasks: {total_tasks}",
                        style={"display": "inline-block", "marginRight": "20px"},
                    ),
                ],
                style={"marginBottom": "20px"},
            ),
            html.Div(id="graph-container", children=[]),
            dcc.Loading(
                id="loading-graph",
                type="circle",
                children=[
                    dcc.Graph(
                        id="pipeline-graph",
                        config={
                            "displayModeBar": True,
                            "toImageButtonOptions": {
                                "format": "png",
                                "filename": "pipeline_visualization",
                            },
                        },
                    ),
                ],
            ),
        ]
    )

    # Cache for storing figure to avoid regenerating it
    figure_cache = {}

    @app.callback(
        Output("pipeline-graph", "figure"),
        Input("graph-container", "children"),
        prevent_initial_call=False,
    )
    def load_graph(_):
        # Use cached figure if available
        cache_key = f"{id(schedule)}"
        if enable_caching and cache_key in figure_cache:
            print("Using cached figure")
            return figure_cache[cache_key]

        # Create the figure
        figure = create_pipeline_figure(schedule_data, show_progress=True)

        # Cache the figure
        if enable_caching:
            figure_cache[cache_key] = figure
            print("Cached figure")

        return figure

    return app


def visualize_pipeline_parallelism_dash(
    schedule: Schedule,
    port: int = 8050,
    debug: bool = False,
    enable_caching: bool = True,
    schedule_type="1f1b",
    open_browser: bool = True,
):
    """
    Launch a Dash app to visualize the pipeline schedule interactively.

    Args:
        schedule: Schedule object to visualize
        port: Port to run the Dash app on
        debug: Whether to run the Dash app in debug mode
        enable_caching: Whether to cache schedule data and figures
        schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
        open_browser: Whether to automatically open a browser window
    """
    app = create_dash_app(
        schedule, schedule_type=schedule_type, enable_caching=enable_caching
    )
    
    # Define function to open browser after a short delay
    def open_browser_tab():
        webbrowser.open_new_tab(f"http://localhost:{port}/")
    
    # Open browser automatically if requested
    if open_browser:
        # Use a timer to open the browser after the server has started
        Timer(1.0, open_browser_tab).start()
    
    print(f"Starting Dash app on http://localhost:{port}/")
    app.run_server(debug=debug, port=port)