LLM·AMP自动混合精度训练 文章目录完整AMP训练流程混合精度训练的优势节省显存占用激活值是大头加快推理速度低精度加速低精度(bf16,fp16)的问题溢出问题FP16 的数值限制梯度下溢underflow梯度上溢overflow大数吃小数问题混合精度训练的流程Loss scaling实现完整AMP训练流程[初始化阶段]Model Weights:FP32 master weights优化器维护真实参数 Runtime Weights:FP16 weights由FP32 cast得到用于forward/backward[Forwardautocast开启]Input:FP32 ↓进入 autocast[算子1matmul/conv]→ cast 为 FP16利用Tensor Core加速 → 输出FP16[算子2exp/softmax/layernorm 等数值敏感操作]→ 自动提升为 FP32防止数值不稳定如overflow/精度损失 → 输出FP32[算子3继续 matmul]→ cast 回 FP16 → 输出FP16...最终 loss → 通常为 FP32loss/reduction类操作一般在FP32[Loss Scaling]loss(FP32)↓ × scale如1024 scaledloss(FP32)# 目的放大梯度避免FP16下溢变0[Backward]scaledloss(FP32)↓ backward自动根据forward路径传播 梯度传播过程-来自 FP16 forward 分支 → 梯度通常以 FP16 表示易受下溢影响-来自 FP32 forward 分支 → 梯度计算中涉及 FP32更稳定 最终得到 gradients大多数存储为 FP16部分计算路径为 FP32[Unscale类型转换]FP16 gradients ↓ cast FP32 gradients ↓ ÷ scaleunscale恢复真实梯度 得到 真实梯度FP32[参数更新Optimizer Step]FP32 gradients ↓ 更新 FP32 master weights高精度累积更新 # 关键原因#FP16无法表示微小更新会导致训练停滞[同步回计算图]更新后的 FP32 master weights ↓ cast FP16 weights供下一轮 forward 使用[整体数据流总结]FP32 master weights ↓ cast FP16 weights ↓ forwardFP16为主FP32关键算子 ↓ lossFP32 ↓ scale scaled lossFP32 ↓ backward gradFP16为主 ↓ castunscale gradFP32 ↓ update master weightsFP32 ↓ 再 cast → FP16进入下一轮[核心要点]1.FP16用于“计算加速”forward/backward主路径2.FP32用于“数值稳定”loss/softmax/norm等3.FP32用于“参数更新”避免精度丢失4.loss scaling 仅用于解决 FP16 梯度下溢问题混合精度训练的优势节省显存占用激活值是大头激活值是关键大头而且与batch_size和序列长度有关因此应该考虑节省这些中间值的梯度占用FP32参数:1GB activation:6GB gradient:3GB----------------总计:10GBAMP参数:FP32 master:1GB FP16 copy:0.5GB activation:3GB减半 gradient:1.5GB减半----------------总计:6GB加快推理速度低精度加速NVIDIA的显卡对于低精度BF16/FP16有专门加速精度吞吐量FP321×FP162×8×低精度(bf16,fp16)的问题溢出问题混合精度训练Mixed Precision Training的核心是用FP16或BF16进行大部分计算用FP32保留关键数值稳定性。但由于 FP16 的表示范围和精度有限会引入一系列典型问题。下面用具体例子说明这些问题。FP16 的数值限制FP16IEEE half precision指数位5 bit →范围小尾数10 bit →精度低大致范围最大值≈ (6.5 \times 10^4)最小正数正规≈ (6 \times 10^{-8})梯度下溢underflow假设某层梯度为g 1 × 10 − 8 g 1 \times 10^{-8}g1×10−8FP32可以表示 ✔️FP16直接变成 0 ❌梯度上溢overflowg 1 × 10 5 g 1 \times 10^5g1×105FP16最大值 ≈ 65504超出范围 →变成 inf大数吃小数问题浮点数加减法需要先对齐指数再比对小数部分但是小数部分往往有限数值差距悬殊时容易出现小数部分溢出。2048 1.0 × 2 11 2048 1.0 \times 2^{11}20481.0×2110.5 1.0 × 2 − 1 0.5 1.0 \times 2^{-1}0.51.0×2−1对齐指数0.5 1.0 × 2 − 1 0.000000000001 × 2 11 0.5 1.0 \times 2^{-1} 0.000000000001 \times 2^{11}0.51.0×2−10.000000000001×211但BF16 只有 7 位尾数2048: 1.0000000 × 2^11 0.5 : 0.000000000001 × 2^11 被截断结果2048 0.5 ≈ 2048 2048 0.5 \approx 204820480.5≈20480.5 被完全忽略混合精度训练的流程在涉及前向过程中可以使用低精度。在涉及梯度更新过程中优化器会保存较高精度的模型参数和并使用较高精度的梯度值。对于高精度的数学计算操作例如矩阵加法softmax操作强制使用高精度。Loss scaling动机低精度的高数值部分使用较少可以考虑将梯度更新缩放N倍确保不会溢出然后转换位低精度来进行更新。Loss scaling 的本质是在反向传播前对loss 进行数值放大使得在FP16 精度下计算的梯度不会发生下溢在得到梯度后再进行反缩放unscale恢复真实梯度并在 FP32 master weight 上进行更新从而保证数值稳定性和更新精度。实现with torch.cuda.amp.autocast(device_typecuda, dtypetorch.float16)模型会在这个上下文下优先选择FP16精度对于需要高精度的操作例如softmax计算过程仍然使用FP32不受影响。scaler.scale(loss).backward()缩放损失确保FP16的梯度不会溢出。#PyTorch AMPFP16训练片段scalertorch.cuda.amp.GradScaler()forx,y in dataloader:optimizer.zero_grad()with torch.cuda.amp.autocast(device_typecuda,dtypetorch.float16):outmodel(x)losscriterion(out,y)scaler.scale(loss).backward()# scaled loss 反传 scaler.step(optimizer)# 内部完成 unscaleupdateFP32 master scaler.update()# 动态调整 scale