基于Triton的layernorm算子调优实践分析 作者昇腾实战派背景在进行视频生成模型的推理调优时通过分析profiling发现layernorm算子存在异常耗时现象。为了提高模型的推理效率需要对layernorm算子进行优化。本文将详细介绍问题的背景、原因分析及优化方案。问题描述在profiling中layernorm算子的执行过程中host在正式下发layernorm算子之前先分别下发了aclnnCast_SliceAiCore_Slice和aclnnCast_CastAiCore_Cast两个算子。这两个算子的作用是什么能否省掉原因分析通过查看op_summary文件可以详细了解到这两个算子的输入输出dtype和shape。具体信息如下图所示从summary中可以看出第一个算子的作用是将shape为[50220, 9, 128]的输入张量切分为shape为[50220, 3, 128]的张量第二个算子的作用是将切分后的张量数据类型从bfloat16转换为float32。第二个算子的原因不难分析因为代码中layernorm的实现是用的torch原生算子如图所示nn.layernorm的底层算子输入数据类型为float因此需要使用cast算子对数据类型进行转换。观察第一个算子的summary进一步产生了另一个疑问layernorm的输入shape为什么是[50220, 9, 128]呢明明在代码中已经通过unbind操作转换成[50220, 3, 128]了如下图所示这涉及到PyTorch中tensor的存储机制。tensor分为头信息区Tensor和存储区Storage。信息区主要保存着tensor的形状size、步长stride、数据类型dtype等信息而真正的数据则以连续一维数组的形式存储在存储区。如下图所示像view、reshape、unbind这一类的操作只是在host侧改变头信息区的指针ptr、步长stride等索引信息实际上并没有改变存储区device侧的storage。因此vid_q, vid_k, vid_v vid_qkv.unbind(1)这行代码的操作具象到实际内存中可以用下图来表示由此可见layernorm的输入张量vid_q并不是连续内存只不过是host侧的索引变了。因此遵循aclnnLayerNorm算子的输入规范需使用aclnnCast_SliceAiCore_Slice进行切片转换在device侧变成连续存储形式。算子优化本次优化的目的是跳过host侧的unbind操作并消除aclnnCast_SliceAiCore_Slice算子。为此需要开发一个支持非连续内存的layernorm算子。调用接口deftriton_inplace_layer_norm(qk:torch.Tensor,# 支持从 qkv slice 出来的不连续 tensorgamma:torch.Tensor,beta:torch.Tensor,):seq_len,n_heads,head_dimqk.shape# 50220, 3, 128seq_strideqk.stride(0)# 1152grid(48,)_inplace_layer_norm_kernel[grid](qk,gamma,beta,seq_len,n_heads,seq_stride,head_dim,eps1e-5,BLOCK_SIZE_SEQ64)returnqk调用接口比较简单易懂入参包括输入张量qk实际上是上文中的vid_q。gamma指的是原始layernorm的权重weight。beta指的是原始layernorm的权重bias。seq_len序列长度这里等于50220。n_headshead的数目这里等于3。seq_stride输入张量在序列维度的步长这里等于3x3x1281152。head_dim每个头的维度这里等于128。epslayernorm的分母防0参数。BLOCK_SIZE_SEQ这里指将每64个token划分为一个block方便在kernel中处理。中括号中的grid可以简单理解为并行处理的内核数和硬件能力有关这里设置为48是因为设备共有48个vector计算单元。kernel实现triton.jitdef_inplace_layer_norm_kernel(# Pointers to inputs/outputsinout_ptr,# [seq_len, n_heads, head_dim]gamma_ptr,# [head_dim]beta_ptr,# [head_dim]# Shapesseq_len:tl.constexpr,n_heads:tl.constexpr,seq_len_stride:tl.constexpr,head_dim:tl.constexpr,eps:tl.constexpr,BLOCK_SIZE_SEQ:tl.constexpr,):pidtl.program_id(0)num_programstl.num_programs(0)# 返回沿着指定 axis0 启动的程序实例的数量。 48个num_seq_blocks(seq_lenBLOCK_SIZE_SEQ-1)//BLOCK_SIZE_SEQ# 按序列划分共有785个seq block待处理num_programs_seqnum_programs//n_heads# 所有pid一起能够并行处理16个seq blockcol_offstl.arange(0,head_dim)# [0, 1, 2, ...., 127]gammatl.load(gamma_ptrcol_offs)betatl.load(beta_ptrcol_offs)forseq_block_idinrange(pid//n_heads,num_seq_blocks,num_programs_seq):seq_indicesseq_block_id*BLOCK_SIZE_SEQtl.arange(0,BLOCK_SIZE_SEQ)seq_maskseq_indicesseq_len head_idxpid%n_heads input_row_base_offsseq_indices*seq_len_stride input_row_offsinput_row_base_offshead_idx*head_dim output_row_base_offsseq_indices*seq_len_stride output_row_offsoutput_row_base_offshead_idx*head_dim q_block_offsinput_row_offs[:,None]col_offs[None,:]# load q and cast to float32qtl.load(inout_ptrq_block_offs,maskseq_mask[:,None],other0.0)q_fp32q.to(tl.float32)# compute mean varrow_meantl.sum(q_fp32,axis1,keep_dimsTrue)/head_dim row_vartl.sum(q_fp32*q_fp32,axis1,keep_dimsTrue)/head_dim-row_mean*row_mean rstdtl.rsqrt(row_vareps)# normalize qq_fp32(q_fp32-row_mean)*rstd q_fp16q_fp32.to(inout_ptr.dtype.element_ty)q_fp16q_fp16*gammabeta# store back qq_out_block_offsoutput_row_offs[:,None]col_offs[None,:]tl.store(inout_ptrq_out_block_offs,q_fp16,maskseq_mask[:,None])