目 录CONTENT

文章目录

使用张量并行在多GPU上训练大型模型

Administrator
2026-01-01 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

📢 转载信息

原文链接:https://machinelearningmastery.com/train-your-large-model-on-multiple-gpus-with-tensor-parallelism/

原文作者:Adrian Tam


张量并行(Tensor parallelism)是一种模型并行技术,它沿特定维度对张量进行分片,并将张量的计算分布到多个设备上,同时将通信开销降至最低。这项技术非常适合那些参数张量非常大,以至于单个矩阵乘法都无法在单个GPU上容纳的模型。在本文中,您将学习如何使用张量并行。具体来说,您将了解到:

  • 什么是张量并行
  • 如何设计张量并行方案
  • 如何在PyTorch中应用张量并行

让我们开始吧!

使用张量并行在多GPU上训练您的大型模型。
照片作者:Seth kane。保留部分权利。

概述

本文分为五个部分:

  • 张量并行示例
  • 设置张量并行
  • 准备用于张量并行的模型
  • 使用张量并行训练模型
  • 将张量并行与FSDP结合

张量并行示例

张量并行起源于Megatron-LM论文。此技术并非适用于所有操作;然而,某些操作(例如矩阵乘法)通过分片计算来实现。

列式张量并行:您将权重 $\mathbf{W}$ 分片为多列,并执行矩阵乘法 $\mathbf{XW}=\mathbf{Y}$ 来产生需要拼接的分片输出。

让我们考虑一个简单的矩阵-矩阵乘法操作,如下所示:

这是一个 $3\times 4$ 矩阵 $\mathbf{X}$ 乘以一个 $4\times 6$ 矩阵 $\mathbf{W}$ 以产生一个 $3\times 6$ 矩阵 $\mathbf{Y}$。您确实可以将其分解为两次矩阵乘法:一次是 $\mathbf{X}$ 乘以一个 $4\times 3$ 矩阵 $\mathbf{W}_1$ 以产生一个 $3\times 3$ 矩阵 $\mathbf{Y}_1$,另一次是 $\mathbf{X}$ 乘以另一个 $3\times 2$ 矩阵 $\mathbf{W}_2$ 以产生一个 $3\times 3$ 矩阵 $\mathbf{Y}_2$。然后最终的 $\mathbf{Y}$ 是 $\mathbf{Y}_1$ 和 $\mathbf{Y}_2$ 的拼接。

您可以看到,在这种情况下,您无需托管大型矩阵 $\mathbf{W}$,而是处理其较小的分片。这节省了内存。每个分片的输出更小,因此与其他设备通信速度也更快。

上述情况称为列式并行(column-wise parallel)。您可以将其推广,沿列维度对矩阵 $\mathbf{W}$ 进行多于两次的分片。

另一种变体是行式并行(row-wise parallel),如下例所示:

行式张量并行:您将权重 $\mathbf{W}$ 分片为多行,并执行矩阵乘法 $\mathbf{XW}=\mathbf{Y}$ 来产生需要逐元素相加的局部输出。

对于相同的 $\mathbf{XW}=\mathbf{Y}$ 矩阵乘法,您现在将 $\mathbf{X}$ 分割成列,将 $\mathbf{W}$ 分割成行。在上面的插图中,左半部分的 $\mathbf{X}$($3\times 2$ 矩阵)与 $\mathbf{W}$ 上半部分($2\times 6$ 矩阵)相乘,产生一个 $3\times 6$ 矩阵。输出形状与完整方程相同,但值仅对应于 $\mathbf{W}$ 的上半部分。通过 $\mathbf{X}$ 的右半部分和 $\mathbf{W}$ 的下半部分重复相同的操作,然后将两个结果相加,即可得到最终输出 $\mathbf{Y}$。

在行式并行中,您处理 $\mathbf{X}$ 和 $\mathbf{W}$ 的分片。工作量比执行完整的矩阵乘法要轻。输出比列式并行更大,需要更多的带宽将结果通信给其他设备。

深度学习模型中并非所有操作都是矩阵乘法。因此,张量并行并非适用于模型中的每个元素。某些操作,如激活函数或归一化层,可以以不同的方式并行化。对于无法并行化的操作,您的模型必须以其原始形式计算它们。

张量并行带来的好处不仅仅是节省内存——它还提供了对计算和通信模式的细粒度控制。由于矩阵乘法是分片的,您可以控制是否取消分片结果,从而在直接使用分片的 DTensor 对象时避免通信开销。

设置张量并行

PyTorch中的张量并行是分布式框架的一部分。因此,脚本的启动方式与数据并行流水线并行全分片数据并行一样,使用 torchrun 命令。您需要像往常一样初始化分布式环境,但还需要设置设备网格(device mesh):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import torch
import torch.distributed as dist
 
# Initialize the distributed environment
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
rank = dist.get_rank()
world_size = dist.get_world_size()
 
