PP-schedule-visualizer / src /strategies.py
Victarry's picture
Initial commit: PP schedule visualization.
c048b97
from collections import defaultdict, deque
from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
def generate_1f1b_schedule(config: ScheduleConfig):
schedule = Schedule(config)
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
for i in range(config.num_devices):
fwd_batch_id = 0
bwd_batch_id = 0
cooldown_batches = warmup_batches = config.num_devices - i - 1
steady_batches = config.num_batches - warmup_batches
for _ in range(warmup_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(fwd_batch_id, i, "forward")
)
fwd_batch_id += 1
for _ in range(steady_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(fwd_batch_id, i, "forward")
)
fwd_batch_id += 1
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_batch_id, i, "backward")
)
bwd_batch_id += 1
for _ in range(cooldown_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_batch_id, i, "backward")
)
bwd_batch_id += 1
return schedule
def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
schedule = Schedule(config)
total_batches = config.num_batches
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
assert config.split_backward, "ZB-1P requires split_backward=True"
for i in range(config.num_devices):
fwd_batch_id = 0
bwd_d_batch_id = 0
bwd_w_batch_id = 0
cooldown_batches = warmup_batches = config.num_devices - i - 1
steady_batches = total_batches - warmup_batches
for _ in range(warmup_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(fwd_batch_id, i, "forward")
)
fwd_batch_id += 1
for _ in range(steady_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(fwd_batch_id, i, "forward")
)
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_d_batch_id, i, "backward_D")
)
if fwd_batch_id - bwd_w_batch_id >= config.num_devices - 1:
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_w_batch_id, i, "backward_W")
)
bwd_w_batch_id += 1
bwd_d_batch_id += 1
fwd_batch_id += 1
for _ in range(cooldown_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_d_batch_id, i, "backward_D")
)
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_w_batch_id, i, "backward_W")
)
bwd_w_batch_id += 1
bwd_d_batch_id += 1
while bwd_w_batch_id < total_batches:
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_w_batch_id, i, "backward_W")
)
bwd_w_batch_id += 1
return schedule
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
schedule = Schedule(config)
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
for i in range(config.num_devices):
fwd_batch_id = 0
bwd_batch_id = 0
cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
steady_batches = config.num_batches - warmup_batches
for _ in range(warmup_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(fwd_batch_id, i, "forward")
)
fwd_batch_id += 1
for _ in range(steady_batches):
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
schedule.register_overlapped_operation(overlapped_op)
schedule.device_queues[i].add_operation(overlapped_op)
fwd_batch_id += 1
bwd_batch_id += 1
for _ in range(cooldown_batches):
schedule.device_queues[i].add_operation(
schedule.get_op(bwd_batch_id, i, "backward")
)
bwd_batch_id += 1
return schedule
def _get_pp_rank_microbatches(
num_microbatches,
num_devices,
device_id,
num_stages_per_device,
microbatch_group_size_per_vp_stage,
):
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
total_num_microbatches = num_microbatches * num_stages_per_device
if num_devices > 1:
# Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
num_warmup_microbatches = (num_devices - device_id - 1) * 2
num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
else:
# forward_backward_no_pipelining
num_warmup_microbatches = 1
if num_warmup_microbatches >= total_num_microbatches:
num_warmup_microbatches = total_num_microbatches
return num_warmup_microbatches
def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
"""Get the schedule table for PP scheduling.
Create a tunable schedule lookup table.
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
microbatch_id | 0 1 2 0 1 2 3 4 3 4
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
"""
schedule_table = []
for min_microbatch_id_in_group in range(
0, num_microbatches, microbatch_group_size_per_vp_stage
):
if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
# Construct schedule for the last microbatch group
schedule_table.extend(
[
(microbatch_id, model_chunk_id)
for model_chunk_id in range(num_model_chunks)
for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
]
)
else:
# Construct schedule for other microbatch groups
schedule_table.extend(
[
(microbatch_id, model_chunk_id)
for model_chunk_id in range(num_model_chunks)
for microbatch_id in range(
min_microbatch_id_in_group,
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
)
]
)
return schedule_table
def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
microbatch_id | 0 1 2 0 1 2 3 4 3 4
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
Then the forward backward separated order is:
forward | 1 1 1 2 2 2 1 1 2 2
backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
If num_warmup_microbatches is 5, the output order is:
1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
"""
_, model_chunk_id_table = zip(*schedule_table)
forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
order = forward_order[:num_warmup_microbatches]
for i in range(num_warmup_microbatches, len(forward_order)):
order.append(forward_order[i])
order.append(backward_order[i - num_warmup_microbatches])
if num_warmup_microbatches > 0:
order.extend(backward_order[-num_warmup_microbatches:])
return order
# Some codes are copied from Megatron-LM
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
schedule = Schedule(config)
for device_id in range(config.num_devices):
microbatch_group_size_per_vp_stage = config.num_devices
num_warmup_microbatches = _get_pp_rank_microbatches(
config.num_batches,
config.num_devices,
device_id,
config.num_stages_per_device,
microbatch_group_size_per_vp_stage,
)
schedule_table = _get_schedule_table(
config.num_batches,
config.num_stages_per_device,
microbatch_group_size_per_vp_stage,
)
order = _convert_schedule_table_to_order(
num_warmup_microbatches,
num_model_chunks=config.num_stages_per_device,
schedule_table=schedule_table,
)
cur_stage_microbatch_id = {}
for i in range(1, config.num_stages_per_device+1):
cur_stage_microbatch_id[i] = 0
cur_stage_microbatch_id[-i] = 0
for order_item in order:
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
if order_item > 0:
op_type = "forward"
micro_batch_id = cur_stage_microbatch_id[order_item]
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
elif order_item < 0:
op_type = "backward"
micro_batch_id = cur_stage_microbatch_id[order_item]
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
else:
raise ValueError(f"Invalid order item: {order_item}")
schedule.device_queues[device_id].add_operation(
schedule.get_op(micro_batch_id, stage_id, op_type)
)
return schedule
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
schedule = Schedule(config)
for device_id in range(config.num_devices):
microbatch_group_size_per_vp_stage = config.num_devices
num_warmup_microbatches = _get_pp_rank_microbatches(
config.num_batches,
config.num_devices,
device_id,
config.num_stages_per_device,
microbatch_group_size_per_vp_stage,
)
schedule_table = _get_schedule_table(
config.num_batches,
config.num_stages_per_device,
microbatch_group_size_per_vp_stage,
)
# NOTE: Add one more warmup microbatch for overlapped operations!
num_warmup_microbatches += 1
order = _convert_schedule_table_to_order(
num_warmup_microbatches,
num_model_chunks=config.num_stages_per_device,
schedule_table=schedule_table,
)
cur_stage_microbatch_id = {}
for i in range(1, config.num_stages_per_device+1):
cur_stage_microbatch_id[i] = 0
cur_stage_microbatch_id[-i] = 0
i = 0
num_overlapped_batches = len(order) - num_warmup_microbatches * 2
while i < len(order):
if i < num_warmup_microbatches:
order_item = order[i]
assert order_item > 0
op_type = "forward"
micro_batch_id = cur_stage_microbatch_id[order_item]
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
schedule.device_queues[device_id].add_operation(
schedule.get_op(micro_batch_id, stage_id, op_type)
)
i += 1
elif i >= num_warmup_microbatches and i < num_warmup_microbatches + num_overlapped_batches - 1:
order_item_a = order[i]
order_item_b = order[i+1]
op_type_a = "forward" if order_item_a > 0 else "backward"
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
cur_stage_microbatch_id[order_item_a] = cur_stage_microbatch_id[order_item_a] + 1
op_type_b = "forward" if order_item_b > 0 else "backward"
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
cur_stage_microbatch_id[order_item_b] = cur_stage_microbatch_id[order_item_b] + 1
stage_id_a = schedule.device_queues[device_id].stages[abs(order_item_a)-1]
stage_id_b = schedule.device_queues[device_id].stages[abs(order_item_b)-1]
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
overlapped_op = OverlappedOperation([op_a, op_b])
schedule.register_overlapped_operation(overlapped_op)
schedule.device_queues[device_id].add_operation(overlapped_op)
i += 2
else:
assert i >= num_warmup_microbatches + num_overlapped_batches
order_item = order[i]
assert order_item < 0
op_type = "backward"
micro_batch_id = cur_stage_microbatch_id[order_item]
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
schedule.device_queues[device_id].add_operation(
schedule.get_op(micro_batch_id, stage_id, op_type)
)
i += 1
return schedule
def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
"""
Helper function to create overlapped operations correctly.
This handles the underlying operation creation and registration to avoid device_id issues.
"""
# Get the operations from the schedule
op1 = schedule.ops[(batch_id1, stage_id, type1)]
op2 = schedule.ops[(batch_id2, stage_id, type2)]
# Create the overlapped operation
overlapped_op = OverlappedOperation([op1, op2])
# Register in the schedule to ensure proper tracking
schedule.register_overlapped_operation(overlapped_op)
return overlapped_op
def generate_dualpipe_schedule(config: ScheduleConfig):
"""
Implements the DualPipe scheduling strategy.
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
and backward computation-communication phases and reduces pipeline bubbles.
The DualPipe strategy has the following characteristics:
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
2. Each device handles both a forward stage and a reverse stage
3. Overlaps forward and backward operations to reduce bubble size
4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
5. Currently only supports split_backward=True.
Args:
config: The scheduling configuration
Returns:
A Schedule object with the DualPipe scheduling
"""
# Ensure placement strategy is set for Schedule initialization
assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
# Assertions based on DualPipe requirements
assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
# Assertion based on original implementation: num_chunks >= num_ranks * 2
# Here, M (config.num_batches) corresponds to half_num_chunks
assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
schedule = Schedule(config, init_ops=False)
num_stages = config.num_stages
num_devices = config.num_devices
# config.num_batches is M in the original paper, which corresponds to half_num_chunks
half_num_chunks = config.num_batches // 2
num_half_ranks = num_devices // 2
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
for device_id in range(num_devices):
is_in_second_half = device_id >= num_half_ranks
if is_in_second_half:
fwd_batch_ids[device_id, 1] = 0
fwd_batch_ids[device_id, 0] = config.num_batches // 2
bwd_d_batch_ids[device_id, 1] = 0
bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
else:
fwd_batch_ids[device_id, 0] = 0
fwd_batch_ids[device_id, 1] = config.num_batches // 2
bwd_d_batch_ids[device_id, 0] = 0
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
if not is_in_second_half:
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
return stage_fwd_dir if phase == 0 else stage_rev_dir
else:
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
return stage_rev_dir if phase == 0 else stage_fwd_dir
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
# Retrieve the correct pre-initialized Operation object
op = Operation(batch_id, stage_id, op_type)
schedule.register_operation(op)
# Add to the device queue
schedule.device_queues[device_id].add_operation(op)
def _schedule_forward_chunk(device_id, phase, is_in_second_half):
"""Schedules a forward compute operation."""
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
batch_id = fwd_batch_ids[device_id, phase]
add_op_to_queue(device_id, stage_id, "forward", batch_id)
fwd_batch_ids[device_id, phase] += 1
def _schedule_backward_chunk(device_id, phase, is_in_second_half):
"""Schedules a backward_D with backward_W compute operation."""
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
batch_id = bwd_d_batch_ids[device_id, phase]
add_op_to_queue(device_id, stage_id, "backward", batch_id)
bwd_d_batch_ids[device_id, phase] += 1
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
"""Schedules a backward_D compute operation."""
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
batch_id = bwd_d_batch_ids[device_id, phase]
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
bwd_d_batch_ids[device_id, phase] += 1
waited_weight_grad[device_id].append((stage_id, batch_id))
def _schedule_backward_weight_chunk(device_id):
"""Schedules a backward_W compute operation."""
stage_id, batch_id = waited_weight_grad[device_id].popleft()
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
"""Schedules an overlapped forward and backward_D compute operation."""
fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
schedule.register_operation(fwd_op)
fwd_batch_ids[device_id, fwd_phase] += 1
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
schedule.register_operation(bwd_op)
bwd_d_batch_ids[device_id, bwd_phase] += 1
# Create and register the overlapped operation
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
schedule.register_overlapped_operation(overlapped_op)
# Add the overlapped operation to the queue
schedule.device_queues[device_id].add_operation(overlapped_op)
# Process each device (rank in original code)
for device_id in range(num_devices):
half_rank = min(device_id, num_devices - 1 - device_id)
is_in_second_half = device_id >= num_half_ranks
is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
# Map original steps to operation additions
# Step 1: nF0
step_1_count = (num_half_ranks - half_rank - 1) * 2
for _ in range(step_1_count):
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
# Step 2: nF0F1
step_2_count = half_rank + 1
for i in range(step_2_count):
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
# Step 3: nB1W1F1
step_3_count = num_half_ranks - half_rank - 1
for _ in range(step_3_count):
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
_schedule_backward_weight_chunk(device_id,) # W1
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
# Step 4 (Main step): nF0B1F1B0
step_4_count = half_num_chunks - num_devices + half_rank + 1
for i in range(step_4_count):
# if i == 0 and is_middle_rank:
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
# else:
# Overlap F0 and B1_D, then schedule W1
_schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
# Overlap F1 and B0_D, then schedule W0
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
# Step 5: nB1F1B0
step_5_count = num_half_ranks - half_rank - 1
for _ in range(step_5_count):
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
# Step 6: nB1B0
step_6_count = half_rank + 1
enable_zb = False
for i in range(step_6_count):
if i == step_6_count // 2 and half_rank % 2 == 1:
enable_zb = True
if enable_zb:
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
else:
_schedule_backward_chunk(device_id, 1, is_in_second_half)
if i == step_6_count // 2 and half_rank % 2 == 0:
enable_zb = True
if enable_zb:
_schedule_backward_input_chunk(device_id, 0, is_in_second_half)
else:
_schedule_backward_chunk(device_id, 0, is_in_second_half)
# Step 7: nWB0
step_7_count = num_half_ranks - half_rank - 1
for _ in range(step_7_count):
_schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
# Step 8: nW
step_8_count = half_rank + 1
for _ in range(step_8_count):
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
# The last W0 gradients correspond to B0_D from step 6 or 7.
_schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
return schedule