大模型显示优化之ZeRO-1/ZeRO-2/ZeRO-3 1. 简介zero-1、zero-2、zero-3 是deepspeed的配置方法对应megatron也有相应的方法Megatron-LM 的实现方式Distributed Optimizer分布式优化器。等效于 ZeRO-1Megatron 的 Distributed Optimizer 默认行为就是将优化器状态Optimizer States均匀地切分并分布在数据并行DP组的所有 GPU 上。等效于 ZeRO-2由于 Megatron 通常结合混合精度训练它在计算完梯度后会通过Reduce-Scatter操作直接将梯度同步并切分到各卡上不再保留全量梯度。这在效果上完全等同于 ZeRO-2。zero-3 将参数也拆分卡来存但后续实际反向梯度更新时操作时还是需要all-gather参数显存还是会全量缓存再一个Megatron针对参数拆分更多使用的是TP/PP拆分所以业界megatron架构使用zero-3不多, 所以本文不做重点分析。Zero架构说的是DP并行域GPU之间。阶段优化对象核心原理效果ZeRO-1优化器状态 (OS)将优化器状态切分并分布到各个 GPU 上每个 GPU 只负责更新自己那一块。显存占用降低约为原来的 1/4以 Adam 为例。ZeRO-2OS 梯度 (G)在 ZeRO-1 基础上进一步将梯度也进行切分。每个 GPU 只保留对应参数的梯度。进一步降低显存占用是目前最常用的平衡配置。ZeRO-3OS G 参数 (W)最彻底的切分。模型参数在平时也分布在不同 GPU 上只有在正向/反向传播需要时才临时同步。显存占用理论上随 GPU 数量线性下降支持训练超大规模模型。实际官方Megatron实现中ZeRO-2 反向不只是对梯度进行切分还对参数在back阶段进行了小段时间的切分后面AllGather回收是一个技术操作。这样好处1. 节省显存2. 避免冗余计算3. 最后的AllGather可以和后续的layer forward 做overlap纯 DPZeRO-2Forward各 rank 用完整 W各 rank 用完整 W相同Backward 后通信AllReduce梯度每人拿完整梯度ReduceScatter梯度每人只拿 1/DP 梯度显存也只存1/DP属于自己的梯度Optimizer step各自完整更新 W结果一致冗余计算各自只更新 W 的 1/DP 段(此更新过程比较复杂Step 后无需额外通信W 天然一致需要AllGather W恢复完整参数显存节省无梯度 优化器状态各节省 1/DP注意AdamW全局grad_norm路径通信方式时机标准路径all_reduceon model parallel groupTP × DPoptimizer.step() 内clip grad 前PP bypass 路径TP 内all_reduce PP 间send/recv逐 stage 累加pre_step 阶段流水线化减少同步 barrierAdamW 的step()中确实有一次全局 grad norm 的all_reduce通信用于计算全局 L2 norm 以确定clip_coeff梯度裁剪系数。这是每一步更新都必须做的集合通信会引入跨所有 model parallel rank 的同步点。2. 显存与通信量分析为了让 ZeRO-1 和 ZeRO-2 的区别更加直观我把之前流程图里的抽象内容具体化成了4 张 GPU 卡在不同阶段的显存状态。这样你可以像看“快照”一样清晰地看到每张卡上到底存了什么。设定假设模型有4个参数块[P0, P1, P2, P3]。4 张 GPU 卡训练。FP16训练的模型为例参数量为参数 (Weights):字节。梯度 (Grads):字节。优化器Adam 状态:FP32 权重副本为了精度。Momentum动量。Variance方差。场景一ZeRO-1 (只切分优化器状态)核心特征每张卡都有完整的参数和完整的梯度但只负责更新1/4的优化器状态。GPU 卡前向/反向计算时梯度通信后 (All-Reduce)参数更新后GPU 0参数:[P0, P1, P2, P3]梯度:[G0, G1, G2, G3]优化器状态:[O0](只负责P0)梯度:[G_avg0, G_avg1, G_avg2, G_avg3](已同步为平均梯度)*用G_avg0更新O0 计算出P0_new然后拼出完整参数[P0_new, P1_new, P2_new....]GPU 1参数:[P0, P1, P2, P3]梯度:[G0, G1, G2, G3]优化器状态:[O1](只负责P1)梯度:[G_avg0, G_avg1, G_avg2, G_avg3]*用G_avg1更新O1 计算出P1_new然后拼出完整参数[P0_new, P1_new, P2_new....]显存占用高。因为每张卡都要存下4份参数 4份梯度。冗余度高。P0被同时存在了 4 张卡上。场景二ZeRO-2 (切分梯度 优化器状态)核心特征每张卡有完整的参数但只保留1/4的梯度并只更新对应的1/4优化器状态。GPU 卡前向/反向计算时 (初始)梯度通信后 (Reduce-Scatter)参数更新后GPU 0参数:[P0, P1, P2, P3]梯度(原始):[G0, G1, G2, G3]优化器状态:[O0]梯度(保留):[G_avg0]梯度(丢弃):[G_avg1, G_avg2, G_avg3]✔️ 丢弃用G_avg0更新O0 计算出P0_new。然后通过 All-Gather 从其他卡获取 P1~P3 的更新。GPU 1参数:[P0, P1, P2, P3]梯度(原始):[G0, G1, G2, G3]优化器状态:[O1]梯度(保留):[G_avg1]梯度(丢弃):[G_avg0, G_avg2, G_avg3]✔️ 丢弃用G_avg1更新O1 计算出P1_new。然后通过 All-Gather 从其他卡获取 P0, P2, P3 的更新。显存占用中等。每张卡存4份参数 1份梯度。显存优化相比 ZeRO-1节省了 3 份梯度的存储空间。两张图的对比总结特征ZeRO-1 (图里场景)ZeRO-2 (图里场景)每张卡上的参数全部[P0, P1, P2, P3]全部[P0, P1, P2, P3]每张卡上的梯度全部[G_avg0...G_avg3](All-Reduce后)只有1块[G_avg0](Reduce-Scatter后)优化器状态分片[O0]分片[O0]参数更新方式各卡独立计算出完整参数各卡计算部分参数再互相广播合并主要节省不节省梯度节省了3/4的梯度显存通过这两张“快照”你应该能清晰地看到ZeRO-2 的本质就是用梯度通信后的一个“丢弃”动作换来了大量的显存空间。通信量总结维度ZeRO-1ZeRO-2ZeRO-3参数存储完整 (每卡都有)完整 (每卡都有)切分(每卡1/DP)梯度存储完整 (每卡都有)切分(每卡1/DP)切分(每卡1/DP)优化器状态切分 (每卡1/DP)切分 (每卡1/DP)切分 (每卡1/DP)单卡模型状态显存2Ψ 2Ψ 12Ψ/DP2Ψ 2Ψ/DP 12Ψ/DP(2Ψ2Ψ12Ψ)/DP主要通信All-Reduce (梯度)Reduce-Scatter All-GatherAll-Gather ×2 Reduce-Scatter通信量2×Ψ(最小)2×Ψ3×Ψ(最大)显存节省仅优化器状态优化器梯度全部3. Megatron ZeRO配置Stage分片内容Megatron对应参数ZeRO-1优化器状态分片m,v)--user-distributed-optimizerZeRO-2优化器分片梯度分片--user-distributed-optimizer--overlap-grad-reduceZeRO-3优化器分片梯度参数需要单独搞4. ZeRO2架构 backward过程计算梯度和更新参数的过程