# Initialize the mesh for tensor parallelism
mesh = dist.device_mesh.init_device_mesh(
    "cuda",
    (world_size,),
)

设备网格是进程组的高级抽象。在后面的部分中,您将明白为什么需要它。这是必要的,因为您使用网格将模型包装成一个张量并行模型,使用以下语法:

1
2
3
from torch.distributed.tensor.parallel import parallelize_module
 
model = parallelize_module(model, mesh, tp_plan)

此后,模型的一些权重将被替换为 DTensor 对象,使操作能够并行运行。内部可能会执行分布式集体操作,如全收集(all-gather)、全规约(all-reduce)和规约-分片(reduce-scatter),以使模型工作起来就像它在单个设备上一样。这些细节对您来说是透明的

准备用于张量并行的模型

让我们看看如何使用张量并行来运行上一篇文章中训练脚本。${br}$
${br}$

将模型转换为使用张量并行运行时,不需要更改模型架构。相反,您需要知道模型中每个模块和子模块的完全限定名称。这些名称与模型 state_dict() 中的键相同,或者您可以回顾模型架构代码以查找模块名称。

您需要识别这些名称来创建并行化方案。这是一个 Python 字典,它将模块名称映射到 ParallelStyle 对象。示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
tp_plan = {
    "input_layernorm": SequenceParallel(),
    "self_attn": PrepareModuleInput(
        input_layouts=Shard(dim=1),  # only one position arg will be used
        desired_input_layouts=Replicate(),
    ),
    # Q/K/V output will be used with GQA, prefer to be replicated
    "self_attn.q_proj": ColwiseParallel(output_layouts=Replicate()),
    "self_attn.k_proj": ColwiseParallel(output_layouts=Replicate()),
    "self_attn.v_proj": ColwiseParallel(output_layouts=Replicate()),
    "self_attn.o_proj": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
    "post_attention_layernorm": SequenceParallel(),
    "mlp": PrepareModuleInput(
        input_layouts=Shard(dim=1),
        desired_input_layouts=Replicate(),
    ),
    "mlp.gate_proj": ColwiseParallel(),
    "mlp.up_proj": ColwiseParallel(),
    "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
}
 
with torch.device("meta"):
    model_config = LlamaConfig()
    model = LlamaForPretraining(model_config)
 
for layer in model.base_model.layers:
    parallelize_module(layer, mesh, tp_plan)

模型架构代码与上一篇文章中完全相同。字典 tp_plan 是从 Transformer 块的角度创建的。请注意,每个块的声明如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        ...
        # Linear layers for Q, K, V projections
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
    ...
 
class LlamaMLP(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.act_fn = F.silu  # SwiGLU activation function
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
    ...
 
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
        self.self_attn = LlamaAttention(config)
        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
        self.mlp = LlamaMLP(config)
    ...

LlamaDecoderLayer 类的角度来看,您可以找到一个 nn.Linear 层,如 layer.mlp.gate_proj,因此您在 tp_plan 字典中使用键 mlp.gate_proj 来引用它。

当您使用张量并行时,应该意识到 PyTorch 不知道如何并行化任意模块,例如您定义的自定义模块。但是,少数标准模块具有张量并行的具体实现:

  • nn.Linearnn.Embedding:您可以使用 ColwiseParallelRowwiseParallel 来并行化它们。操作如上所述。
  • nn.LayerNormnn.RMSNormnn.Dropout:您可以使用 SequenceParallel 来并行化它们。

随着 PyTorch 的不断发展,请参阅官方文档以获取张量并行的最新覆盖范围。

设置好 tp_plan 后,您可以使用 parallelize_module() 函数将其应用于模型。这将用 DTensor 对象替换模块的参数,并使模块以分片计算方式运行,就像它在单个设备上一样。上面的 for 循环出于方便将 parallelize_module() 应用于每个 Transformer 块。您也可以将其应用于整个基础模型,但由于 Transformer 块在基础模型中重复多次,您需要对 tp_plan 进行大量重复的更新。

Transformer 块包含两个 RMS 范数层。您使用 SequenceParallel 标记它们,以指示它们应沿序列维度(输入张量的维度 1)分片。

Transformer 块包含多个 nn.Linear 层:前馈子层中有三个,注意力子层中有四个。您可以使用 ColwiseParallelRowwiseParallel 标记它们以实现并行化,但有一些需要注意的事项:

  • ColwiseParallel 默认期望一个完整的张量作为输入,输出沿最后一个维度分片,使输出张量变小。
  • RowwiseParallel 默认期望一个沿最后一个维度分片的张量作为输入,输出是一个完整大小的张量,与未并行化时的大小相同。

这就是您指定以下内容的原因:

1
2
3
{ ... "mlp.gate_proj": ColwiseParallel(), "mlp.up_proj": ColwiseParallel(), "m... [内容被截断]



🚀 想要体验更好更全面的AI调用?

欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。

0

评论区