Spaces:
Running
Running
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
class Operation: | |
"""Operation is a single operation in the pipeline.""" | |
def __init__(self, batch_id: int, stage_id: int, op_type: str): | |
self.batch_id = batch_id | |
self.stage_id = stage_id | |
self.op_type = op_type | |
self.device_id = None | |
self.start_time = None | |
self.end_time = None | |
def set_end_time(self, end_time: float): | |
self.end_time = end_time | |
def set_start_time(self, start_time: float): | |
self.start_time = start_time | |
def __repr__(self) -> str: | |
return f"Operation(batch_id={self.batch_id}, stage_id={self.stage_id}, op_type={self.op_type})" | |
class OverlappedOperation: | |
"""Represents multiple operations that are overlapped/executed concurrently.""" | |
def __init__(self, operations: List[Operation]): | |
self.operations = operations | |
self.device_id = operations[0].device_id | |
# Validate all operations are on the same device | |
for op in operations: | |
assert op.device_id == self.device_id, "All operations must be on the same device" | |
# Create a combined op_type (e.g., "overlapped_forward_backward") | |
self.op_type = "overlapped_" + "_".join([op.op_type for op in operations]) | |
# Use the batch_id and stage_id of the first operation for identification | |
# (though we'll track all operations internally) | |
self.batch_id = operations[0].batch_id | |
self.stage_id = operations[0].stage_id | |
# Initialize timing information | |
self.start_time = None | |
self.end_time = None | |
def set_end_time(self, end_time: float): | |
self.end_time = end_time | |
for op in self.operations: | |
op.set_end_time(end_time) | |
def set_start_time(self, start_time: float): | |
self.start_time = start_time | |
for op in self.operations: | |
op.set_start_time(start_time) | |
def __repr__(self) -> str: | |
op_str = ", ".join([f"({op.batch_id},{op.stage_id},{op.op_type})" for op in self.operations]) | |
return f"OverlappedOperation([{op_str}])" | |
class DeviceQueue: | |
def __init__(self, stages: List[int], device_id: int): | |
self.stages = stages | |
self.device_id = device_id | |
self.ops = [] # List of operations | |
def add_operation(self, op: Operation): | |
assert op.stage_id in self.stages | |
self.ops.append(op) | |
assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}" | |
op.device_id = self.device_id | |
class ScheduleConfig: | |
def __init__( | |
self, | |
num_devices: int, | |
num_stages: int, | |
num_batches: int, | |
p2p_latency: float = 0.0, | |
placement_strategy: str = "standard", | |
split_backward: bool = False, | |
op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None, | |
): | |
self.num_devices = num_devices | |
self.num_stages = num_stages | |
self.num_batches = num_batches | |
self.p2p_latency = p2p_latency | |
self.placement_strategy = placement_strategy | |
self.split_backward = split_backward | |
# Initialize default operation times | |
if self.split_backward: | |
self.op_times = { | |
"forward": 1.0, | |
"backward_D": 1.0, | |
"backward_W": 1.0, | |
"backward": 2.0, | |
} | |
else: | |
self.op_times = { | |
"forward": 1.0, | |
"backward": 2.0, | |
} | |
# Update with user-provided operation times | |
if op_times: | |
for op_type, times in op_times.items(): | |
if isinstance(times, dict): | |
# If a dict is provided, it maps stage_id -> time | |
if op_type not in self.op_times: | |
self.op_times[op_type] = {} | |
elif not isinstance(self.op_times[op_type], dict): | |
# Convert float to dict if needed | |
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} | |
# Update with provided stage-specific times | |
for stage_id, time in times.items(): | |
if not isinstance(self.op_times[op_type], dict): | |
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} | |
self.op_times[op_type][stage_id] = time | |
else: | |
# If a float is provided, use same time for all stages | |
self.op_times[op_type] = times | |
assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices" | |
self.num_stages_per_device = num_stages // num_devices | |
self.init_device_to_stages() | |
if self.placement_strategy == "dualpipe": | |
assert ( | |
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2 | |
) | |
else: | |
assert ( | |
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages | |
) | |
def init_device_to_stages(self): | |
if self.placement_strategy == "standard": | |
# Evenly distributed | |
stages_per_device = self.num_stages // self.num_devices | |
self.device_to_stages = defaultdict(list) | |
for i in range(self.num_stages): | |
device_to_put = i // stages_per_device | |
self.device_to_stages[device_to_put].append(i) | |
elif self.placement_strategy == "interleave": | |
self.device_to_stages = defaultdict(list) | |
for i in range(self.num_stages): | |
device_to_put = i % self.num_devices | |
self.device_to_stages[device_to_put].append(i) | |
elif self.placement_strategy == "dualpipe": | |
# For DualPipe, each device has two stages | |
assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages" | |
assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices" | |
self.device_to_stages = defaultdict(list) | |
for i in range(self.num_stages): | |
self.device_to_stages[i] = [i, self.num_stages - i - 1] | |
else: | |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}") | |
def get_op_time(self, op_type: str, stage_id: int): | |
# For overlapped operations, extract the original operation types | |
if op_type.startswith("overlapped_"): | |
if op_type in self.op_times: | |
if isinstance(self.op_times[op_type], dict): | |
if stage_id in self.op_times[op_type]: | |
return self.op_times[op_type][stage_id] | |
else: | |
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}") | |
else: | |
return self.op_times[op_type] | |
else: | |
op_parts = op_type.split("_")[1:] | |
if len(op_parts) >= 2: | |
op_type1, op_type2 = op_parts[0], op_parts[1] | |
return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id) | |
if op_type not in self.op_times: | |
raise ValueError(f"Invalid operation type: {op_type}") | |
times = self.op_times[op_type] | |
if isinstance(times, dict): | |
# If we have stage-specific times, use those | |
if stage_id not in times: | |
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}") | |
return times[stage_id] | |
else: | |
# If we have a single float, use the same value for all stages | |
return times | |
class Schedule: | |
def __init__(self, config: ScheduleConfig, init_ops: bool = True): | |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation | |
self.device_queues: List[DeviceQueue] = [] | |
for dev_id in range(config.num_devices): | |
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id)) | |
self.config = config | |
if init_ops: | |
self.init_operations() | |
self.op_to_overlapped = {} | |
def register_overlapped_operation(self, overlapped_op: OverlappedOperation): | |
for op in overlapped_op.operations: | |
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op | |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op | |
def register_operation(self, op: Operation): | |
assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered" | |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = op | |
def init_operations(self): | |
op_types = ["forward", "backward"] | |
if self.config.split_backward: | |
op_types = ["forward", "backward_D", "backward_W"] | |
for batch_id in range(self.config.num_batches): | |
for stage_id in range(self.config.num_stages): | |
for op_type in op_types: | |
self.ops[(batch_id, stage_id, op_type)] = Operation( | |
batch_id, stage_id, op_type | |
) | |
def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False): | |
if (batch_id, stage_id, op_type) in self.op_to_overlapped: | |
return self.op_to_overlapped[(batch_id, stage_id, op_type)] | |
if allow_none: | |
if (batch_id, stage_id, op_type) not in self.ops: | |
return None | |
return self.ops[(batch_id, stage_id, op_type)] | |
def get_dependencies(self, op: Operation, include_device_dependency=True): | |
deps = [] | |
if isinstance(op, OverlappedOperation): | |
for sub_op in op.operations: | |
deps.extend(self.get_dependencies(sub_op, include_device_dependency=False)) | |
if include_device_dependency: | |
device_index = self.device_queues[op.device_id].ops.index(op) | |
if device_index > 0: | |
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0)) | |
return deps | |
if op.op_type == "forward": | |
if op.stage_id > 0: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id - 1, "forward"), | |
self.config.p2p_latency, | |
) | |
) | |
if self.config.split_backward: | |
if op.op_type == "backward_D": | |
if op.stage_id < self.config.num_stages - 1: | |
op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True) | |
if op_bwd_d is not None: | |
deps.append( | |
( | |
op_bwd_d, | |
self.config.p2p_latency, | |
) | |
) | |
else: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id + 1, "backward"), | |
self.config.p2p_latency, | |
) | |
) | |
elif op.op_type == "backward_W": | |
if op.stage_id < self.config.num_stages - 1: | |
op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True) | |
if op_bwd_d is not None: | |
deps.append( | |
( | |
op_bwd_d, | |
self.config.p2p_latency, | |
) | |
) | |
else: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id, "backward"), | |
self.config.p2p_latency, | |
) | |
) | |
elif op.op_type == "backward": | |
if op.stage_id < self.config.num_stages - 1: | |
op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True) | |
if op_bwd is not None: | |
deps.append( | |
( | |
op_bwd, | |
self.config.p2p_latency, | |
) | |
) | |
else: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id + 1, "backward_D"), | |
self.config.p2p_latency, | |
) | |
) | |
else: | |
if op.op_type == "backward": | |
if op.stage_id < self.config.num_stages - 1: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id + 1, "backward"), | |
self.config.p2p_latency, | |
) | |
) | |
if include_device_dependency: | |
device_index = self.device_queues[op.device_id].ops.index(op) | |
if device_index > 0: | |
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0)) | |
return deps | |
def show(self): | |
"""Display detailed information about the schedule for debugging purposes.""" | |
print("\n=== SCHEDULE DETAILS ===") | |
print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}") | |
print(f"Placement Strategy: {self.config.placement_strategy}") | |
print("\n=== DEVICE QUEUES ===") | |
for dev_id in range(self.config.num_devices): | |
print(f"\nDEVICE {dev_id} (Stages: {self.device_queues[dev_id].stages}):") | |
print("-" * 80) | |
print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}") | |
print("-" * 80) | |
for op in self.device_queues[dev_id].ops: | |
op_type = op.op_type | |
start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A" | |
end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A" | |
duration = "N/A" | |
if op.start_time is not None and op.end_time is not None: | |
duration = f"{op.end_time - op.start_time:.2f}" | |
print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}") | |
# Find the total execution time (if timing info is available) | |
if all(op.end_time is not None for op in self.ops.values()): | |
total_time = max(op.end_time for op in self.ops.values()) | |
print(f"\nTotal execution time: {total_time:.2f}") | |
def execute(self): | |
# TODO: change the execution order to topological order via DAG | |
def execute_op(op: Operation): | |
if op.end_time is not None: | |
return | |
deps = self.get_dependencies(op) | |
if len(deps) == 0: | |
op.set_start_time(0.0) | |
else: | |
for dep, gap in deps: | |
if dep.end_time is None or dep.start_time is None: | |
execute_op(dep) | |
op.set_start_time(max(dep.end_time + gap for dep, gap in deps)) | |
op.set_end_time(op.start_time + self.config.get_op_time( | |
op.op_type, op.stage_id | |
)) | |
op_num = len(self.device_queues[0].ops) | |
for i in range(op_num): | |
for dev_id in range(self.config.num_devices): | |
if len(self.device_queues[dev_id].ops) <= i: | |
continue | |
op = self.device_queues[dev_id].ops[i] | |
execute_op(op) | |
for op in self.ops.values(): | |
assert ( | |
op.start_time is not None | |
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time" | |
assert ( | |
op.end_time is not None | |
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time" | |
def get_total_execution_time(self): | |
return max(op.end_time for op in self.ops.values()) | |
def get_bubble_rate(self): | |
actual_time = self.get_total_execution_time() | |
ideal_time = 0 | |
for stage_id in range(self.config.num_stages): | |
for op_type in ["forward", "backward"]: | |
ideal_time += self.config.get_op_time(op_type, stage_id) | |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices | |
return (actual_time - ideal_time) / ideal_time | |
def get_device_running_time(self): | |
device_time = [0] * self.config.num_devices | |
for dev_id in range(self.config.num_devices): | |
for op in self.device_queues[dev_id].ops: | |
device_time[dev_id] += op.end_time - op.start_time | |
return device_time | |