从零实现Transformer第 3 部分 - 掩码多头注意力的掩码广播Broadcasting of Masks in Masked Multi-Head Attentionflyfish以生成填充掩码 前瞻掩码的组合掩码 为例1. 生成 Padding Mask填充掩码屏蔽序列中的填充占位符pad_id0填充的0是无效字符模型不应该关注、学习这些无意义的占位符2. 生成 Look-ahead Mask前瞻掩码屏蔽当前位置之后的所有未来 token解码器是自回归生成一步步生成文本绝对不能提前看到未来的词3. 合并掩码用|运算把两个掩码合二为一只要是「填充位」或「未来位」统一屏蔽True用处输出形状[batch, 1, seq_len, seq_len]这个掩码直接传入解码器的多头自注意力层掩码为True→ 注意力分数置为负无穷模型完全忽略该位置掩码为False→ 正常计算注意力模型可以关注该位置importtorchdefcreate_tgt_mask(tgt_ids,pad_id):创建目标序列掩码padding mask look-ahead mask#1.2维padding掩码[batch,seq_len]padding_mask_2d(tgt_idspad_id)#2.升维适配注意力维度-[batch,1,1,seq_len]tgt_padding_maskpadding_mask_2d.unsqueeze(1).unsqueeze(1)#3.生成序列长度 tgt_seq_lentgt_ids.shape[1]#4.构造上三角前瞻掩码[seq_len,seq_len]#diagonal1主对角线上方为1遮挡未来位置look_ahead_masktorch.triu(torch.ones(tgt_seq_len,tgt_seq_len,devicetgt_ids.device),diagonal1).bool()#5.升维支持批量广播-[1,1,seq_len,seq_len]look_ahead_masklook_ahead_mask.unsqueeze(0).unsqueeze(0)#6.合并掩码任意一个为True就遮挡returntgt_padding_mask|look_ahead_mask # 测试if__name____main__:pad_id0#2个batch序列长度50为padding tgt_idstorch.tensor([[1,2,3,0,0],[4,5,0,0,0]])maskcreate_tgt_mask(tgt_ids,pad_id)print(最终掩码形状:,mask.shape)# torch.Size([2,1,5,5])print(掩码内容:\n,mask)输出最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])广播 PyTorch 自动把 形状不同但兼容 的张量复制拉伸成相同形状然后再运算两个张量形状returntgt_padding_mask|look_ahead_mask两个输入形状tgt_padding_mask[B, 1, 1, S]→ 举例[2, 1, 1, 4]look_ahead_mask[1, 1, S, S]→ 举例[1, 1, 4, 4]广播目标把两个张量都自动变成[2, 1, 4, 4]再做|运算广播规则维度为1的位置可以自动复制扩展成任意大小扩展后两个张量形状完全一致就能运算例子1最简单的2维广播模拟小张量自动拉伸importtorch# 形状 [1,4] → 1行4列atorch.tensor([[True,False,True,False]])# 形状 [4,4] → 4行4列btorch.ones(4,4).bool()# 广播运算a自动复制4行变成[4,4]再和b运算ca|bprint(a形状:,a.shape)print(b形状:,b.shape)print(广播后运算结果形状:,c.shape)# 输出 [4,4][1,4]自动扩成[4,4]例子23维广播过渡# [2,1,4]atorch.rand(2,1,4).bool()# [1,4,4]btorch.rand(1,4,4).bool()# 自动广播成 [2,4,4]ca|bprint(c.shape)# [2,4,4]例子3模拟代码的4维广播importtorch# 模拟两个掩码B,S2,4# 1. padding掩码 [2,1,1,4]tgt_pad_masktorch.rand(B,1,1,S).bool()# 2. 前瞻掩码 [1,1,4,4]look_ahead_masktorch.rand(1,1,S,S).bool()# 广播运算final_masktgt_pad_mask|look_ahead_mask# 打印形状print(padding掩码形状:,tgt_pad_mask.shape)# [2,1,1,4]print(前瞻掩码形状:,look_ahead_mask.shape)# [1,1,4,4]print(广播后最终形状:,final_mask.shape)# [2,1,4,4]代码里用到的输入参数# 2个句子每个句子最长5个词tgt_idstorch.tensor([[1,2,3,0,0],# 第1个样本有效词3个后2个是填充0[4,5,0,0,0]# 第2个样本有效词2个后3个是填充0])pad_id0# 0代表填充位批次大小B 2序列长度S 5标准掩码维度[batch, num_heads, seq_q, seq_k]最终维度是[2, 1, 5, 5][2, 1, 5, 5] [批次B, 头数H, 查询序列长Q, 键序列长K]2一次性处理2 个句子batch21代码里没做多头默认1 个注意力头5Query 向量数量 目标序列长度 55Key 向量数量 目标序列长度 5代码里的广播tgt_padding_mask形状[2, 1, 1, 5]look_ahead_mask形状[1, 1, 5, 5]PyTorch自动广播把两个张量都拉伸为[2, 1, 5, 5]再做|运算掩码内容最终掩码 前瞻掩码或填充掩码True 遮挡不让看False 允许看1. 前瞻掩码固定不变所有样本共用torch.triu(..., diagonal1)生成固定上三角矩阵# 5x5 前瞻掩码对角线以上全是True遮挡未来词 [ [F, T, T, T, T], # 第1个词只能看自己不能看后面4个 [F, F, T, T, T], # 第2个词能看自己前1个不能看后面3个 [F, F, F, T, T], # 第3个词能看自己前2个不能看后面2个 [F, F, F, F, T], # 第4个词能看自己前3个不能看后面1个 [F, F, F, F, F] # 第5个词能看所有前面的词 ]2. 填充掩码每个样本不一样样本1[1,2,3,0,0]第4、5位是填充→ 掩码[F,F,F,T,T]样本2[4,5,0,0,0]第3、4、5位是填充→ 掩码[F,F,T,T,T]最终合并结果样本1 输出第一块 5x5[[False, True, True, True, True], [False, False, True, True, True], [False, False, False, True, True], [False, False, False, True, True], # 第4位是填充永久遮挡 [False, False, False, True, True]] # 第5位是填充永久遮挡前3行只受前瞻掩码影响后2行前瞻掩码 填充掩码双重遮挡样本2 输出第二块 5x5[[False, True, True, True, True], [False, False, True, True, True], [False, False, True, True, True], # 第3位是填充永久遮挡 [False, False, True, True, True], # 第4位是填充永久遮挡 [False, False, True, True, True]] # 第5位是填充永久遮挡前2行只受前瞻掩码影响后3行前瞻掩码 填充掩码双重遮挡简单的流程就是维度[2,1,5,5][2个句子, 1个注意力头, 每个句子5个Query, 每个句子5个Key]掩码内容上三角的True 遮挡未来词前瞻掩码后半列的True 遮挡填充0填充掩码两者合并就是看到的输出不用广播的写法importtorchdefcreate_tgt_mask_no_broadcast(tgt_ids,pad_id):创建目标序列掩码无广播版手动扩展维度B,Stgt_ids.shape # 直接获取批次B2序列长S5#1.2维padding掩码[batch,seq_len]→[2,5]padding_mask_2d(tgt_idspad_id)#2.升维 →[B,1,1,S]→[2,1,1,5]tgt_padding_maskpadding_mask_2d.unsqueeze(1).unsqueeze(1)#替代广播# 把第3维seq_q从1复制成 S → 形状变成[B,1,S,S][2,1,5,5]tgt_padding_masktgt_padding_mask.repeat(1,1,S,1)#3.构造上三角前瞻掩码[S,S]→[5,5]look_ahead_masktorch.triu(torch.ones(S,S,devicetgt_ids.device),diagonal1).bool()#4.升维 →[1,1,S,S]→[1,1,5,5]look_ahead_masklook_ahead_mask.unsqueeze(0).unsqueeze(0)#替代广播# 把第0维batch从1复制成 B → 形状变成[B,1,S,S][2,1,5,5]look_ahead_masklook_ahead_mask.repeat(B,1,1,1)#6.两个掩码形状完全一致直接运算无任何广播returntgt_padding_mask|look_ahead_mask # 测试if__name____main__:pad_id0tgt_idstorch.tensor([[1,2,3,0,0],[4,5,0,0,0]])# 运行无广播版本 maskcreate_tgt_mask_no_broadcast(tgt_ids,pad_id)print(最终掩码形状:,mask.shape)# 依旧是 torch.Size([2,1,5,5])print(掩码内容:\n,mask)输出最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])
从零实现Transformer:第 3 部分 - 掩码多头注意力的掩码广播(Broadcasting of Masks in Masked Multi-Head Attention)
发布时间:2026/6/27 23:12:42
从零实现Transformer第 3 部分 - 掩码多头注意力的掩码广播Broadcasting of Masks in Masked Multi-Head Attentionflyfish以生成填充掩码 前瞻掩码的组合掩码 为例1. 生成 Padding Mask填充掩码屏蔽序列中的填充占位符pad_id0填充的0是无效字符模型不应该关注、学习这些无意义的占位符2. 生成 Look-ahead Mask前瞻掩码屏蔽当前位置之后的所有未来 token解码器是自回归生成一步步生成文本绝对不能提前看到未来的词3. 合并掩码用|运算把两个掩码合二为一只要是「填充位」或「未来位」统一屏蔽True用处输出形状[batch, 1, seq_len, seq_len]这个掩码直接传入解码器的多头自注意力层掩码为True→ 注意力分数置为负无穷模型完全忽略该位置掩码为False→ 正常计算注意力模型可以关注该位置importtorchdefcreate_tgt_mask(tgt_ids,pad_id):创建目标序列掩码padding mask look-ahead mask#1.2维padding掩码[batch,seq_len]padding_mask_2d(tgt_idspad_id)#2.升维适配注意力维度-[batch,1,1,seq_len]tgt_padding_maskpadding_mask_2d.unsqueeze(1).unsqueeze(1)#3.生成序列长度 tgt_seq_lentgt_ids.shape[1]#4.构造上三角前瞻掩码[seq_len,seq_len]#diagonal1主对角线上方为1遮挡未来位置look_ahead_masktorch.triu(torch.ones(tgt_seq_len,tgt_seq_len,devicetgt_ids.device),diagonal1).bool()#5.升维支持批量广播-[1,1,seq_len,seq_len]look_ahead_masklook_ahead_mask.unsqueeze(0).unsqueeze(0)#6.合并掩码任意一个为True就遮挡returntgt_padding_mask|look_ahead_mask # 测试if__name____main__:pad_id0#2个batch序列长度50为padding tgt_idstorch.tensor([[1,2,3,0,0],[4,5,0,0,0]])maskcreate_tgt_mask(tgt_ids,pad_id)print(最终掩码形状:,mask.shape)# torch.Size([2,1,5,5])print(掩码内容:\n,mask)输出最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])广播 PyTorch 自动把 形状不同但兼容 的张量复制拉伸成相同形状然后再运算两个张量形状returntgt_padding_mask|look_ahead_mask两个输入形状tgt_padding_mask[B, 1, 1, S]→ 举例[2, 1, 1, 4]look_ahead_mask[1, 1, S, S]→ 举例[1, 1, 4, 4]广播目标把两个张量都自动变成[2, 1, 4, 4]再做|运算广播规则维度为1的位置可以自动复制扩展成任意大小扩展后两个张量形状完全一致就能运算例子1最简单的2维广播模拟小张量自动拉伸importtorch# 形状 [1,4] → 1行4列atorch.tensor([[True,False,True,False]])# 形状 [4,4] → 4行4列btorch.ones(4,4).bool()# 广播运算a自动复制4行变成[4,4]再和b运算ca|bprint(a形状:,a.shape)print(b形状:,b.shape)print(广播后运算结果形状:,c.shape)# 输出 [4,4][1,4]自动扩成[4,4]例子23维广播过渡# [2,1,4]atorch.rand(2,1,4).bool()# [1,4,4]btorch.rand(1,4,4).bool()# 自动广播成 [2,4,4]ca|bprint(c.shape)# [2,4,4]例子3模拟代码的4维广播importtorch# 模拟两个掩码B,S2,4# 1. padding掩码 [2,1,1,4]tgt_pad_masktorch.rand(B,1,1,S).bool()# 2. 前瞻掩码 [1,1,4,4]look_ahead_masktorch.rand(1,1,S,S).bool()# 广播运算final_masktgt_pad_mask|look_ahead_mask# 打印形状print(padding掩码形状:,tgt_pad_mask.shape)# [2,1,1,4]print(前瞻掩码形状:,look_ahead_mask.shape)# [1,1,4,4]print(广播后最终形状:,final_mask.shape)# [2,1,4,4]代码里用到的输入参数# 2个句子每个句子最长5个词tgt_idstorch.tensor([[1,2,3,0,0],# 第1个样本有效词3个后2个是填充0[4,5,0,0,0]# 第2个样本有效词2个后3个是填充0])pad_id0# 0代表填充位批次大小B 2序列长度S 5标准掩码维度[batch, num_heads, seq_q, seq_k]最终维度是[2, 1, 5, 5][2, 1, 5, 5] [批次B, 头数H, 查询序列长Q, 键序列长K]2一次性处理2 个句子batch21代码里没做多头默认1 个注意力头5Query 向量数量 目标序列长度 55Key 向量数量 目标序列长度 5代码里的广播tgt_padding_mask形状[2, 1, 1, 5]look_ahead_mask形状[1, 1, 5, 5]PyTorch自动广播把两个张量都拉伸为[2, 1, 5, 5]再做|运算掩码内容最终掩码 前瞻掩码或填充掩码True 遮挡不让看False 允许看1. 前瞻掩码固定不变所有样本共用torch.triu(..., diagonal1)生成固定上三角矩阵# 5x5 前瞻掩码对角线以上全是True遮挡未来词 [ [F, T, T, T, T], # 第1个词只能看自己不能看后面4个 [F, F, T, T, T], # 第2个词能看自己前1个不能看后面3个 [F, F, F, T, T], # 第3个词能看自己前2个不能看后面2个 [F, F, F, F, T], # 第4个词能看自己前3个不能看后面1个 [F, F, F, F, F] # 第5个词能看所有前面的词 ]2. 填充掩码每个样本不一样样本1[1,2,3,0,0]第4、5位是填充→ 掩码[F,F,F,T,T]样本2[4,5,0,0,0]第3、4、5位是填充→ 掩码[F,F,T,T,T]最终合并结果样本1 输出第一块 5x5[[False, True, True, True, True], [False, False, True, True, True], [False, False, False, True, True], [False, False, False, True, True], # 第4位是填充永久遮挡 [False, False, False, True, True]] # 第5位是填充永久遮挡前3行只受前瞻掩码影响后2行前瞻掩码 填充掩码双重遮挡样本2 输出第二块 5x5[[False, True, True, True, True], [False, False, True, True, True], [False, False, True, True, True], # 第3位是填充永久遮挡 [False, False, True, True, True], # 第4位是填充永久遮挡 [False, False, True, True, True]] # 第5位是填充永久遮挡前2行只受前瞻掩码影响后3行前瞻掩码 填充掩码双重遮挡简单的流程就是维度[2,1,5,5][2个句子, 1个注意力头, 每个句子5个Query, 每个句子5个Key]掩码内容上三角的True 遮挡未来词前瞻掩码后半列的True 遮挡填充0填充掩码两者合并就是看到的输出不用广播的写法importtorchdefcreate_tgt_mask_no_broadcast(tgt_ids,pad_id):创建目标序列掩码无广播版手动扩展维度B,Stgt_ids.shape # 直接获取批次B2序列长S5#1.2维padding掩码[batch,seq_len]→[2,5]padding_mask_2d(tgt_idspad_id)#2.升维 →[B,1,1,S]→[2,1,1,5]tgt_padding_maskpadding_mask_2d.unsqueeze(1).unsqueeze(1)#替代广播# 把第3维seq_q从1复制成 S → 形状变成[B,1,S,S][2,1,5,5]tgt_padding_masktgt_padding_mask.repeat(1,1,S,1)#3.构造上三角前瞻掩码[S,S]→[5,5]look_ahead_masktorch.triu(torch.ones(S,S,devicetgt_ids.device),diagonal1).bool()#4.升维 →[1,1,S,S]→[1,1,5,5]look_ahead_masklook_ahead_mask.unsqueeze(0).unsqueeze(0)#替代广播# 把第0维batch从1复制成 B → 形状变成[B,1,S,S][2,1,5,5]look_ahead_masklook_ahead_mask.repeat(B,1,1,1)#6.两个掩码形状完全一致直接运算无任何广播returntgt_padding_mask|look_ahead_mask # 测试if__name____main__:pad_id0tgt_idstorch.tensor([[1,2,3,0,0],[4,5,0,0,0]])# 运行无广播版本 maskcreate_tgt_mask_no_broadcast(tgt_ids,pad_id)print(最终掩码形状:,mask.shape)# 依旧是 torch.Size([2,1,5,5])print(掩码内容:\n,mask)输出最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])