from collections import defaultdict from typing import Dict, List, Optional, Union class Operation: """Operation is a single operation in the pipeline.""" def __init__(self, batch_id: int, stage_id: int, op_type: str): self.batch_id = batch_id self.stage_id = stage_id self.op_type = op_type self.device_id = None self.start_time = None self.end_time = None def set_end_time(self, end_time: float): self.end_time = end_time def set_start_time(self, start_time: float): self.start_time = start_time def __repr__(self) -> str: return f"Operation(batch_id={self.batch_id}, stage_id={self.stage_id}, op_type={self.op_type})" class OverlappedOperation: """Represents multiple operations that are overlapped/executed concurrently.""" def __init__(self, operations: List[Operation]): self.operations = operations self.device_id = operations[0].device_id # Validate all operations are on the same device for op in operations: assert op.device_id == self.device_id, "All operations must be on the same device" # Create a combined op_type (e.g., "overlapped_forward_backward") self.op_type = "overlapped_" + "_".join([op.op_type for op in operations]) # Use the batch_id and stage_id of the first operation for identification # (though we'll track all operations internally) self.batch_id = operations[0].batch_id self.stage_id = operations[0].stage_id # Initialize timing information self.start_time = None self.end_time = None def set_end_time(self, end_time: float): self.end_time = end_time for op in self.operations: op.set_end_time(end_time) def set_start_time(self, start_time: float): self.start_time = start_time for op in self.operations: op.set_start_time(start_time) def __repr__(self) -> str: op_str = ", ".join([f"({op.batch_id},{op.stage_id},{op.op_type})" for op in self.operations]) return f"OverlappedOperation([{op_str}])" class DeviceQueue: def __init__(self, stages: List[int], device_id: int): self.stages = stages self.device_id = device_id self.ops = [] # List of operations def add_operation(self, op: Operation): assert op.stage_id in self.stages self.ops.append(op) assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}" op.device_id = self.device_id class ScheduleConfig: def __init__( self, num_devices: int, num_stages: int, num_batches: int, p2p_latency: float = 0.0, placement_strategy: str = "standard", split_backward: bool = False, op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None, ): self.num_devices = num_devices self.num_stages = num_stages self.num_batches = num_batches self.p2p_latency = p2p_latency self.placement_strategy = placement_strategy self.split_backward = split_backward # Initialize default operation times if self.split_backward: self.op_times = { "forward": 1.0, "backward_D": 1.0, "backward_W": 1.0, "backward": 2.0, } else: self.op_times = { "forward": 1.0, "backward": 2.0, } # Update with user-provided operation times if op_times: for op_type, times in op_times.items(): if isinstance(times, dict): # If a dict is provided, it maps stage_id -> time if op_type not in self.op_times: self.op_times[op_type] = {} elif not isinstance(self.op_times[op_type], dict): # Convert float to dict if needed self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} # Update with provided stage-specific times for stage_id, time in times.items(): if not isinstance(self.op_times[op_type], dict): self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} self.op_times[op_type][stage_id] = time else: # If a float is provided, use same time for all stages self.op_times[op_type] = times assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices" self.num_stages_per_device = num_stages // num_devices self.init_device_to_stages() if self.placement_strategy == "dualpipe": assert ( sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2 ) else: assert ( sum(len(stages) for stages in self.device_to_stages.values()) == num_stages ) def init_device_to_stages(self): if self.placement_strategy == "standard": # Evenly distributed stages_per_device = self.num_stages // self.num_devices self.device_to_stages = defaultdict(list) for i in range(self.num_stages): device_to_put = i // stages_per_device self.device_to_stages[device_to_put].append(i) elif self.placement_strategy == "interleave": self.device_to_stages = defaultdict(list) for i in range(self.num_stages): device_to_put = i % self.num_devices self.device_to_stages[device_to_put].append(i) elif self.placement_strategy == "dualpipe": # For DualPipe, each device has two stages assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages" assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices" self.device_to_stages = defaultdict(list) for i in range(self.num_stages): self.device_to_stages[i] = [i, self.num_stages - i - 1] else: raise ValueError(f"Invalid placement strategy: {self.placement_strategy}") def get_op_time(self, op_type: str, stage_id: int): # For overlapped operations, extract the original operation types if op_type.startswith("overlapped_"): if op_type in self.op_times: if isinstance(self.op_times[op_type], dict): if stage_id in self.op_times[op_type]: return self.op_times[op_type][stage_id] else: raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}") else: return self.op_times[op_type] else: op_parts = op_type.split("_")[1:] if len(op_parts) >= 2: op_type1, op_type2 = op_parts[0], op_parts[1] return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id) if op_type not in self.op_times: raise ValueError(f"Invalid operation type: {op_type}") times = self.op_times[op_type] if isinstance(times, dict): # If we have stage-specific times, use those if stage_id not in times: raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}") return times[stage_id] else: # If we have a single float, use the same value for all stages return times class Schedule: def __init__(self, config: ScheduleConfig, init_ops: bool = True): self.ops = {} # (batch_id, stage_id, op_type) -> Operation self.device_queues: List[DeviceQueue] = [] for dev_id in range(config.num_devices): self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id)) self.config = config if init_ops: self.init_operations() self.op_to_overlapped = {} def register_overlapped_operation(self, overlapped_op: OverlappedOperation): for op in overlapped_op.operations: self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op def register_operation(self, op: Operation): assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered" self.ops[(op.batch_id, op.stage_id, op.op_type)] = op def init_operations(self): op_types = ["forward", "backward"] if self.config.split_backward: op_types = ["forward", "backward_D", "backward_W"] for batch_id in range(self.config.num_batches): for stage_id in range(self.config.num_stages): for op_type in op_types: self.ops[(batch_id, stage_id, op_type)] = Operation( batch_id, stage_id, op_type ) def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False): if (batch_id, stage_id, op_type) in self.op_to_overlapped: return self.op_to_overlapped[(batch_id, stage_id, op_type)] if allow_none: if (batch_id, stage_id, op_type) not in self.ops: return None return self.ops[(batch_id, stage_id, op_type)] def get_dependencies(self, op: Operation, include_device_dependency=True): deps = [] if isinstance(op, OverlappedOperation): for sub_op in op.operations: deps.extend(self.get_dependencies(sub_op, include_device_dependency=False)) if include_device_dependency: device_index = self.device_queues[op.device_id].ops.index(op) if device_index > 0: deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0)) return deps if op.op_type == "forward": if op.stage_id > 0: deps.append( ( self.get_op(op.batch_id, op.stage_id - 1, "forward"), self.config.p2p_latency, ) ) if self.config.split_backward: if op.op_type == "backward_D": if op.stage_id < self.config.num_stages - 1: op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True) if op_bwd_d is not None: deps.append( ( op_bwd_d, self.config.p2p_latency, ) ) else: deps.append( ( self.get_op(op.batch_id, op.stage_id + 1, "backward"), self.config.p2p_latency, ) ) elif op.op_type == "backward_W": if op.stage_id < self.config.num_stages - 1: op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True) if op_bwd_d is not None: deps.append( ( op_bwd_d, self.config.p2p_latency, ) ) else: deps.append( ( self.get_op(op.batch_id, op.stage_id, "backward"), self.config.p2p_latency, ) ) elif op.op_type == "backward": if op.stage_id < self.config.num_stages - 1: op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True) if op_bwd is not None: deps.append( ( op_bwd, self.config.p2p_latency, ) ) else: deps.append( ( self.get_op(op.batch_id, op.stage_id + 1, "backward_D"), self.config.p2p_latency, ) ) else: if op.op_type == "backward": if op.stage_id < self.config.num_stages - 1: deps.append( ( self.get_op(op.batch_id, op.stage_id + 1, "backward"), self.config.p2p_latency, ) ) if include_device_dependency: device_index = self.device_queues[op.device_id].ops.index(op) if device_index > 0: deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0)) return deps def show(self): """Display detailed information about the schedule for debugging purposes.""" print("\n=== SCHEDULE DETAILS ===") print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}") print(f"Placement Strategy: {self.config.placement_strategy}") print("\n=== DEVICE QUEUES ===") for dev_id in range(self.config.num_devices): print(f"\nDEVICE {dev_id} (Stages: {self.device_queues[dev_id].stages}):") print("-" * 80) print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}") print("-" * 80) for op in self.device_queues[dev_id].ops: op_type = op.op_type start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A" end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A" duration = "N/A" if op.start_time is not None and op.end_time is not None: duration = f"{op.end_time - op.start_time:.2f}" print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}") # Find the total execution time (if timing info is available) if all(op.end_time is not None for op in self.ops.values()): total_time = max(op.end_time for op in self.ops.values()) print(f"\nTotal execution time: {total_time:.2f}") def execute(self): # TODO: change the execution order to topological order via DAG def execute_op(op: Operation): if op.end_time is not None: return deps = self.get_dependencies(op) if len(deps) == 0: op.set_start_time(0.0) else: for dep, gap in deps: if dep.end_time is None or dep.start_time is None: execute_op(dep) op.set_start_time(max(dep.end_time + gap for dep, gap in deps)) op.set_end_time(op.start_time + self.config.get_op_time( op.op_type, op.stage_id )) op_num = len(self.device_queues[0].ops) for i in range(op_num): for dev_id in range(self.config.num_devices): if len(self.device_queues[dev_id].ops) <= i: continue op = self.device_queues[dev_id].ops[i] execute_op(op) for op in self.ops.values(): assert ( op.start_time is not None ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time" assert ( op.end_time is not None ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time" def get_total_execution_time(self): return max(op.end_time for op in self.ops.values()) def get_bubble_rate(self): actual_time = self.get_total_execution_time() ideal_time = 0 for stage_id in range(self.config.num_stages): for op_type in ["forward", "backward"]: ideal_time += self.config.get_op_time(op_type, stage_id) ideal_time = ideal_time * self.config.num_batches / self.config.num_devices return (actual_time - ideal_time) / ideal_time def get_device_running_time(self): device_time = [0] * self.config.num_devices for dev_id in range(self.config.num_devices): for op in self.device_queues[dev_id].ops: device_time[dev_id] += op.end_time - op.start_time return device_time