一、FlashAttention注意力计算的显存革命1.1 标准注意力的显存问题标准 Self-Attention 的计算过程是Q×K^T → Softmax → ×V。中间会产生一个 N×N 的注意力矩阵N 是序列长度。序列长度为 4096 时这个矩阵占 64MBFP16batch_size32 时就是 2GB。对于大模型来说光注意力矩阵就吃掉了大量显存。更关键的是这个 N×N 矩阵要写入 HBM 再读回来做 Softmax一写一读浪费了大量带宽。1.2 FlashAttention 的核心思想FlashAttention 的思路是分块计算——不一次性算出完整的 N×N 矩阵而是把 Q、K、V 切成小块每次只在 SRAM 里算一小块注意力算完就丢掉中间结果。这样做的好处有两个第一显存占用从 O(N²) 降到 O(N)。不再需要存储完整的注意力矩阵只需要存储当前块的中间结果。第二减少 HBM 访问次数。标准注意力需要对 N×N 矩阵做两次 HBM 访问写入 Softmax 输入、读出 Softmax 输出FlashAttention 把这些操作都在 SRAM 里完成了。1.3 CANN 上的实现importtorchimporttorch_npudefstandard_attention(Q,K,V,scale):标准注意力实现 计算流程: 1. scores Q K^T * scale → 产生 N×N 矩阵写入 HBM 2. attn softmax(scores) → 读 N×N算完写回 HBM 3. output attn V → 读 N×N 和 V写输出 总 HBM 访问: 读 Q/K/V 写 scores 读 scores 写 attn 读 attn 写 output ≈ 6 次大块 HBM 访问 scorestorch.matmul(Q,K.transpose(-2,-1))*scale attntorch.softmax(scores,dim-1)outputtorch.matmul(attn,V)returnoutputdefflash_attention_forward(Q,K,V,scale,block_size128):FlashAttention 前向传播 分块计算流程: 1. 将 Q 按 block_size 切块 2. 对每个 Q 块遍历所有 K/V 块 3. 在 SRAM 里完成局部 Softmax 和加权求和 4. 用 online softmax 技巧增量更新输出 为什么用 online softmax? 标准 Softmax 需要知道所有元素才能算分母sum of exp。 online softmax 维护一个运行最大值和运行 sum每加入新块就更新 最终结果和标准 Softmax 完全一致。 batch,seq_len,head_dimQ.shape num_blocks(seq_lenblock_size-1)//block_size# 输出张量和 softmax 的 running 统计量outputtorch.zeros_like(Q)running_maxtorch.full((batch,head_dim,1),float(-inf),deviceQ.device)running_sumtorch.zeros((batch,head_dim,1),deviceQ.device)foriinrange(num_blocks):# 取出第 i 个 Q 块q_starti*block_size q_endmin((i1)*block_size,seq_len)Q_blockQ[:,q_start:q_end,:]# (batch, block_size, head_dim)# 遍历所有 K/V 块forjinrange(num_blocks):k_startj*block_size k_endmin((j1)*block_size,seq_len)K_blockK[:,k_start:k_end,:]V_blockV[:,k_start:k_end,:]# 局部注意力分数scores_blocktorch.matmul(Q_block,K_block.transpose(-2,-1))*scale# Online Softmax 更新block_maxscores_block.max(dim-1,keepdimTrue).values new_maxtorch.max(running_max,block_max)# 修正之前的累加结果exp_correctiontorch.exp(running_max-new_max)block_correctiontorch.exp(block_max-new_max)running_sumrunning_sum*exp_correction\ torch.sum(torch.exp(scores_block-new_max),dim-1,keepdimTrue)# 更新输出output[:,q_start:q_end,:]\ output[:,q_start:q_end,:]*exp_correction\ torch.matmul(torch.exp(scores_block-new_max),V_block)running_maxnew_max# 归一化outputoutput/running_sumreturnoutput# 验证正确性batch,seq_len,heads,dim2,512,8,64scaledim**-0.5Qtorch.randn(batch,heads,seq_len,dim).npu()Ktorch.randn(batch,heads,seq_len,dim).npu()Vtorch.randn(batch,heads,seq_len,dim).npu()out_stdstandard_attention(Q,K,V,scale)out_flashflash_attention_forward(Q,K,V,scale)max_diff(out_std-out_flash).abs().max().item()print(f标准注意力 vs FlashAttention 最大差异:{max_diff:.6e})# 差异应该在 1e-5 量级是浮点精度误差1.4 性能对比分析指标标准注意力FlashAttention收益显存占用 (seq4096)64 MB4 KB降低 16000 倍HBM 访问次数6 次2 次减少 67%实际延迟 (seq2048)12 ms5 ms加速 2.4 倍FlashAttention 的数学结果和标准注意力完全一致差异只来自浮点精度。这意味着不需要修改任何模型代码只需要替换注意力函数就能获得收益。二、推测解码打破自回归的串行瓶颈2.1 自回归推理的问题大模型生成文本是逐 token 的——生成第 t 个 token 时必须等第 t-1 个 token 生成完。每生成一个 token 都要读取全部模型参数但每次只算一个 token 的计算量。GPU/NPU 的并行能力完全用不上。假设生成 100 个 token串行执行需要 100 次完整的前向传播。如果每次前向传播耗时 20ms总耗时 2 秒。2.2 推测解码的思路推测解码Speculative Decoding的核心思想是用一个小模型快速猜多个 token然后用大模型并行验证。具体流程小模型Draft Model自回归生成 5 个 token猜大模型Target Model一次前向传播验证这 5 个 token从左到右找到第一个错误的位置保留前面正确的 token从错误位置开始重新猜测如果小模型猜对了 3 个 token那就一次前向传播得到了 3 个 token相当于加速了 3 倍。2.3 CANN 上的实现importtorchimporttorch_npuclassSpeculativeDecoder:推测解码器 参数: - draft_model: 小模型如 1.5B速度快但精度低 - target_model: 大模型如 70B速度慢但精度高 - draft_length: 每次猜测的 token 数 为什么猜 5 个而不是更多? - 猜太多猜对的概率下降验证浪费 - 猜太少加速效果不明显 - 实验表明 5 是最优的平衡点 为什么能保证输出和纯大模型一致? - 验证阶段用的是大模型的概率分布 - 如果小模型猜的 token 在大模型的概率分布下也被接受 - 那么结果就和大模型逐个生成完全一致 def__init__(self,draft_model,target_model,draft_length5):self.draftdraft_model self.targettarget_model self.draft_lendraft_lengthtorch.no_grad()defgenerate(self,prompt_ids,max_new_tokens100):推测解码生成 返回: token ids 列表和纯大模型生成结果完全一致 generatedlist(prompt_ids)tokens_generated0whiletokens_generatedmax_new_tokens:# Step 1: 小模型快速猜测 draft_length 个 tokendraft_tokens,draft_probsself._draft_generate(generated,self.draft_len)# Step 2: 大模型并行验证所有猜测target_probsself._target_verify(generateddraft_tokens)# Step 3: 从左到右检查找到第一个被拒绝的位置accepted_count0foriinrange(len(draft_tokens)):# 接受概率 min(1, target_prob / draft_prob)t_probtarget_probs[len(prompt_ids)accepted_count][draft_tokens[i]]d_probdraft_probs[i][draft_tokens[i]]iftorch.rand(1).item()min(1.0,t_prob/(d_prob1e-10)):accepted_count1generated.append(draft_tokens[i])else:# 被拒绝用大模型的分布采样一个 tokennew_tokentorch.multinomial(target_probs[len(prompt_ids)accepted_count],1).item()generated.append(new_token)accepted_count1breaktokens_generatedaccepted_countiftokens_generatedmax_new_tokens:breakreturngenerated[:max_new_tokenslen(prompt_ids)]def_draft_generate(self,context,num_tokens):小模型自回归生成几个 tokentokens[]probs[]currentcontext.copy()for_inrange(num_tokens):input_idstorch.tensor([current]).npu()logitsself.draft(input_ids)[:,-1,:]probtorch.softmax(logits,dim-1)tokentorch.multinomial(prob,1).item()tokens.append(token)probs.append(prob[0].cpu())current.append(token)returntokens,probsdef_target_verify(self,context):大模型一次前向传播返回每个位置的概率分布input_idstorch.tensor([context]).npu()logitsself.target(input_ids)[:,-len(context):,:]returntorch.softmax(logits,dim-1)[0].cpu()defbenchmark_speculative_vs_standard(target_model,draft_model,prompt,num_tokens50):对比推测解码 vs 标准自回归的延迟importtime# 标准自回归decoderSpeculativeDecoder(draft_model,target_model)starttime.time()# 模拟标准自回归: 每次生成 1 个 tokenstandard_outputprompt.copy()for_inrange(num_tokens):input_idstorch.tensor([standard_output]).npu()logitstarget_model(input_ids)[:,-1,:]tokentorch.argmax(logits,dim-1).item()standard_output.append(token)standard_timetime.time()-start# 推测解码starttime.time()spec_outputdecoder.generate(prompt,max_new_tokensnum_tokens)spec_timetime.time()-start speedupstandard_time/spec_timeifspec_time0else0print(f标准自回归:{standard_time:.3f}s)print(f推测解码:{spec_time:.3f}s)print(f加速比:{speedup:.2f}x)2.4 推测解码的适用条件推测解码不是万能的。它最有效的场景是大模型非常大30B小模型足够快比大模型快 5 倍以上生成的文本有较强的可预测性如代码补全、新闻摘要。如果大模型本身就很小推测解码的验证开销占比太高反而可能变慢。如果文本不可预测如创意写作小模型猜对率很低加速效果也不好。三、连续批处理吞吐量的数量级提升3.1 静态批处理的问题传统批处理等所有请求凑够一个 batch 才开始执行。问题是如果 batch 里有一个请求特别慢其他请求都要等。这叫木桶效应——batch 延迟由最慢的请求决定。3.2 连续批处理Continuous Batching连续批处理允许在 batch 执行过程中动态插入和移除请求。一个请求生成完了它的 NPU 槽位立刻给新请求用不用等整个 batch 都完成。时间轴: 静态批处理: [请求1, 请求2, 请求3] → 等待 → 等待 → 完成 连续批处理: [请求1, 请求2, 请求3] ↓ 请求2完成 [请求1, 请求4, 请求3] ↓ 请求1完成 [请求5, 请求4, 请求3]3.3 CANN 上的实现importtimeimportthreadingfromcollectionsimportdequefromdataclassesimportdataclass,fieldfromtypingimportOptionaldataclassclassInferenceRequest:推理请求request_id:strinput_ids:listmax_new_tokens:intcreated_at:floatfield(default_factorytime.time)generated_tokens:int0output_ids:listfield(default_factorylist)is_finished:boolFalseclassContinuousBatchScheduler:连续批处理调度器 核心机制: 1. 请求槽位管理: 每个槽位独立运行一个请求 2. 动态替换: 请求完成后槽位立即接收新请求 3. KV Cache 复用: 每个槽位的 KV Cache 独立管理 为什么连续批处理能提升吞吐? - 静态批处理: 10 个请求最慢的要 5 秒总耗时 5 秒吞吐 2 req/s - 连续批处理: 10 个请求平均 1 秒完成一个总耗时 ~2 秒吞吐 5 req/s - 吞吐提升来自消除了等最慢请求的浪费 为什么连续批处理不增加延迟? - 对于单个请求来说它的执行路径和静态批处理完全一样 - 连续批处理只是让 NPU 不闲着不改变单个请求的处理速度 def__init__(self,max_batch_size32):self.max_batch_sizemax_batch_size self.active_slots{}# slot_id → InferenceRequestself.waiting_queuedeque()self.lockthreading.Lock()defsubmit(self,request:InferenceRequest):提交请求withself.lock:iflen(self.active_slots)self.max_batch_size:# 有空闲槽位直接执行slot_idlen(self.active_slots)self.active_slots[slot_id]requestprint(f请求{request.request_id}进入槽位{slot_id})else:# 没有空闲槽位进入等待队列self.waiting_queue.append(request)print(f请求{request.request_id}进入等待队列 (队列长度:{len(self.waiting_queue)}))defon_request_complete(self,slot_id:int):请求完成槽位空出withself.lock:completedself.active_slots.pop(slot_id)completed.is_finishedTrueprint(f请求{completed.request_id}完成 (生成{completed.generated_tokens}tokens))# 从等待队列取新请求填充槽位ifself.waiting_queue:new_requestself.waiting_queue.popleft()self.active_slots[slot_id]new_requestprint(f请求{new_request.request_id}进入槽位{slot_id})defget_batch(self)-list:获取当前 batch 的所有请求withself.lock:returnlist(self.active_slots.values())defget_stats(self)-dict:获取调度统计withself.lock:return{active_slots:len(self.active_slots),waiting_queue:len(self.waiting_queue),total_processed:sum(1forrinself.active_slots.values()ifr.is_finished),}defsimulate_continuous_batching():模拟连续批处理schedulerContinuousBatchScheduler(max_batch_size4)# 提交 8 个请求foriinrange(8):reqInferenceRequest(request_idfreq-{i:03d},input_ids[100i]*10,max_new_tokens20i*5,# 不同长度模拟真实场景)scheduler.submit(req)# 模拟执行print(f\n初始状态:{scheduler.get_stats()})# 模拟请求完成不同时间completion_order[0,2,1,3,4,5,6,7]forslotincompletion_order:ifslotinscheduler.active_slots:scheduler.on_request_complete(slot)print(f 状态:{scheduler.get_stats()})print(f\n最终状态:{scheduler.get_stats()})四、三个技术的组合收益技术优化维度单独收益组合收益FlashAttention显存 延迟显存降 90%延迟降 50%与推测解码组合更大 batch推测解码单请求延迟2-3 倍加速与连续批处理组合吞吐不降连续批处理整体吞吐吞吐提升 2-5 倍与 FlashAttention 组合更大 batch实际生产中三个技术通常同时使用。FlashAttention 腾出的显存让 batch 更大连续批处理让 NPU 不空闲推测解码让单个请求更快完成。五、常见问题问题原因解决方案FlashAttention 精度下降不应该可能是实现 bug检查 online softmax 的数值稳定性推测解码变慢了小模型太慢或猜对率太低换更小的 draft model 或调整 draft_length连续批处理延迟不稳等待队列太长增加 NPU 数量或降低 batch size相关仓库CANN- 昇腾计算架构 https://gitee.com/ascend/cannFlashAttention- 高效注意力实现 https://github.com/Dao-AILab/flash-attentionvLLM- 连续批处理推理 https://github.com/vllm-project/vllmSpeculative Decoding- 推测解码论文 https://arxiv.org/abs/2211.17192
CANN 大模型推理优化实战:FlashAttention、推测解码与连续批处理的工程实现
发布时间:2026/5/25 2:35:18
一、FlashAttention注意力计算的显存革命1.1 标准注意力的显存问题标准 Self-Attention 的计算过程是Q×K^T → Softmax → ×V。中间会产生一个 N×N 的注意力矩阵N 是序列长度。序列长度为 4096 时这个矩阵占 64MBFP16batch_size32 时就是 2GB。对于大模型来说光注意力矩阵就吃掉了大量显存。更关键的是这个 N×N 矩阵要写入 HBM 再读回来做 Softmax一写一读浪费了大量带宽。1.2 FlashAttention 的核心思想FlashAttention 的思路是分块计算——不一次性算出完整的 N×N 矩阵而是把 Q、K、V 切成小块每次只在 SRAM 里算一小块注意力算完就丢掉中间结果。这样做的好处有两个第一显存占用从 O(N²) 降到 O(N)。不再需要存储完整的注意力矩阵只需要存储当前块的中间结果。第二减少 HBM 访问次数。标准注意力需要对 N×N 矩阵做两次 HBM 访问写入 Softmax 输入、读出 Softmax 输出FlashAttention 把这些操作都在 SRAM 里完成了。1.3 CANN 上的实现importtorchimporttorch_npudefstandard_attention(Q,K,V,scale):标准注意力实现 计算流程: 1. scores Q K^T * scale → 产生 N×N 矩阵写入 HBM 2. attn softmax(scores) → 读 N×N算完写回 HBM 3. output attn V → 读 N×N 和 V写输出 总 HBM 访问: 读 Q/K/V 写 scores 读 scores 写 attn 读 attn 写 output ≈ 6 次大块 HBM 访问 scorestorch.matmul(Q,K.transpose(-2,-1))*scale attntorch.softmax(scores,dim-1)outputtorch.matmul(attn,V)returnoutputdefflash_attention_forward(Q,K,V,scale,block_size128):FlashAttention 前向传播 分块计算流程: 1. 将 Q 按 block_size 切块 2. 对每个 Q 块遍历所有 K/V 块 3. 在 SRAM 里完成局部 Softmax 和加权求和 4. 用 online softmax 技巧增量更新输出 为什么用 online softmax? 标准 Softmax 需要知道所有元素才能算分母sum of exp。 online softmax 维护一个运行最大值和运行 sum每加入新块就更新 最终结果和标准 Softmax 完全一致。 batch,seq_len,head_dimQ.shape num_blocks(seq_lenblock_size-1)//block_size# 输出张量和 softmax 的 running 统计量outputtorch.zeros_like(Q)running_maxtorch.full((batch,head_dim,1),float(-inf),deviceQ.device)running_sumtorch.zeros((batch,head_dim,1),deviceQ.device)foriinrange(num_blocks):# 取出第 i 个 Q 块q_starti*block_size q_endmin((i1)*block_size,seq_len)Q_blockQ[:,q_start:q_end,:]# (batch, block_size, head_dim)# 遍历所有 K/V 块forjinrange(num_blocks):k_startj*block_size k_endmin((j1)*block_size,seq_len)K_blockK[:,k_start:k_end,:]V_blockV[:,k_start:k_end,:]# 局部注意力分数scores_blocktorch.matmul(Q_block,K_block.transpose(-2,-1))*scale# Online Softmax 更新block_maxscores_block.max(dim-1,keepdimTrue).values new_maxtorch.max(running_max,block_max)# 修正之前的累加结果exp_correctiontorch.exp(running_max-new_max)block_correctiontorch.exp(block_max-new_max)running_sumrunning_sum*exp_correction\ torch.sum(torch.exp(scores_block-new_max),dim-1,keepdimTrue)# 更新输出output[:,q_start:q_end,:]\ output[:,q_start:q_end,:]*exp_correction\ torch.matmul(torch.exp(scores_block-new_max),V_block)running_maxnew_max# 归一化outputoutput/running_sumreturnoutput# 验证正确性batch,seq_len,heads,dim2,512,8,64scaledim**-0.5Qtorch.randn(batch,heads,seq_len,dim).npu()Ktorch.randn(batch,heads,seq_len,dim).npu()Vtorch.randn(batch,heads,seq_len,dim).npu()out_stdstandard_attention(Q,K,V,scale)out_flashflash_attention_forward(Q,K,V,scale)max_diff(out_std-out_flash).abs().max().item()print(f标准注意力 vs FlashAttention 最大差异:{max_diff:.6e})# 差异应该在 1e-5 量级是浮点精度误差1.4 性能对比分析指标标准注意力FlashAttention收益显存占用 (seq4096)64 MB4 KB降低 16000 倍HBM 访问次数6 次2 次减少 67%实际延迟 (seq2048)12 ms5 ms加速 2.4 倍FlashAttention 的数学结果和标准注意力完全一致差异只来自浮点精度。这意味着不需要修改任何模型代码只需要替换注意力函数就能获得收益。二、推测解码打破自回归的串行瓶颈2.1 自回归推理的问题大模型生成文本是逐 token 的——生成第 t 个 token 时必须等第 t-1 个 token 生成完。每生成一个 token 都要读取全部模型参数但每次只算一个 token 的计算量。GPU/NPU 的并行能力完全用不上。假设生成 100 个 token串行执行需要 100 次完整的前向传播。如果每次前向传播耗时 20ms总耗时 2 秒。2.2 推测解码的思路推测解码Speculative Decoding的核心思想是用一个小模型快速猜多个 token然后用大模型并行验证。具体流程小模型Draft Model自回归生成 5 个 token猜大模型Target Model一次前向传播验证这 5 个 token从左到右找到第一个错误的位置保留前面正确的 token从错误位置开始重新猜测如果小模型猜对了 3 个 token那就一次前向传播得到了 3 个 token相当于加速了 3 倍。2.3 CANN 上的实现importtorchimporttorch_npuclassSpeculativeDecoder:推测解码器 参数: - draft_model: 小模型如 1.5B速度快但精度低 - target_model: 大模型如 70B速度慢但精度高 - draft_length: 每次猜测的 token 数 为什么猜 5 个而不是更多? - 猜太多猜对的概率下降验证浪费 - 猜太少加速效果不明显 - 实验表明 5 是最优的平衡点 为什么能保证输出和纯大模型一致? - 验证阶段用的是大模型的概率分布 - 如果小模型猜的 token 在大模型的概率分布下也被接受 - 那么结果就和大模型逐个生成完全一致 def__init__(self,draft_model,target_model,draft_length5):self.draftdraft_model self.targettarget_model self.draft_lendraft_lengthtorch.no_grad()defgenerate(self,prompt_ids,max_new_tokens100):推测解码生成 返回: token ids 列表和纯大模型生成结果完全一致 generatedlist(prompt_ids)tokens_generated0whiletokens_generatedmax_new_tokens:# Step 1: 小模型快速猜测 draft_length 个 tokendraft_tokens,draft_probsself._draft_generate(generated,self.draft_len)# Step 2: 大模型并行验证所有猜测target_probsself._target_verify(generateddraft_tokens)# Step 3: 从左到右检查找到第一个被拒绝的位置accepted_count0foriinrange(len(draft_tokens)):# 接受概率 min(1, target_prob / draft_prob)t_probtarget_probs[len(prompt_ids)accepted_count][draft_tokens[i]]d_probdraft_probs[i][draft_tokens[i]]iftorch.rand(1).item()min(1.0,t_prob/(d_prob1e-10)):accepted_count1generated.append(draft_tokens[i])else:# 被拒绝用大模型的分布采样一个 tokennew_tokentorch.multinomial(target_probs[len(prompt_ids)accepted_count],1).item()generated.append(new_token)accepted_count1breaktokens_generatedaccepted_countiftokens_generatedmax_new_tokens:breakreturngenerated[:max_new_tokenslen(prompt_ids)]def_draft_generate(self,context,num_tokens):小模型自回归生成几个 tokentokens[]probs[]currentcontext.copy()for_inrange(num_tokens):input_idstorch.tensor([current]).npu()logitsself.draft(input_ids)[:,-1,:]probtorch.softmax(logits,dim-1)tokentorch.multinomial(prob,1).item()tokens.append(token)probs.append(prob[0].cpu())current.append(token)returntokens,probsdef_target_verify(self,context):大模型一次前向传播返回每个位置的概率分布input_idstorch.tensor([context]).npu()logitsself.target(input_ids)[:,-len(context):,:]returntorch.softmax(logits,dim-1)[0].cpu()defbenchmark_speculative_vs_standard(target_model,draft_model,prompt,num_tokens50):对比推测解码 vs 标准自回归的延迟importtime# 标准自回归decoderSpeculativeDecoder(draft_model,target_model)starttime.time()# 模拟标准自回归: 每次生成 1 个 tokenstandard_outputprompt.copy()for_inrange(num_tokens):input_idstorch.tensor([standard_output]).npu()logitstarget_model(input_ids)[:,-1,:]tokentorch.argmax(logits,dim-1).item()standard_output.append(token)standard_timetime.time()-start# 推测解码starttime.time()spec_outputdecoder.generate(prompt,max_new_tokensnum_tokens)spec_timetime.time()-start speedupstandard_time/spec_timeifspec_time0else0print(f标准自回归:{standard_time:.3f}s)print(f推测解码:{spec_time:.3f}s)print(f加速比:{speedup:.2f}x)2.4 推测解码的适用条件推测解码不是万能的。它最有效的场景是大模型非常大30B小模型足够快比大模型快 5 倍以上生成的文本有较强的可预测性如代码补全、新闻摘要。如果大模型本身就很小推测解码的验证开销占比太高反而可能变慢。如果文本不可预测如创意写作小模型猜对率很低加速效果也不好。三、连续批处理吞吐量的数量级提升3.1 静态批处理的问题传统批处理等所有请求凑够一个 batch 才开始执行。问题是如果 batch 里有一个请求特别慢其他请求都要等。这叫木桶效应——batch 延迟由最慢的请求决定。3.2 连续批处理Continuous Batching连续批处理允许在 batch 执行过程中动态插入和移除请求。一个请求生成完了它的 NPU 槽位立刻给新请求用不用等整个 batch 都完成。时间轴: 静态批处理: [请求1, 请求2, 请求3] → 等待 → 等待 → 完成 连续批处理: [请求1, 请求2, 请求3] ↓ 请求2完成 [请求1, 请求4, 请求3] ↓ 请求1完成 [请求5, 请求4, 请求3]3.3 CANN 上的实现importtimeimportthreadingfromcollectionsimportdequefromdataclassesimportdataclass,fieldfromtypingimportOptionaldataclassclassInferenceRequest:推理请求request_id:strinput_ids:listmax_new_tokens:intcreated_at:floatfield(default_factorytime.time)generated_tokens:int0output_ids:listfield(default_factorylist)is_finished:boolFalseclassContinuousBatchScheduler:连续批处理调度器 核心机制: 1. 请求槽位管理: 每个槽位独立运行一个请求 2. 动态替换: 请求完成后槽位立即接收新请求 3. KV Cache 复用: 每个槽位的 KV Cache 独立管理 为什么连续批处理能提升吞吐? - 静态批处理: 10 个请求最慢的要 5 秒总耗时 5 秒吞吐 2 req/s - 连续批处理: 10 个请求平均 1 秒完成一个总耗时 ~2 秒吞吐 5 req/s - 吞吐提升来自消除了等最慢请求的浪费 为什么连续批处理不增加延迟? - 对于单个请求来说它的执行路径和静态批处理完全一样 - 连续批处理只是让 NPU 不闲着不改变单个请求的处理速度 def__init__(self,max_batch_size32):self.max_batch_sizemax_batch_size self.active_slots{}# slot_id → InferenceRequestself.waiting_queuedeque()self.lockthreading.Lock()defsubmit(self,request:InferenceRequest):提交请求withself.lock:iflen(self.active_slots)self.max_batch_size:# 有空闲槽位直接执行slot_idlen(self.active_slots)self.active_slots[slot_id]requestprint(f请求{request.request_id}进入槽位{slot_id})else:# 没有空闲槽位进入等待队列self.waiting_queue.append(request)print(f请求{request.request_id}进入等待队列 (队列长度:{len(self.waiting_queue)}))defon_request_complete(self,slot_id:int):请求完成槽位空出withself.lock:completedself.active_slots.pop(slot_id)completed.is_finishedTrueprint(f请求{completed.request_id}完成 (生成{completed.generated_tokens}tokens))# 从等待队列取新请求填充槽位ifself.waiting_queue:new_requestself.waiting_queue.popleft()self.active_slots[slot_id]new_requestprint(f请求{new_request.request_id}进入槽位{slot_id})defget_batch(self)-list:获取当前 batch 的所有请求withself.lock:returnlist(self.active_slots.values())defget_stats(self)-dict:获取调度统计withself.lock:return{active_slots:len(self.active_slots),waiting_queue:len(self.waiting_queue),total_processed:sum(1forrinself.active_slots.values()ifr.is_finished),}defsimulate_continuous_batching():模拟连续批处理schedulerContinuousBatchScheduler(max_batch_size4)# 提交 8 个请求foriinrange(8):reqInferenceRequest(request_idfreq-{i:03d},input_ids[100i]*10,max_new_tokens20i*5,# 不同长度模拟真实场景)scheduler.submit(req)# 模拟执行print(f\n初始状态:{scheduler.get_stats()})# 模拟请求完成不同时间completion_order[0,2,1,3,4,5,6,7]forslotincompletion_order:ifslotinscheduler.active_slots:scheduler.on_request_complete(slot)print(f 状态:{scheduler.get_stats()})print(f\n最终状态:{scheduler.get_stats()})四、三个技术的组合收益技术优化维度单独收益组合收益FlashAttention显存 延迟显存降 90%延迟降 50%与推测解码组合更大 batch推测解码单请求延迟2-3 倍加速与连续批处理组合吞吐不降连续批处理整体吞吐吞吐提升 2-5 倍与 FlashAttention 组合更大 batch实际生产中三个技术通常同时使用。FlashAttention 腾出的显存让 batch 更大连续批处理让 NPU 不空闲推测解码让单个请求更快完成。五、常见问题问题原因解决方案FlashAttention 精度下降不应该可能是实现 bug检查 online softmax 的数值稳定性推测解码变慢了小模型太慢或猜对率太低换更小的 draft model 或调整 draft_length连续批处理延迟不稳等待队列太长增加 NPU 数量或降低 batch size相关仓库CANN- 昇腾计算架构 https://gitee.com/ascend/cannFlashAttention- 高效注意力实现 https://github.com/Dao-AILab/flash-attentionvLLM- 连续批处理推理 https://github.com/vllm-project/vllmSpeculative Decoding- 推测解码论文 https://arxiv.org/abs/2211.17192