去年我们把一个13B参数的推理服务从GPU迁移到昇腾NPUattention部分从标准实现换成catlass模板的FlashAttention吞吐从1,200 tokens/s提到4,800 tokens/s。但这个过程不是换个模板就完事——数据布局、精度对齐、分块策略、算子融合每一步都有坑。今天把整个调优过程记录下来包含具体的配置参数和实测数据。背景为什么选catlasscatlass不是CUTLASS的昇腾移植版它是昇腾CANN体系内的算子模板库定位是给开发者提供高性能算子的开发骨架。ops-nn、ops-math、ops-blas这些算子仓库底层都依赖catlass的模板。选catlass而不是直接写Ascend C算子原因很简单手动写一个达芬奇架构上高性能的FlashAttention你需要处理分块加载、Unified Buffer管理、bank conflict规避、流水线调度……一个人搞可能要一两个月。catlass模板把这些封装好了你只需要调参数。但调参数这三个字背后的事也不少。精度选择FP16还是BF16第一个决策点。昇腾910支持FP16和BF16两种半精度catlass模板两种都支持。选择依据维度FP16BF16表示范围±65504±3.4×10³⁸尾数精度10位7位softmax溢出风险高指数容易超65504低累加精度损失低高达芬奇算力利用率更高略低我们的场景是推理softmax中间值容易爆FP16的范围。实测数据# FP16 FlashAttention序列长度8192 [ERROR] softmax overflow detected, batch2 head15 tile_m48 # 17个tile中有3个溢出输出NaN # BF16 FlashAttention同样配置 [PASS] no overflow, max softmax value 1.2e38 # 所有tile正常所以长序列场景直接用BF16省去溢出排查的麻烦。短序列2048以内FP16精度更好推理结果跟FP32的误差更小。我们的折中方案4K以内FP164K以上BF16。FlashAttnConfig config; if (seq_len 4096) { config.use_fp16 true; } else { config.use_fp16 false; // 启用BF16 }分块策略不是越大越好catlass模板的核心参数是block_m和block_n控制Q和K/V的分块大小。直觉上block越大并行度越高性能越好。但达芬奇架构的约束不允许你无限加大约束1Unified Buffer容量达芬奇架构的Unified Buffer大约256KB具体大小随芯片版本略有差异。一个tile的数据量 block_m × head_dim × sizeof(data_type) × 3QKV。加上中间变量实际占用大概是这个值的2-3倍。block_m128, head_dim128, FP16: 单tile 128 × 128 × 2 × 3 96KB 加上softmax统计量和O的累加buffer ≈ 200KB → 勉强能塞进去 block_m256, head_dim128, FP16: 单tile 256 × 128 × 2 × 3 192KB 加上中间变量 ≈ 420KB → 超了超了会怎样catlass模板不会报错而是自动降级——把一个tile拆成多次加载性能反而比block_m128更差。约束2K/V的复用模式FlashAttention的outer loop是沿M方向Q的序列方向遍历inner loop是沿N方向K/V的序列方向。每个Q的tile要跟所有K/V的tile做计算。所以K/V的tile会被反复加载block_n越大单次加载的数据量越大但加载次数越少。block_mblock_nK/V加载次数(Q单tile)单次加载量(KB)实测吞吐12812832323,4001286464163,8002566464164,2001283212883,200block_n64比128快因为小tile的cache命中率更高。block_n32太碎了调度开销吃掉了cache收益。block_m256block_n64是最优组合但要确认Unified Buffer够用。数据布局这步做错后面全白搭catlass模板要求输入数据的layout是[batch, heads, seq_len, head_dim]row-major存储stride必须128字节对齐。PyTorch默认的tensor layout恰好满足但如果你从其他框架MindSpore、Paddle传入数据大概率layout不一样。我们踩过的坑MindSpore的attention输入layout是[batch, seq_len, heads, head_dim]直接传给catlass模板结果不对但也不报错。数值偏了大概5%肉眼不容易看出来端到端推理结果就是差一截。import torch_npu def ensure_layout(tensor, target_layoutBSHD): 确保tensor的layout符合catlass要求 current_layout detect_layout(tensor) # 根据stride判断 if current_layout BSHD and target_layout BHSD: # [batch, seq, heads, dim] - [batch, heads, seq, dim] tensor tensor.transpose(1, 2).contiguous() elif current_layout BHSD and target_layout BSHD: tensor tensor.transpose(1, 2).contiguous() # 128字节对齐检查 assert tensor.stride(0) % 128 0, fstride未对齐: {tensor.stride(0)} return tensor另一个容易忽略的点contiguous()。transpose之后tensor不再连续必须调contiguous()才会真正重排内存。不调的话catlass模板读到的数据是乱的。Causal Mask的实现差异自回归推理必须用causal mask每个位置只能看到之前的token。catlass模板的causal实现有两种模式模式1下三角mask矩阵显式构造一个下三角bool矩阵传入kernel。优点是通用缺点是占用O(N²)显存——跟标准attention一样的毛病。模式2对角线跳过kernel内部根据tile坐标判断哪些计算可以跳过。不需要额外显存而且能跳过大量无效计算。// catlass模板内部的对角线跳过逻辑简化版 for (int tile_n 0; tile_n num_kv_tiles; tile_n) { // 当前Q tile的行范围: [tile_m * block_m, (tile_m1) * block_m) // 当前K tile的列范围: [tile_n * block_n, (tile_n1) * block_n) if (causal tile_n * block_n (tile_m 1) * block_m) { // 这个K tile完全在mask之外跳过 continue; // 长序列时能跳过约50%的tile } // 加载K/V tile做局部attention计算 load_kv_tile(k_tile, v_tile, tile_n); compute_local_attention(q_tile, k_tile, v_tile, o_tile); }对角线跳过的收益跟序列长度正相关。序列越长能跳过的tile越多序列长度总tile数跳过tile数跳过比例吞吐提升204825612850%1.3x4096102451250%1.3x81924096204850%1.4x1638416384819250%1.5x收益随序列增长而增加因为跳过计算的占比不变但省下来的显存带宽可以用于有效计算。16384序列时causal模式的吞吐比non-causal模式还高15%这就是跳过无效计算的回报。跟GE图引擎的融合优化单算子调优到4,200 tokens/s之后还有一档免费性能算子融合。昇腾CANN的GE图引擎能自动把FlashAttention和相邻算子合并执行。融合的前提是算子都走GE的图模式。如果你用AscendCL的单算子API调用FlashAttentionGE没法做融合。必须把整个模型编译成图import torch_npu from torch_npu.contrib import transfer_to_npu # 模型迁移到NPU自动走GE图模式 model model.npu() # GE日志确认融合 import os os.environ[GE_OPTYPE_BLACKLIST] # 清空黑名单允许所有融合 os.environ[DUMP_GE_GRAPH] 1 # 导出GE图 # 推理一次触发图编译 with torch.no_grad(): output model(input_ids) # 检查融合结果 # 日志路径/usr/local/Ascend/ascend-toolkit/latest/xx/dump/ # 搜索关键词FlashAttention Fuse融合前后GE图的对比融合前6个独立算子 RMSNorm → MatMul(Q) → MatMul(K) → MatMul(V) → FlashAttention → MatMul(O) 融合后2个融合算子 FusedNormQKV(RMSNorm MatMul Q/K/V) → FusedAttnProj(FlashAttention MatMul O)显存读写次数从12次降到4次吞吐从4,200提到4,860 tokens/s。反向传播的特殊处理推理服务只跑前向但如果你的场景是训练或finetuneFlashAttention的反向也需要catlass模板。反向有个额外参数deterministic。FlashAttnBwdConfig bwd_config; bwd_config.deterministic false; // 非确定性模式用atomic add bwd_config.deterministic true; // 确定性模式用排序累加非确定性模式快15%左右但梯度在多卡之间可能有微小差异FP16的atomic add不满足交换律。对训练来说这点差异通常不影响收敛但如果你在做数值对比测试建议开确定性模式。完整调优结果我们的13B模型在Ascend 910上的端到端性能变化阶段吞吐首token延迟显存标准attention基线1,2002,85052GBcatlass FlashAttention4,2001,28014GBblock参数调优4,5001,15012GBGE算子融合4,86098011GB从1,200到4,860整体提升4倍。其中catlass模板贡献最大3.5x参数调优贡献7%GE融合贡献8%。想在自己的昇腾NPU上复现这些数据去AtomGit拉catlass仓库https://atomgit.com/cann/catlass建议先把examples目录下的FlashAttention示例跑通确认环境没问题。然后对照本文的参数表逐步调优。如果遇到精度问题先用BF16排除溢出再逐步切回FP16。cann-recipes-train仓库里有FlashAttention在训练场景下的完整集成方案包括反向传播和多卡并行。
昇腾CANN上FlashAttention的工程实践:catlass模板调优全记录
发布时间:2026/5/21 7:14:18
去年我们把一个13B参数的推理服务从GPU迁移到昇腾NPUattention部分从标准实现换成catlass模板的FlashAttention吞吐从1,200 tokens/s提到4,800 tokens/s。但这个过程不是换个模板就完事——数据布局、精度对齐、分块策略、算子融合每一步都有坑。今天把整个调优过程记录下来包含具体的配置参数和实测数据。背景为什么选catlasscatlass不是CUTLASS的昇腾移植版它是昇腾CANN体系内的算子模板库定位是给开发者提供高性能算子的开发骨架。ops-nn、ops-math、ops-blas这些算子仓库底层都依赖catlass的模板。选catlass而不是直接写Ascend C算子原因很简单手动写一个达芬奇架构上高性能的FlashAttention你需要处理分块加载、Unified Buffer管理、bank conflict规避、流水线调度……一个人搞可能要一两个月。catlass模板把这些封装好了你只需要调参数。但调参数这三个字背后的事也不少。精度选择FP16还是BF16第一个决策点。昇腾910支持FP16和BF16两种半精度catlass模板两种都支持。选择依据维度FP16BF16表示范围±65504±3.4×10³⁸尾数精度10位7位softmax溢出风险高指数容易超65504低累加精度损失低高达芬奇算力利用率更高略低我们的场景是推理softmax中间值容易爆FP16的范围。实测数据# FP16 FlashAttention序列长度8192 [ERROR] softmax overflow detected, batch2 head15 tile_m48 # 17个tile中有3个溢出输出NaN # BF16 FlashAttention同样配置 [PASS] no overflow, max softmax value 1.2e38 # 所有tile正常所以长序列场景直接用BF16省去溢出排查的麻烦。短序列2048以内FP16精度更好推理结果跟FP32的误差更小。我们的折中方案4K以内FP164K以上BF16。FlashAttnConfig config; if (seq_len 4096) { config.use_fp16 true; } else { config.use_fp16 false; // 启用BF16 }分块策略不是越大越好catlass模板的核心参数是block_m和block_n控制Q和K/V的分块大小。直觉上block越大并行度越高性能越好。但达芬奇架构的约束不允许你无限加大约束1Unified Buffer容量达芬奇架构的Unified Buffer大约256KB具体大小随芯片版本略有差异。一个tile的数据量 block_m × head_dim × sizeof(data_type) × 3QKV。加上中间变量实际占用大概是这个值的2-3倍。block_m128, head_dim128, FP16: 单tile 128 × 128 × 2 × 3 96KB 加上softmax统计量和O的累加buffer ≈ 200KB → 勉强能塞进去 block_m256, head_dim128, FP16: 单tile 256 × 128 × 2 × 3 192KB 加上中间变量 ≈ 420KB → 超了超了会怎样catlass模板不会报错而是自动降级——把一个tile拆成多次加载性能反而比block_m128更差。约束2K/V的复用模式FlashAttention的outer loop是沿M方向Q的序列方向遍历inner loop是沿N方向K/V的序列方向。每个Q的tile要跟所有K/V的tile做计算。所以K/V的tile会被反复加载block_n越大单次加载的数据量越大但加载次数越少。block_mblock_nK/V加载次数(Q单tile)单次加载量(KB)实测吞吐12812832323,4001286464163,8002566464164,2001283212883,200block_n64比128快因为小tile的cache命中率更高。block_n32太碎了调度开销吃掉了cache收益。block_m256block_n64是最优组合但要确认Unified Buffer够用。数据布局这步做错后面全白搭catlass模板要求输入数据的layout是[batch, heads, seq_len, head_dim]row-major存储stride必须128字节对齐。PyTorch默认的tensor layout恰好满足但如果你从其他框架MindSpore、Paddle传入数据大概率layout不一样。我们踩过的坑MindSpore的attention输入layout是[batch, seq_len, heads, head_dim]直接传给catlass模板结果不对但也不报错。数值偏了大概5%肉眼不容易看出来端到端推理结果就是差一截。import torch_npu def ensure_layout(tensor, target_layoutBSHD): 确保tensor的layout符合catlass要求 current_layout detect_layout(tensor) # 根据stride判断 if current_layout BSHD and target_layout BHSD: # [batch, seq, heads, dim] - [batch, heads, seq, dim] tensor tensor.transpose(1, 2).contiguous() elif current_layout BHSD and target_layout BSHD: tensor tensor.transpose(1, 2).contiguous() # 128字节对齐检查 assert tensor.stride(0) % 128 0, fstride未对齐: {tensor.stride(0)} return tensor另一个容易忽略的点contiguous()。transpose之后tensor不再连续必须调contiguous()才会真正重排内存。不调的话catlass模板读到的数据是乱的。Causal Mask的实现差异自回归推理必须用causal mask每个位置只能看到之前的token。catlass模板的causal实现有两种模式模式1下三角mask矩阵显式构造一个下三角bool矩阵传入kernel。优点是通用缺点是占用O(N²)显存——跟标准attention一样的毛病。模式2对角线跳过kernel内部根据tile坐标判断哪些计算可以跳过。不需要额外显存而且能跳过大量无效计算。// catlass模板内部的对角线跳过逻辑简化版 for (int tile_n 0; tile_n num_kv_tiles; tile_n) { // 当前Q tile的行范围: [tile_m * block_m, (tile_m1) * block_m) // 当前K tile的列范围: [tile_n * block_n, (tile_n1) * block_n) if (causal tile_n * block_n (tile_m 1) * block_m) { // 这个K tile完全在mask之外跳过 continue; // 长序列时能跳过约50%的tile } // 加载K/V tile做局部attention计算 load_kv_tile(k_tile, v_tile, tile_n); compute_local_attention(q_tile, k_tile, v_tile, o_tile); }对角线跳过的收益跟序列长度正相关。序列越长能跳过的tile越多序列长度总tile数跳过tile数跳过比例吞吐提升204825612850%1.3x4096102451250%1.3x81924096204850%1.4x1638416384819250%1.5x收益随序列增长而增加因为跳过计算的占比不变但省下来的显存带宽可以用于有效计算。16384序列时causal模式的吞吐比non-causal模式还高15%这就是跳过无效计算的回报。跟GE图引擎的融合优化单算子调优到4,200 tokens/s之后还有一档免费性能算子融合。昇腾CANN的GE图引擎能自动把FlashAttention和相邻算子合并执行。融合的前提是算子都走GE的图模式。如果你用AscendCL的单算子API调用FlashAttentionGE没法做融合。必须把整个模型编译成图import torch_npu from torch_npu.contrib import transfer_to_npu # 模型迁移到NPU自动走GE图模式 model model.npu() # GE日志确认融合 import os os.environ[GE_OPTYPE_BLACKLIST] # 清空黑名单允许所有融合 os.environ[DUMP_GE_GRAPH] 1 # 导出GE图 # 推理一次触发图编译 with torch.no_grad(): output model(input_ids) # 检查融合结果 # 日志路径/usr/local/Ascend/ascend-toolkit/latest/xx/dump/ # 搜索关键词FlashAttention Fuse融合前后GE图的对比融合前6个独立算子 RMSNorm → MatMul(Q) → MatMul(K) → MatMul(V) → FlashAttention → MatMul(O) 融合后2个融合算子 FusedNormQKV(RMSNorm MatMul Q/K/V) → FusedAttnProj(FlashAttention MatMul O)显存读写次数从12次降到4次吞吐从4,200提到4,860 tokens/s。反向传播的特殊处理推理服务只跑前向但如果你的场景是训练或finetuneFlashAttention的反向也需要catlass模板。反向有个额外参数deterministic。FlashAttnBwdConfig bwd_config; bwd_config.deterministic false; // 非确定性模式用atomic add bwd_config.deterministic true; // 确定性模式用排序累加非确定性模式快15%左右但梯度在多卡之间可能有微小差异FP16的atomic add不满足交换律。对训练来说这点差异通常不影响收敛但如果你在做数值对比测试建议开确定性模式。完整调优结果我们的13B模型在Ascend 910上的端到端性能变化阶段吞吐首token延迟显存标准attention基线1,2002,85052GBcatlass FlashAttention4,2001,28014GBblock参数调优4,5001,15012GBGE算子融合4,86098011GB从1,200到4,860整体提升4倍。其中catlass模板贡献最大3.5x参数调优贡献7%GE融合贡献8%。想在自己的昇腾NPU上复现这些数据去AtomGit拉catlass仓库https://atomgit.com/cann/catlass建议先把examples目录下的FlashAttention示例跑通确认环境没问题。然后对照本文的参数表逐步调优。如果遇到精度问题先用BF16排除溢出再逐步切回FP16。cann-recipes-train仓库里有FlashAttention在训练场景下的完整集成方案包括反向传播和多卡并行。