从单头到多头用PyTorch MultiheadAttention复现Transformer核心模块的完整流程在自然语言处理和计算机视觉领域Transformer架构已经成为革命性的基础模型。而MultiheadAttention多头注意力作为其核心组件理解其工作原理对于掌握现代深度学习模型至关重要。本文将带您从最基础的Scaled Dot-Product Attention缩放点积注意力开始逐步构建完整的多头注意力模块最终实现一个简易的Transformer编码器块。1. 注意力机制基础从单头开始注意力机制的核心思想是让模型能够关注输入序列中最重要的部分。让我们先实现一个最基本的单头注意力模块import torch import torch.nn as nn import torch.nn.functional as F class ScaledDotProductAttention(nn.Module): def __init__(self, d_k): super().__init__() self.d_k d_k # 查询和键的维度 def forward(self, Q, K, V, maskNone): Q: 查询矩阵 (batch_size, seq_len, d_k) K: 键矩阵 (batch_size, seq_len, d_k) V: 值矩阵 (batch_size, seq_len, d_v) mask: 可选掩码矩阵 scores torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k)) if mask is not None: scores scores.masked_fill(mask 0, -1e9) attn_weights F.softmax(scores, dim-1) output torch.matmul(attn_weights, V) return output, attn_weights这个简单的实现包含了注意力机制的三个关键步骤计算注意力分数通过查询(Q)和键(K)的点积计算相关性缩放处理除以√d_k防止梯度消失Softmax归一化得到注意力权重加权求和用注意力权重对值(V)进行加权注意在实际应用中d_k通常等于d_v但这不是必须的。这种设计灵活性允许我们根据任务需求调整维度。2. 多头注意力的实现原理单头注意力只能学习一种关注模式而多头注意力通过并行运行多个注意力头可以捕获输入序列中不同类型的依赖关系。下面是多头注意力的关键设计并行处理多个注意力头同时处理输入线性投影每个头有自己的可学习投影矩阵拼接输出将所有头的输出拼接后通过线性变换class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads 0, d_model必须能被num_heads整除 self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads # 线性投影层 self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) self.attention ScaledDotProductAttention(self.d_k) def split_heads(self, x): 将输入张量分割为多个头 x: (batch_size, seq_len, d_model) 返回: (batch_size, num_heads, seq_len, d_k) batch_size, seq_len, _ x.size() return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) def forward(self, Q, K, V, maskNone): # 线性投影 Q self.split_heads(self.W_q(Q)) K self.split_heads(self.W_k(K)) V self.split_heads(self.W_v(V)) # 计算注意力 attn_output, attn_weights self.attention(Q, K, V, mask) # 拼接多头输出 batch_size, _, seq_len, _ attn_output.size() concat attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # 最终线性变换 output self.W_o(concat) return output, attn_weights3. 使用PyTorch内置MultiheadAttention模块虽然手动实现有助于理解原理但在实际项目中我们通常会使用PyTorch内置的torch.nn.MultiheadAttention模块它经过高度优化且功能完善。下面是使用示例import torch.nn as nn # 参数设置 embed_dim 512 # 输入和输出的维度 num_heads 8 # 注意力头的数量 dropout 0.1 # Dropout概率 # 创建多头注意力层 multihead_attn nn.MultiheadAttention( embed_dimembed_dim, num_headsnum_heads, dropoutdropout, batch_firstTrue # 输入格式为(batch, seq, feature) ) # 示例输入 (batch_size2, seq_len10, embed_dim512) query torch.randn(2, 10, 512) key torch.randn(2, 10, 512) value torch.randn(2, 10, 512) # 前向传播 attn_output, attn_weights multihead_attn( queryquery, keykey, valuevalue )PyTorch的MultiheadAttention模块提供了几个重要参数参数描述默认值embed_dim模型的总维度-num_heads并行注意力头的数量-dropout注意力权重上的Dropout概率0.0bias是否添加偏置到投影层Truebatch_first输入/输出是否为(batch, seq, feature)格式False提示在大多数现代应用中建议设置batch_firstTrue以获得更直观的输入输出格式。4. 构建简易Transformer编码器块现在我们将多头注意力模块整合到一个完整的Transformer编码器块中class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, num_heads, dim_feedforward2048, dropout0.1): super().__init__() self.self_attn nn.MultiheadAttention( embed_dimd_model, num_headsnum_heads, dropoutdropout, batch_firstTrue ) # 前馈网络 self.linear1 nn.Linear(d_model, dim_feedforward) self.linear2 nn.Linear(dim_feedforward, d_model) # 归一化层 self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) # Dropout self.dropout nn.Dropout(dropout) self.activation nn.ReLU() def forward(self, src, src_maskNone): # 自注意力 attn_output, _ self.self_attn( querysrc, keysrc, valuesrc, attn_masksrc_mask ) # 残差连接和归一化 src src self.dropout(attn_output) src self.norm1(src) # 前馈网络 ff_output self.linear2(self.dropout(self.activation(self.linear1(src)))) # 残差连接和归一化 src src self.dropout(ff_output) src self.norm2(src) return src这个编码器块包含以下关键组件多头自注意力层处理输入序列的内部关系前馈神经网络对每个位置独立进行非线性变换残差连接帮助梯度流动缓解梯度消失层归一化稳定训练过程5. 实际应用中的注意事项在真实项目中使用多头注意力时有几个关键点需要考虑维度匹配确保embed_dim能被num_heads整除掩码处理正确处理填充掩码和未来掩码性能优化利用PyTorch的torch.nn.functional.scaled_dot_product_attention进行高效计算梯度检查多头注意力可能面临梯度不稳定问题需要适当调整学习率# 高效注意力计算示例 def efficient_attention(Q, K, V, maskNone): return torch.nn.functional.scaled_dot_product_attention( Q, K, V, attn_maskmask, dropout_p0.1, is_causalFalse )多头注意力机制的成功应用需要平衡理论理解和工程实践。通过从基础实现开始逐步过渡到优化实现开发者可以更好地掌握这一强大工具为构建更复杂的Transformer模型打下坚实基础。
从单头到多头:用PyTorch MultiheadAttention复现Transformer核心模块的完整流程
发布时间:2026/5/19 11:59:01
从单头到多头用PyTorch MultiheadAttention复现Transformer核心模块的完整流程在自然语言处理和计算机视觉领域Transformer架构已经成为革命性的基础模型。而MultiheadAttention多头注意力作为其核心组件理解其工作原理对于掌握现代深度学习模型至关重要。本文将带您从最基础的Scaled Dot-Product Attention缩放点积注意力开始逐步构建完整的多头注意力模块最终实现一个简易的Transformer编码器块。1. 注意力机制基础从单头开始注意力机制的核心思想是让模型能够关注输入序列中最重要的部分。让我们先实现一个最基本的单头注意力模块import torch import torch.nn as nn import torch.nn.functional as F class ScaledDotProductAttention(nn.Module): def __init__(self, d_k): super().__init__() self.d_k d_k # 查询和键的维度 def forward(self, Q, K, V, maskNone): Q: 查询矩阵 (batch_size, seq_len, d_k) K: 键矩阵 (batch_size, seq_len, d_k) V: 值矩阵 (batch_size, seq_len, d_v) mask: 可选掩码矩阵 scores torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k)) if mask is not None: scores scores.masked_fill(mask 0, -1e9) attn_weights F.softmax(scores, dim-1) output torch.matmul(attn_weights, V) return output, attn_weights这个简单的实现包含了注意力机制的三个关键步骤计算注意力分数通过查询(Q)和键(K)的点积计算相关性缩放处理除以√d_k防止梯度消失Softmax归一化得到注意力权重加权求和用注意力权重对值(V)进行加权注意在实际应用中d_k通常等于d_v但这不是必须的。这种设计灵活性允许我们根据任务需求调整维度。2. 多头注意力的实现原理单头注意力只能学习一种关注模式而多头注意力通过并行运行多个注意力头可以捕获输入序列中不同类型的依赖关系。下面是多头注意力的关键设计并行处理多个注意力头同时处理输入线性投影每个头有自己的可学习投影矩阵拼接输出将所有头的输出拼接后通过线性变换class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads 0, d_model必须能被num_heads整除 self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads # 线性投影层 self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) self.attention ScaledDotProductAttention(self.d_k) def split_heads(self, x): 将输入张量分割为多个头 x: (batch_size, seq_len, d_model) 返回: (batch_size, num_heads, seq_len, d_k) batch_size, seq_len, _ x.size() return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) def forward(self, Q, K, V, maskNone): # 线性投影 Q self.split_heads(self.W_q(Q)) K self.split_heads(self.W_k(K)) V self.split_heads(self.W_v(V)) # 计算注意力 attn_output, attn_weights self.attention(Q, K, V, mask) # 拼接多头输出 batch_size, _, seq_len, _ attn_output.size() concat attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # 最终线性变换 output self.W_o(concat) return output, attn_weights3. 使用PyTorch内置MultiheadAttention模块虽然手动实现有助于理解原理但在实际项目中我们通常会使用PyTorch内置的torch.nn.MultiheadAttention模块它经过高度优化且功能完善。下面是使用示例import torch.nn as nn # 参数设置 embed_dim 512 # 输入和输出的维度 num_heads 8 # 注意力头的数量 dropout 0.1 # Dropout概率 # 创建多头注意力层 multihead_attn nn.MultiheadAttention( embed_dimembed_dim, num_headsnum_heads, dropoutdropout, batch_firstTrue # 输入格式为(batch, seq, feature) ) # 示例输入 (batch_size2, seq_len10, embed_dim512) query torch.randn(2, 10, 512) key torch.randn(2, 10, 512) value torch.randn(2, 10, 512) # 前向传播 attn_output, attn_weights multihead_attn( queryquery, keykey, valuevalue )PyTorch的MultiheadAttention模块提供了几个重要参数参数描述默认值embed_dim模型的总维度-num_heads并行注意力头的数量-dropout注意力权重上的Dropout概率0.0bias是否添加偏置到投影层Truebatch_first输入/输出是否为(batch, seq, feature)格式False提示在大多数现代应用中建议设置batch_firstTrue以获得更直观的输入输出格式。4. 构建简易Transformer编码器块现在我们将多头注意力模块整合到一个完整的Transformer编码器块中class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, num_heads, dim_feedforward2048, dropout0.1): super().__init__() self.self_attn nn.MultiheadAttention( embed_dimd_model, num_headsnum_heads, dropoutdropout, batch_firstTrue ) # 前馈网络 self.linear1 nn.Linear(d_model, dim_feedforward) self.linear2 nn.Linear(dim_feedforward, d_model) # 归一化层 self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) # Dropout self.dropout nn.Dropout(dropout) self.activation nn.ReLU() def forward(self, src, src_maskNone): # 自注意力 attn_output, _ self.self_attn( querysrc, keysrc, valuesrc, attn_masksrc_mask ) # 残差连接和归一化 src src self.dropout(attn_output) src self.norm1(src) # 前馈网络 ff_output self.linear2(self.dropout(self.activation(self.linear1(src)))) # 残差连接和归一化 src src self.dropout(ff_output) src self.norm2(src) return src这个编码器块包含以下关键组件多头自注意力层处理输入序列的内部关系前馈神经网络对每个位置独立进行非线性变换残差连接帮助梯度流动缓解梯度消失层归一化稳定训练过程5. 实际应用中的注意事项在真实项目中使用多头注意力时有几个关键点需要考虑维度匹配确保embed_dim能被num_heads整除掩码处理正确处理填充掩码和未来掩码性能优化利用PyTorch的torch.nn.functional.scaled_dot_product_attention进行高效计算梯度检查多头注意力可能面临梯度不稳定问题需要适当调整学习率# 高效注意力计算示例 def efficient_attention(Q, K, V, maskNone): return torch.nn.functional.scaled_dot_product_attention( Q, K, V, attn_maskmask, dropout_p0.1, is_causalFalse )多头注意力机制的成功应用需要平衡理论理解和工程实践。通过从基础实现开始逐步过渡到优化实现开发者可以更好地掌握这一强大工具为构建更复杂的Transformer模型打下坚实基础。