注意力机制广泛存在于现在的深度学习网络结构中,使用得到能够提升模型的学习效果。本文讲使用Pytorch实现多头自注意力模块。
一个典型的自注意力模块由Q、K、V三个矩阵的运算组成,Q、K、V三个矩阵都由原特征矩阵变换而来,所以本质上来说是对自身的运算。
而多头注意力机制则是单头注意力机制的进化版,把每次attention运算分组(头)进行,能够从多个维度提炼特征信息。具体原理可以参看相关的科普文章,下面是Pytorch实现。
as nn
class MHSA(nn.Module):def __init__(self, num_heads, dim):super().__init__()# Q, K, V 转换矩阵,这里假设输入和输出的特征维度相同self.q = nn.Linear(dim, dim)self.k = nn.Linear(dim, dim)self.v = nn.Linear(dim, dim)self.num_heads = num_headsdef forward(self, x):B, N, C = x.shape# 生成转换矩阵并分多头q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)k = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)v = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)# 点积得到attention scoreattn = q @ k.transpose(2, 3) * (x.shape[-1] ** -0.5)attn = attn.softmax(dim=-1)# 乘上attention score并输出v = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)return v
本文发布于:2024-02-08 19:47:19,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170739295268458.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |