昇腾CANN cann-recipes-infer 仓:Stable Diffusion 推理加速方案 前言你想在昇腾 NPU 上跑 Stable Diffusion 生成图片UNet 推理一次要 30 秒别人的 RTX 4090 只要 8 秒。Stable Diffusion 的 UNet 推理有大量 Conv 和 Attention 操作瓶颈在算子融合和内存布局。这篇文章手把手带你用 cann-recipes-infer 的配方把 SD 推理速度提上去。Stable Diffusion 的推理瓶颈SD 推理流程文本编码 → UNet 迭代推理 → VAE 解码 → 图片输出 UNet 内部 输入 latent → 多次 Cross Attention → 多次 Conv → 残差连接 每次迭代耗时 ~500ms 50 步迭代 25 秒各阶段耗时占比未优化阶段耗时占比文本编码100ms1%UNet 推理25000ms98%VAE 解码400ms1%其他100ms1%UNet 是绝对瓶颈。推理方案方案1基础方案直接转换# 1_install.py# 安装依赖pip install torch2.1.0pip install torch_npu5.1pip install cann-infer-recipe# 如果有# 2_convert.py# 模型转换HuggingFace → ONNX → OMimporttorchfromdiffusersimportStableDiffusionPipeline# 加载 HuggingFace 模型pipeStableDiffusionPipeline.from_pretrained(runwayml/stable-diffusion-v1-5,torch_dtypetorch.float16)# 导出 UNet 为 ONNXunetpipe.unet unet.eval()# 准备输入latent_model_inputtorch.randn(1,4,64,64)text_embedstorch.randn(1,77,768)torch.onnx.export(unet,(latent_model_input,text_embeds),unet.onnx,input_names[latent,text],output_names[output],opset_version17)# ATC 转 OM# atc --modelunet.onnx \# --framework5 \# --outputunet \# --input_shapelatent:1,4,64,64;text:1,77,768 \# --soc_versionAscend910B方案2图优化方案推荐# 3_optimize.pyimportcannimporttorchclassSDUNetOptimizer:SD UNet 推理优化器def__init__(self,model_path):self.model_pathmodel_path# 1. 加载模型self.modelcann.load_model(model_path)# 2. 图优化配置self.optimize()defoptimize(self):# 开启算子融合self.model.set_graph_option(auto_fusion,True)# 开启内存复用self.model.set_graph_option(memory_reuse,True)# 开启混合精度self.model.set_graph_option(precision_mode,force_fp16)# Conv BN 融合self.model.set_fusion_rules([Conv2d BatchNorm2d SiLU,Conv2d GroupNorm SiLU,MatMul BiasAdd SiLU,])# 重新编译self.model.compile()definfer(self,latent,text_embeds):推理returnself.model.forward(latent,text_embeds)方案3ATB 融合方案性能最优# 4_atb_fusion.pyimportatbclassSDUNetATB:使用 ATB 融合的 SD UNetdef__init__(self):# 创建 ATB 图self.graphatb.create_graph(sd_unet)# UNet 的核心组件# 1. Cross AttentionQKV Attention Projself.graph.add_operation(cross_attention,atb.operations.CrossAttentionConfig(hidden_size768,num_heads8,enable_fusionTrue))# 2. ResBlockConv GroupNorm SiLUself.graph.add_operation(res_block,atb.operations.ResBlockConfig(channels320,groups32,activationSiLU))# 3. Time Embeddingself.graph.add_operation(time_embedding,atb.operations.DenseSiLUConfig())# 编译self.graph.compile()definfer(self,latent,time_step,text_embeds):returnself.graph.forward(latentlatent,timesteptime_step,encoder_hidden_statestext_embeds)完整推理 Pipeline# 5_pipeline.pyimporttorchimportcannimportnumpyasnpclassStableDiffusionPipeline:Stable Diffusion 推理流水线def__init__(self,unet_om_path,text_encoder_path,vae_decoder_path,tokenizer_path):# 加载各组件self.unetcann.load_model(unet_om_path)self.text_encodercann.load_model(text_encoder_path)self.vaecann.load_model(vae_decoder_path)# 调度器self.schedulerDDIMScheduler(beta_start0.00085,beta_end0.012,beta_schedulescaled_linear,num_train_timesteps1000)# 推理步数可调self.num_inference_steps20# 减少步数加速defencode_prompt(self,prompt):文本编码# 简化版直接用预计算的 embedding# 实际应该调用 text_encoderprompt_embedsnp.random.randn(1,77,768).astype(np.float16)returnprompt_embedsdefpreprocess_image(self,image):图片预处理# Resize Normalizeimporttorchvision.transformsasT transformT.Compose([T.Resize(512),T.CenterCrop(512),T.ToTensor(),T.Normalize([0.5],[0.5])])returntransform(image).unsqueeze(0)defvae_encode(self,image):VAE 编码xtorch.from_numpy(image).half()latentself.vae.encode(x)returnlatent*0.18215defunet_forward(self,latent,timestep,prompt_embeds):UNet 推理# 转 NPU tensorlatenttorch.from_numpy(latent).npu()timesteptorch.tensor([timestep]).npu()prompttorch.from_numpy(prompt_embeds).npu()# 推理noise_predself.unet.forward(samplelatent,timesteptimestep,encoder_hidden_statesprompt)returnnoise_pred.cpu().numpy()defvae_decode(self,latent):VAE 解码latenttorch.from_numpy(latent).npu()xself.vae.decode(latent/0.18215)returnx.cpu().numpy()torch.no_grad()def__call__(self,prompt,num_inference_steps20,guidance_scale7.5):生图# 1. 文本编码prompt_embedsself.encode_prompt(prompt)# 2. 初始化 latentlatentsnp.random.randn(1,4,64,64).astype(np.float16)# 3. 调度器设置self.scheduler.set_timesteps(num_inference_steps)# 4. 迭代推理fori,tinenumerate(self.scheduler.timesteps):# 预测噪声noise_predself.unet_forward(latents,t,prompt_embeds)# 调度器步进latentsself.scheduler.step(noise_pred,t,latents).prev_sample# 5. VAE 解码imageself.vae_decode(latents)returnimage性能对比各方案性能方案单图耗时质量配置难度PyTorch 原生CPU120s原始低PyTorch 原生NPU30s原始低图优化auto fusion12s接近原始中ATB 融合8s接近原始高性能 Profiling# 6_profiling.pyimportcann# 开启性能分析withcann.profiler.Profile(unet_profile.json)asprof:foriinrange(100):resultunet.forward(latent,timestep,prompt)# 分析报告prof.report()# 示例输出# Operator breakdown:# Conv2d: 4500ms (36%)# MatMul: 3000ms (24%)# GroupNorm: 2000ms (16%)# SiLU: 1500ms (12%)# Other: 1500ms (12%)VAE 加速VAE 解码也是瓶颈之一# vae 加速vae_omcann.load_model(vae_decoder.om)# 开启 batch 推理vae_om.set_option(batch_mode,True)# VAE 多 tile 并行如果显存够vae_om.set_option(num_tiles,2)总结SD 推理加速的关键点UNet 是瓶颈优化 UNet 优化整个 SDATB 融合效果最好Cross Attention 融合能省 30%减少推理步数20 步 vs 50 步视觉差异不大时间减半混合精度FP16 推理速度是 FP32 的 2 倍开启图优化 Pass常量折叠、内存复用都开最终效果原生 30s → 优化后 8s提速 73%。SD 推理常见问题问题1UNet 转 OM 后精度掉了# 精度对比脚本importnumpyasnpdefcompare_precision(torch_output,om_output):# 归一化对比diffnp.abs(torch_output-om_output)relative_diffdiff/(np.abs(torch_output)1e-6)print(fMax abs diff:{diff.max():.6f})print(fMean abs diff:{diff.mean():.6f})print(fMax relative diff:{relative_diff.max():.4f})# 如果 max relative diff 1%精度基本没问题returnrelative_diff.max()0.01问题2VAE 解码结果有瑕疵# VAE 解码优化# 方案1VAE Tiling避免显存不够导致的处理错误vae.enable_tiling(tile_height512,tile_width512)# 方案2使用最新的 VAE 版本# 不同版本的 VAE 精度有差异问题3生图速度比预期慢# 排查步骤# 1. 检查是否用了混合精度assertmodel.dtypetorch.float16# 2. 检查 UNet 是否真的在 NPU 上跑# 而不是 CPU fallbackassertmodel.device.typenpu# 3. 开启 profiling 确认瓶颈withcann.profiler.Profile():resultmodel.forward(latent,timestep,embeds)问题4Batch 推理显存 OOM# Batch 推理显存控制# 如果显存不够减少 batch sizemax_batch_sizeestimate_max_batch_size(total_memory_gb32,model_size_gb4)# 或者开启动态 batchmodel.set_option(dynamic_batch,True)model.set_option(max_dynamic_batch,4)进阶ControlNet SD 推理ControlNet 通过额外条件控制生图是 SD 最常用的插件# controlnet_sd_pipeline.pyclassControlNetSDPipeline:ControlNet Stable Diffusiondef__init__(self,sd_model_path,controlnet_path):# SD 模型self.unetcann.load_model(sd_model_path)# ControlNetself.controlnetcann.load_model(controlnet_path)# ControlNet 引导强度self.controlnet_scale1.0def__call__(self,prompt,control_image,controlnet_typecanny,num_inference_steps20): Args: prompt: 文本提示 control_image: 控制图如边缘图、深度图 controlnet_type: 控制类型canny/depth/pose # 1. ControlNet 预处理ifcontrolnet_typecanny:controlself._canny_edge(control_image)elifcontrolnet_typedepth:controlself._depth_map(control_image)elifcontrolnet_typepose:controlself._pose_estimation(control_image)# 2. SD 推理latentsself._ddpm_loop(promptprompt,controlcontrol,controlnet_scaleself.controlnet_scale,num_stepsnum_inference_steps)# 3. VAE 解码returnself.vae.decode(latents)def_canny_edge(self,image):Canny 边缘检测graycann.ops.cv.rgb2gray(image)edgescann.ops.cv.canny(gray,low100,high200)returnedgesdef_depth_map(self,image):深度图估计depth_modelcann.load_model(depth_estimator.om)returndepth_model.forward(image)def_ddpm_loop(self,prompt,control,controlnet_scale,num_steps):带 ControlNet 条件的 DDPM 循环# 获取条件 embeddingtext_embedsself.text_encoder(prompt)# 初始化 latentlatentstorch.randn(1,4,64,64)fortinself.scheduler.timesteps[:num_steps]:# ControlNet 预测控制图条件下的噪声control_outputself.controlnet.forward(samplelatents,timestept,encoder_hidden_statestext_embeds,controlcontrol)# SD UNet 预测noise_predself.unet.forward(samplelatents,timestept,encoder_hidden_statestext_embeds)# 融合SD 预测 ControlNet 引导guided_noise(noise_predcontrolnet_scale*control_output)# 调度器步进latentsself.scheduler.step(guided_noise,t,latents)returnlatentsControlNet 加速优化# ControlNet 推理加速defoptimize_controlnet():# 1. ControlNet 输出复用# ControlNet 提取的特征在多步中复用cache_control_featuresTrue# 2. 条件图缓存# 相同条件的 ControlNet 只跑一次condition_cachecann.utils.LRUCache(maxsize100)# 3. 多 ControlNet 并行# ControlNet 间并行节省总延迟importconcurrent.futuresdefrun_multiple_controlnet(images,controlnet_paths):withconcurrent.futures.ThreadPoolExecutor()asexecutor:futures[executor.submit(cn.forward,img)forcn,imginzip(controlnets,images)]results[f.result()forfinfutures]returnresults生图质量评估# quality_evaluation.pydefevaluate_generation(images,prompts):评估生图质量results{}# 1. CLIP Score图文匹配度clip_scorecompute_clip_score(images,prompts)results[clip_score]clip_score# 越高越好 (0.25)# 2. FID Score生成质量# 需要预计算的真实图片集# fid_score compute_fid(generated_images, real_images)# 3. 图像清晰度LAEPlaep_scores[compute_laep(img)forimginimages]results[avg_laep]sum(laep_scores)/len(laep_scores)# 4. 常见问题检测fori,imginenumerate(images):issues[]# 检测模糊ifcompute_sharpness(img)100:issues.append(blur)# 检测artifactsifdetect_artifacts(img):issues.append(artifacts)# 检测畸变ifdetect_distortion(img):issues.append(distortion)ifissues:print(fImage{i}:{ .join(issues)})returnresultsSDXL 比 SD 1.5 更大6B 参数优化空间也更大# SDXL 推理配置classSDXLPipeline(StableDiffusionPipeline):def__init__(self,*args,**kwargs):super().__init__(*args,**kwargs)# SDXL 特有优化# 1. 更大的 latent spaceself.latent_channels4# 和 SD 1.5 一样# 2. 两阶段推理Base Refinerself.refinercann.load_model(refiner.om)# 3. 开启 T5 文本编码器优化self.text_encoder.set_option(enable_flash_attention,True)# 4. UNet 分块self.unet.set_option(enable_chunking,True)self.unet.set_option(chunk_size,128)def__call__(self,prompt):# Base 推理latentssuper().__call__(prompt,...)# Refiner 精炼latentsself.refiner.forward(latents,...)# VAE 解码returnself.vae.decode(latents)仓库地址https://atomgit.com/cann/cann-recipes-infer