深度学习中的并行策略概述:3 Pipeline Parallelism
使用 PyTorch 的分布式流水线并行(torch.distributed.pipelining)来分割和调度一个 Transformer 模型的计算。首先,定义了一个 ModelArgs 类来存储模型参数,然后创建了一个 Transformer 类,它继承自 nn.Module 并包含了嵌入层、Transformer解码层、层归一化和输出线性层。接着,定义了一个 init_distributed 函数来初始化分布式环境,并设置了进程组和设备。manual_model_split 函数用于手动分割模型,以便在不同的阶段执行模型的不同部分。在主函数中,初始化了分布式环境,创建了模型和虚拟数据,并将模型手动分割为两个阶段。然后,将模型和数据移动到指定的设备上,并定义了一个损失函数。最后,使用 ScheduleGPipe 类来调度流水线并行的执行,并在不同的阶段计算损失。代码中还包含了一个销毁进程组的步骤,以确保在程序结束时正确清理资源。
import torch
import torch.nn as nn
from dataclasses import dataclass@dataclass
class ModelArgs:dim: int = 512n_layers: int = 8n_heads: int = 8vocab_size: int = 10000class Transformer(nn.Module):def __init__(self, model_args: ModelArgs):super().__init__()self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)# Using a ModuleDict lets us delete layers witout affecting names,# ensuring checkpoints will correctly save and load.self.layers = torch.nn.ModuleDict()for layer_id in range(model_args.n_layers):self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)self.norm = nn.LayerNorm(model_args.dim)self.output = nn.Linear(model_args.dim, model_args.vocab_size)def forward(self, tokens: torch.Tensor):# Handling layers being 'None' at runtime enables easy pipeline splittingh = self.tok_embeddings(tokens) if self.tok_embeddings else tokensfor layer in self.layers.values():h = layer(h, h)h = self.norm(h) if self.norm else houtput = self.output(h).clone() if self.output else hreturn outputimport os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipeglobal rank, device, pp_group, stage_index, num_stages
def init_distributed():global rank, device, pp_group, stage_index, num_stagesrank = int(os.environ["LOCAL_RANK"])world_size = int(os.environ["WORLD_SIZE"])device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")dist.init_process_group()# This group can be a sub-group in the N-D parallel casepp_group = dist.new_group()stage_index = ranknum_stages = world_sizedef manual_model_split(model) -> PipelineStage:if stage_index == 0:# prepare the first stage modelfor i in range(4, 8):del model.layers[str(i)]model.norm = Nonemodel.output = Noneelif stage_index == 1:# prepare the second stage modelfor i in range(4):del model.layers[str(i)]model.tok_embeddings = Nonestage = PipelineStage(model,stage_index,num_stages,device,)return stageif __name__ == "__main__":init_distributed()num_microbatches = 4model_args = ModelArgs()model = Transformer(model_args)# Dummy datax = torch.ones(32, 500, dtype=torch.long)y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)example_input_microbatch = x.chunk(num_microbatches)[0]# Option 1: Manual model splittingstage = manual_model_split(model)# Option 2: Tracer model splitting# stage = tracer_model_split(model, example_input_microbatch)model.to(device)x = x.to(device)y = y.to(device)def tokenwise_loss_fn(outputs, targets):loss_fn = nn.CrossEntropyLoss()outputs = outputs.reshape(-1, model_args.vocab_size)targets = targets.reshape(-1)return loss_fn(outputs, targets)schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)if rank == 0:schedule.step(x)elif rank == 1:losses = []output = schedule.step(target=y, losses=losses)print(f"losses: {losses}")dist.destroy_process_group()
调用:
torchrun --nnodes 1 --nproc_per_node 2 pipelining_xxx.py