FlashDecode:Decode 阶段的 Attention 并行化改造 本文基于昇腾CANN和昇腾NPU围绕 ops-transformer 仓库的相关技术展开。FlashDecode 解决了 Decode 阶段的一个结构性浪费每个 Decode Step 只产生 1 个新 Token但 Attention 计算仍然要走完整的 QK^T 路径。FlashDecode 在 CANN 上做了一个关键优化——把多个 Decode Step 的 Attention 计算合并到一起让 NPU 的 Cube Unit 跑满。Decode 阶段 Attention 的痛点# 标准 Decode Attention——每步只算 1 个 Queryimporttorchimporttorch.nn.functionalasFdefdecode_attention(q,k_cache,v_cache,step_idx): q: [1, num_heads, 1, head_dim] —— 当前步的 Query k_cache: [1, num_heads, L, head_dim] —— L 是已缓存长度 v_cache: [1, num_heads, L, head_dim] step_idx: 当前是第几步 # Q: [1, h, 1, d] × K^T: [1, h, d, L] → score: [1, h, 1, L]scoretorch.matmul(q,k_cache.transpose(-2,-1))scorescore/(head_dim**0.5)attnF.softmax(score,dim-1)# attn: [1, h, 1, L] × V: [1, h, L, d] → [1, h, 1, d]outtorch.matmul(attn,v_cache)returnout# 问题MQ 序列长度1 → Cube 利用率只有 15-25%# 瓶颈在 ScoreV 这一步——Matrix-Vector 而不是 Matrix-Matrix每步 M1NPU 的 Cube Unit 大部分时间在等数据搬运。FlashDecode 的思路很简单把 K 缓存切块让多个 Query 并行查。FlashDecode 的块式 Attention# FlashDecode按块读取 KV Cache多个 Query Step 并行计算defflash_decode_attention(q_block,k_cache,v_cache,block_size64): q_block: [num_steps, num_heads, 1, head_dim] —— 合并多个 Decode Step 的 Q k_cache: [num_heads, total_len, head_dim] v_cache: [num_heads, total_len, head_dim] block_size: 每次从 Cache 读几组 KV num_stepsq_block.shape[0]num_headsq_block.shape[1]dq_block.shape[-1]total_lenk_cache.shape[1]# 输出累积器outputtorch.zeros(num_steps,num_heads,1,d)# 分块读取 KV Cache——NPU 的 L1 Buffer 只能装 block_size 个 KVforblock_startinrange(0,total_len,block_size):block_endmin(block_startblock_size,total_len)k_blockk_cache[:,block_start:block_end,:]# [h, bs, d]v_blockv_cache[:,block_start:block_end,:]# [h, bs, d]# Q 块 × K 块^T——现在 Mnum_steps, Kbs# Cube 实际算的是 [num_steps, d] × [d, bs] [num_steps, bs]# Mnum_steps 可以到 32-64Cube 利用率 70%forhinrange(num_heads):q_hq_block[:,h,0,:]# [num_steps, d]k_hk_block[h]# [bs, d]# 批量的 Score 计算——从 Vector 变 Matrixscore_htorch.matmul(q_h,k_h.transpose(-1,-2))# [num_steps, bs]score_hscore_h/(d**0.5)# Online-Softmax避免整段 Softmax 的显存开销local_maxscore_h.max(dim-1,keepdimTrue).values local_exptorch.exp(score_h-local_max)local_sumlocal_exp.sum(dim-1,keepdimTrue)local_outtorch.matmul(local_exp,v_block[h])# [num_steps, d]# 合并到输出——实际生产用 rescale 累加而不是简单加法output[:,h,0,:]local_out.squeeze(1)returnoutputFlashDecode 把 M1 的 Matrix-Vector 变成了 Mnum_steps 的 Matrix-Matrix。步子越大利用率越高但不能超过 64——超过了注意力分布就开始分散精度会掉。CANN 上的 FlashDecode 融合// FlashDecode 在 Ascend C 上的实现——融合了 Score Softmax 累加classFlashDecodeKernel:publicAscendC::Kernel{public:__aicore__inlineFlashDecodeKernel(){}__aicore__inlinevoidProcess()override{// 从 Global Memory 搬 Q 到 L1 BufferAscendC::LocalTensorfloatq_localAscendC::LocalAllocfloat(num_steps*head_dim);AscendC::DataCopy(q_local,gm_q,num_steps*head_dim);// 逐块处理 KV Cachefor(intblock0;blocknum_blocks;block){// 搬 K 块到 L1AscendC::LocalTensorfloatk_localAscendC::LocalAllocfloat(block_size*head_dim);AscendC::DataCopy(k_local,gm_kblock_offset,block_size*head_dim);// Cube 做 QK^T——走 MMA 指令AscendC::LocalTensorfloatscore_localAscendC::LocalAllocfloat(num_steps*block_size);// 这里触发 Cube Unit 的矩阵乘法AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRAN_A);// 直接在 L1 上做 Scale Softmax——不用回显存AscendC::Mul(score_local,score_local,inv_scale);AscendC::Exp(score_local,score_local);// 逐元素 ExpAscendC::ReduceSum(row_sum,score_local,1);// 逐行求和// 读 V 块算加权和AscendC::LocalTensorfloatv_localAscendC::LocalAllocfloat(block_size*head_dim);AscendC::DataCopy(v_local,gm_vblock_offset,block_size*head_dim);// Score (归一化后) V——仍在 L1 完成AscendC::MatMul(partial_out,score_local,v_local);// 累加输出——走了两轮再写回 Global MemoryAscendC::Add(output_local,output_local,partial_out);}// 最终写回AscendC::DataCopy(gm_out,output_local,num_steps*head_dim);}};实测下来 FlashDecode 在 Decode 阶段能把 GPU 的利用率从 15% 拉到 52%。每步处理 32 个合并 Query 时收益最高——再多缓存就装不下 K 块了。参考仓库FlashDecode 算子实现Runtime 多流调度