📢 转载信息
原文作者:Adrian Tam
训练一个语言模型是内存密集型的,这不仅是因为模型本身很大,还因为训练数据批次中的序列很长。在内存受限的环境下训练模型是一项挑战。在本文中,您将学习到能够在内存受限环境中启用模型训练的技术。具体来说,您将了解到:
- 低精度浮点数和混合精度训练
- 使用梯度检查点(Gradient Checkpointing)
让我们开始吧!
使用混合精度和梯度检查点在内存受限环境下训练模型
图片来自 Meduana。保留部分权利。
概述
本文分为三个部分;它们是:
- 浮点数
- 自动混合精度训练
- 梯度检查点
让我们开始吧!
浮点数
PyTorch中的默认数据类型是IEEE 754 32位浮点格式,也称为单精度。它不是你可以使用的唯一浮点类型。例如,大多数CPU支持64位双精度浮点,GPU也通常支持半精度浮点。下表列出了一些浮点类型:
| 数据类型 | PyTorch 类型 | 总位数 | 符号位 | 指数位数 | 尾数位数 | 最小值 | 最大值 | eps |
|---|---|---|---|---|---|---|---|---|
| IEEE 754 双精度 | torch.float64 |
64 | 1 | 11 | 52 | -1.79769e+308 | 1.79769e+308 | 2.22045e-16 |
| IEEE 754 单精度 | torch.float32 |
32 | 1 | 8 | 23 | -3.40282e+38 | 3.40282e+38 | 1.19209e-07 |
| IEEE 754 半精度 | torch.float16 |
16 | 1 | 5 | 10 | -65504 | 65504 | 0.000976562 |
| bf16 | torch.bfloat16 |
16 | 1 | 8 | 7 | -3.38953e+38 | 3.38953e+38 | 0.0078125 |
| fp8 (e4m3) | torch.float8_e4m3fn |
8 | 1 | 4 | 3 | -448 | 448 | 0.125 |
| fp8 (e5m2) | torch.float8_e5m2 |
8 | 1 | 5 | 2 | -57344 | 57344 | 0.25 |
| fp8 (e8m0) | torch.float8_e8m0fnu |
8 | 1 | 8 | 0 | 1.70141e+38 | 5.87747e-39 | 1.0 |
| fp6 (e3m2) | 6 | 1 | 3 | 2 | -28 | 28 | 0.25 | |
| fp6 (e2m3) | 6 | 1 | 2 | 3 | -7.5 | 7.5 | 0.125 | |
| fp4 (e2m1) | 4 | 1 | 2 | 1 | -6 | 6 |
浮点数是实数的二进制表示。每个浮点数由一个符号位、几位用于指数的位数和几位用于尾数的位数组成。它们的布局如图所示。当按二进制表示排序时,浮点数保持其按实数值的顺序。
浮点数表示。图源自 Wikimedia。
不同的浮点类型具有不同的范围和精度。并非所有硬件都支持所有类型。例如,fp4仅在Nvidia的Blackwell架构中受支持。PyTorch仅支持少数几种数据类型。您可以运行以下代码来打印有关各种浮点类型的信息:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
import torch
from tabulate import tabulate
# float types:
float_types = [
torch.float64,
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e8m0fnu,
]
# collect finfo for each type
table = []
for dtype in float_types:
info = torch.finfo(dtype)
try:
typename = info.dtype
except:
typename = str(dtype)
table.append([typename, info.max, info.min, info.smallest_normal, info.eps])
headers = ['data type', 'max', 'min', 'smallest normal', 'eps']
print(tabulate(table, headers=headers))
|
请注意每种类型的最小值和最大值,以及 eps 值。最小值和最大值表示类型可以支持的范围(动态范围)。如果使用某种类型训练模型时,模型权重超出了此范围,您将遇到溢出或下溢,通常会导致模型输出 NaN 或 Inf。eps 值是最小的正数,使得类型能够区分 1+eps 和 1。这是精度的度量标准。如果模型的梯度更新小于 eps,您可能会观察到梯度消失问题。
因此,float32是深度学习的一个不错的默认选择:它具有宽广的动态范围和高精度。然而,每个float32数字需要4字节内存。作为权衡,您可以使用float16来节省内存,但由于动态范围小得多,您可能会遇到溢出或下溢问题。
Google Brain 团队发现了这个问题,并提出了bfloat16,这是一种16位浮点格式,具有与float32相同的动态范围。作为权衡,其精度比float16差一个数量级。事实证明,动态范围对于深度学习比精度更重要,这使得bfloat16非常有用。
当您在PyTorch中创建张量时,可以指定数据类型。例如:
|
1
2
|
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
print(x)
|
有一种简单的方法可以将默认类型更改为另一种类型,例如bfloat16。这对于模型训练非常方便。您需要做的就是在创建任何模型或优化器之前设置以下行:
|
1
2
|
# set default dtype to bfloat16
torch.set_default_dtype(torch.bfloat16)
|
仅通过这样做,您就强制所有模型权重和梯度都采用bfloat16类型。这节省了一半的内存。在上一篇文章中,我们建议将批次大小设置为8以适应只有12GB VRAM的GPU。使用bfloat16,您应该能够将批次大小设置为16。
请注意,尝试使用8位浮点或更低精度的类型可能无法工作。这是因为您需要硬件支持以及PyTorch来执行相应的数学运算。您可以尝试以下代码(需要CUDA设备),并发现您需要额外的努力才能在8位浮点上进行操作:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
dtype = torch.float8_e4m3fn
# Define a tensor with float8 will see
# NotImplementedError: "normal_kernel_cuda" not implemented for 'Float8_e4m3fn'
x = torch.randn(16, 16, dtype=dtype, device="cuda")
# Create in float32 and convert to float8 works
x = torch.randn(16, 16, device="cuda").to(dtype)
# But matmul is not supported. You will see
# NotImplementedError: "addmm_cuda" not implemented for 'Float8_e4m3fn'
y = x @ x.T
# The correct way to run matrix multiplication on 8-bit float
y = torch._scaled_mm(x, x.T, out_dtype=dtype,
scale_a=torch.tensor(1.0, device="cuda"),
scale_b=torch.tensor(1.0, device="cuda"))
print(y)
|
自动混合精度训练
使用float16训练模型可能会遇到问题,因为并非所有操作都应以较低精度执行。例如,矩阵乘法在较低精度下是稳健的,但归约操作、池化和某些激活函数需要float32。
您可以手动为模型的每个组件设置数据类型,但这很繁琐,因为您必须在组件之间转换数据类型。一个更好的解决方案是在PyTorch中使用自动混合精度训练。
PyTorch有一个子库torch.amp,可以根据操作自动转换数据类型。并非所有操作都以相同的浮点类型执行。如果已知某个操作在较低精度下是稳健的,该库将在运行操作之前将张量转换为该精度。因此得名“混合精度”。使用较低的精度不仅可以节省内存,还可以加快训练速度。某些GPU可以以float32两倍的速度运行float16操作。
当您使用torch.amp训练模型时,您需要做的就是在torch.amp.autocast()的上下文中运行前向传播。通常,您还会使用GradScaler来处理梯度缩放。这是必要的,因为在低精度下,您可能会由于浮点类型的精度有限而遇到梯度消失。GradScaler在反向传播之前缩放梯度,以防止梯度流失。在反向传播过程中,您应该将梯度按比例回退以进行准确的更新。这个过程可能很麻烦,因为您需要确定正确的比例因子,而GradScaler会为您处理这个问题。
与上一篇文章中的训练循环相比,下面是如何通常使用torch.amp训练模型的方法:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
...
# Check if mixed precision training is supported
assert torch.amp.autocast_mode.is_autocast_available("cuda")
# Creates a GradScaler before the training loop
scaler = torch.amp.GradScaler("cuda", enabled=True)
# start training
for epoch in range(begin_epoch, epochs):
pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch_id, batch in enumerate(pbar):
# get batched data
input_ids, target_ids = batch
# create attention mask: causal mask + padding mask
attn_mask = create_causal_mask(input_ids.shape[1], device) + \
create_padding_mask(input_ids, PAD_TOKEN_ID, device)
# with autocasting to bfloat16, run the forward pass
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(input_ids, attn_mask)
loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1))
# backward with loss, scaled by the GradScaler
optimizer.zero_grad()
scaler.scale(loss).backward()
# step the optimizer and check if the scale has been updated
scaler.step(optimizer)
old_scale = scaler.get_scale()
scaler.update()
if scaler.get_scale() < old_scale:
scheduler.step()
pbar.set_postfix(loss=loss.item())
pbar.update(1)
pbar.close()
|
使用AMP自动转换非常直接:将模型的默认精度保持为float32,然后用torch.autocast()包装前向传播和损失计算。在此上下文中,所有支持的操作将以指定的浮点类型运行。
一旦获得损失,就让GradScaler处理反向传播。它将放大损失并更新模型的梯度。然而,如果缩放太大,可能导致NaN或Inf梯度,从而引发问题。因此,使用scaler.step(optimizer)来推进优化器,它会在执行优化器步骤之前验证梯度。如果GradScaler决定不推进优化器,它会在调用update()时减小比例因子。检查比例因子是否已更新,以确定是否应该推进学习率调度器。
由于反向传播使用缩放后的损失,如果您使用梯度裁剪,则应在裁剪之前反向缩放梯度。操作方法如下:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
|
...
# backward with loss, scaled by the GradScaler
optimizer.zero_grad()
scaler.scale(loss).backward()
# unscaled the gradients and apply gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step the optimizer and check if the scale has been updated
scaler.step(optimizer
|
🚀 想要体验更好更全面的AI调用?
欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。
评论区