CANN/ops-transformer稠密LightningIndexer梯度KL损失算子 DenseLightningIndexerGradKLLoss【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer产品支持情况产品是否支持Ascend 950PR/Ascend 950DT×Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明算子功能DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来从而减少长序列场景下Attention的计算量加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。计算公式Top-k value的计算公式$$ I_{t,:}W_{t,:}ReLU(\tilde{q}{t,:}\tilde{K}{:t,:}^\top) $$$W_{t,:}$是第$t$个token对应的$weights$$\tilde{q}_{t,:}$是$\tilde{q}$矩阵第$t$个token对应的$G$个query头合轴后的结果$\tilde{K}_{:t,:}$为$t$行$\tilde{K}$矩阵。正向的Softmax对应公式$$ p_{t,:} \text{Softmax}(q_{t,:} K_{:t,:}^\top/\sqrt{d}) $$$p_{t,:}$是第$t$个token对应的Softmax结果$q_{t,:}$是$q$矩阵第$t$个token对应的$G$个query头合轴后的结果${K}_{:t,:}$为$t$行$K$矩阵。npu_lightning_indexer会单独训练对应的loss function为$$ Loss{}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:})) $$其中$p_{t,:}$是target distribution通过对main attention score 进行所有的head的求和然后把求和结果沿着上下文方向进行L1正则化得到。$D_{KL}$为KL散度其表达式为$$ D_{KL}(a||b){}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)} $$通过求导可得Loss的梯度表达式$$ dI\mathop{{}}\nolimits_{{t,:}}Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right. $$利用链式法则可以进行weightsquery和key矩阵的梯度计算 $$ dW\mathop{{}}\nolimits_{{t,:}}dI\mathop{{}}\nolimits_{{t,:}}\text{} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right. $$$$ d\mathop{{\tilde{q}}}\nolimits_{{t,:}}dS\mathop{{}}\nolimits_{{t,:}}\tilde{K}\mathop{{}}\nolimits_{{:t,:}} $$$$ d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right. $$其中$S$为$\tilde{q}$和$K$矩阵乘的结果。参数说明参数名输入/输出/属性描述数据类型数据格式query输入attention结构的输入Q。FLOAT16、BFLOAT16NDkey输入attention结构的输入K。FLOAT16、BFLOAT16NDqueryIndex输入lightningIndexer结构的输入queryIndex。FLOAT16、BFLOAT16NDkeyIndex输入lightningIndexer结构的输入keyIndex。FLOAT16、BFLOAT16NDweights输入权重。FLOAT16、BFLOAT16NDsoftmaxMax输入Device侧的aclTensor注意力正向计算的中间输出。FLOAT32NDsoftmaxSum输入Device侧的aclTensor注意力正向计算的中间输出。FLOAT32NDsoftmaxMaxIndex输入Device侧的aclTensor注意力正向计算的中间输出。FLOAT32NDsoftmaxSumIndex输入Device侧的aclTensor注意力正向计算的中间输出。FLOAT32NDqueryRope输入MLA rope部分Query位置编码的输出。FLOAT16、BFLOAT16NDkeyRope输入MLA rope部分Key位置编码的输出。FLOAT16、BFLOAT16NDactualSeqLengthsQuery输入每个Batch中Query的有效token数。INT64NDactualSeqLengthsKey输入每个Batch中Key的有效token数。INT64NDscaleValue输入缩放系数。double-layout输入layout格式。char*-sparseMode输入sparse的模式。INT64-preTokens输入用于稀疏计算表示Attention需要和前几个token计算关联。INT64-nextTokens输入用于稀疏计算表示Attention需要和后几个token计算关联。INT64-dQueryIndex输出QueryIndex的梯度。FLOAT16、BFLOAT16NDdKeyIndex输出KeyIndex的梯度。FLOAT16、BFLOAT16NDdWeights输出Weights的梯度。FLOAT16、BFLOAT16NDloss输出损失函数值。FLOAT32ND约束说明无调用说明调用方式调用样例说明aclnn调用test_aclnn_dense_lightning_indexer_grad_kl_loss通过aclnnDenseLightningIndexerGradKLLoss接口方式调用dense_lightning_indexer_grad_kl_loss算子。【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考