CANN 8.5 之前ops-transformer 仓库的 FlashAttention 只融合了前向传播。推理没问题训练就尴尬了——反向传播还得拆成三个独立 kernel 分别算 dQ、dK、dV中间结果全落显存。CANN 8.5 的 FlashAttention V2 把反向传播也融合了训练场景的吞吐直接拉了 30%。V1 的反向传播为什么没融合FlashAttention V1 的前向融合相对好做Q·K^T 的分块 Softmax 结果存在片上缓存直接乘 V 输出。反向传播不一样它需要前向的 Softmax 中间结果来算梯度。V1 的做法是前向把 Softmax 的归一化因子存下来叫O_scale反向再读出来用。问题在于O_scale的存储格式。前向分块计算时每个分块的O_scale是按 block 顺序存的反向计算梯度时需要按完整的行来读。这个读写模式的不匹配让 V1 没法把 dQ/dK/dV 融进同一个 kernel——数据对不齐。V2 怎么解决的FlashAttention V2 换了一个反向传播算法。不再依赖前向存的O_scale而是在反向 kernel 里重新算一遍 Softmax 的归一化因子。听起来像是浪费计算——多算了一次 Softmax——但省掉了显存读写。在昇腾NPU上这个取舍特别划算。达芬奇架构的 Cube 单元算力充裕Vector 单元做 Softmax 也快但 HBM 到片上缓存的数据搬运是瓶颈。多用 5% 的计算换掉 40% 的显存读写这笔账怎么算都值。具体实现上V2 的反向 kernel 流程1. 读入 dO来自上层的梯度 2. 重新计算 Softmax 归一化因子Vector 单元 3. 分块计算 dV Softmax(Q·K^T)^T · dO 4. 分块计算 dK Q^T · (Softmax(Q·K^T) · dO) 5. 分块计算 dQ (Softmax(Q·K^T) · dO) · V^T 6. 三组梯度在片上缓存完成一次性写回显存步骤 3-5 在同一个 kernel 里流水执行dV/dK/dQ 共享中间结果不需要反复读取 Q、K、V。实测数据Atlas 800I A2Llama2-7B 训练序列长度 4096配置训练吞吐 (tokens/s/p)显存占用kernel launch 次数/层FlashAttention V1前向融合反向拆分1,82056 GB前向1反向34FlashAttention V2全融合2,41044 GB前向1反向12kernel launch 次数从 4 降到 2减少了 2 次调度开销。显存省 21% 来自不再存储O_scale中间结果。Llama2-70B 的数据更夸张配置训练吞吐 (tokens/s/p)最大序列长度V14208K显存不够V256016K显存省下来直接撑到 16K 序列长度。V1 在 8K 以上就 OOM 了。迁移方法CANN 8.5 的torch_npu自动把 SDPA 路由到 V2不需要改代码。但如果你直接调了npu.flash_attention的反向相关接口有个参数变更# V1 写法CANN 8.0outtorch_npu.npu.flash_attention(q,k,v)# 反向自动拆分无法控制# V2 写法CANN 8.5— scale 改成关键字参数outtorch_npu.npu.flash_attention(q,k,v,scale1.0/math.sqrt(dim))# 反向自动融合不需要额外调用如果你用了 ATB 做训练框架ATB 0.8 默认走 V2 路径。ATB 0.7 及以下只能走 V1。只做推理要升级吗不需要。V2 的改进全部在反向传播。如果你的 NPU 只跑推理V1 和 V2 的前向性能完全一致升级没有收益。训练场景下 FlashAttention V2 是刚需30% 的吞吐提升和 21% 的显存节省相当于白捡半张卡的算力。CANN 8.5 torch_npu 2.3 就能用仓库在这里https://atomgit.com/cann/ops-transformer
CANN-FlashAttentionV2-昇腾NPU反向传播融合到底快在哪
发布时间:2026/5/23 16:45:36
CANN 8.5 之前ops-transformer 仓库的 FlashAttention 只融合了前向传播。推理没问题训练就尴尬了——反向传播还得拆成三个独立 kernel 分别算 dQ、dK、dV中间结果全落显存。CANN 8.5 的 FlashAttention V2 把反向传播也融合了训练场景的吞吐直接拉了 30%。V1 的反向传播为什么没融合FlashAttention V1 的前向融合相对好做Q·K^T 的分块 Softmax 结果存在片上缓存直接乘 V 输出。反向传播不一样它需要前向的 Softmax 中间结果来算梯度。V1 的做法是前向把 Softmax 的归一化因子存下来叫O_scale反向再读出来用。问题在于O_scale的存储格式。前向分块计算时每个分块的O_scale是按 block 顺序存的反向计算梯度时需要按完整的行来读。这个读写模式的不匹配让 V1 没法把 dQ/dK/dV 融进同一个 kernel——数据对不齐。V2 怎么解决的FlashAttention V2 换了一个反向传播算法。不再依赖前向存的O_scale而是在反向 kernel 里重新算一遍 Softmax 的归一化因子。听起来像是浪费计算——多算了一次 Softmax——但省掉了显存读写。在昇腾NPU上这个取舍特别划算。达芬奇架构的 Cube 单元算力充裕Vector 单元做 Softmax 也快但 HBM 到片上缓存的数据搬运是瓶颈。多用 5% 的计算换掉 40% 的显存读写这笔账怎么算都值。具体实现上V2 的反向 kernel 流程1. 读入 dO来自上层的梯度 2. 重新计算 Softmax 归一化因子Vector 单元 3. 分块计算 dV Softmax(Q·K^T)^T · dO 4. 分块计算 dK Q^T · (Softmax(Q·K^T) · dO) 5. 分块计算 dQ (Softmax(Q·K^T) · dO) · V^T 6. 三组梯度在片上缓存完成一次性写回显存步骤 3-5 在同一个 kernel 里流水执行dV/dK/dQ 共享中间结果不需要反复读取 Q、K、V。实测数据Atlas 800I A2Llama2-7B 训练序列长度 4096配置训练吞吐 (tokens/s/p)显存占用kernel launch 次数/层FlashAttention V1前向融合反向拆分1,82056 GB前向1反向34FlashAttention V2全融合2,41044 GB前向1反向12kernel launch 次数从 4 降到 2减少了 2 次调度开销。显存省 21% 来自不再存储O_scale中间结果。Llama2-70B 的数据更夸张配置训练吞吐 (tokens/s/p)最大序列长度V14208K显存不够V256016K显存省下来直接撑到 16K 序列长度。V1 在 8K 以上就 OOM 了。迁移方法CANN 8.5 的torch_npu自动把 SDPA 路由到 V2不需要改代码。但如果你直接调了npu.flash_attention的反向相关接口有个参数变更# V1 写法CANN 8.0outtorch_npu.npu.flash_attention(q,k,v)# 反向自动拆分无法控制# V2 写法CANN 8.5— scale 改成关键字参数outtorch_npu.npu.flash_attention(q,k,v,scale1.0/math.sqrt(dim))# 反向自动融合不需要额外调用如果你用了 ATB 做训练框架ATB 0.8 默认走 V2 路径。ATB 0.7 及以下只能走 V1。只做推理要升级吗不需要。V2 的改进全部在反向传播。如果你的 NPU 只跑推理V1 和 V2 的前向性能完全一致升级没有收益。训练场景下 FlashAttention V2 是刚需30% 的吞吐提升和 21% 的显存节省相当于白捡半张卡的算力。CANN 8.5 torch_npu 2.3 就能用仓库在这里https://atomgit.com/cann/ops-transformer