在实践中,当给定相同的查询、键和值的集合时, 希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如短距离依赖和长距离依赖关系)。 因此允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。
为此与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 ℎ 组不同的线性投影(linear projections)来变换查询、键和值。 然后,这 ℎ 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 ℎ 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出,这种设计被称为多头注意力(multihead attention)。 对于 ℎ 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。 下图展示了使用全连接层来实现可学习的线性变换的多头注意力。
在实现多头注意力之前,用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q mathbf{q} in mathbb{R}^{d_q} q∈Rdq、键 k ∈ R d k mathbf{k} in mathbb{R}^{d_k} k∈Rdk和值 v ∈ R d v mathbf{v} in mathbb{R}^{d_v} v∈Rdv,每个注意力头 h i mathbf{h}_i hi( i = 1 , … , h i = 1, ldots, h i=1,…,h)的计算方法为:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , mathbf{h}_i = f(mathbf W_i^{(q)}mathbf q, mathbf W_i^{(k)}mathbf k,mathbf W_i^{(v)}mathbf v) in mathbb R^{p_v}, hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,
其中,可学习的参数包括 W i ( q ) ∈ R p q × d q mathbf W_i^{(q)}inmathbb R^{p_qtimes d_q} Wi(q)∈Rpq×dq、 W i ( k ) ∈ R p k × d k mathbf W_i^{(k)}inmathbb R^{p_ktimes d_k} Wi(k)∈Rpk×dk和 W i ( v ) ∈ R p v × d v mathbf W_i^{(v)}inmathbb R^{p_vtimes d_v} Wi(v)∈Rpv×dv,以及代表注意力汇聚的函数 f f f。 f f f可以是加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v mathbf W_oinmathbb R^{p_otimes h p_v} Wo∈Rpo×hpv:
W o [ h 1 ⋮ h h ] ∈ R p o . mathbf W_o begin{bmatrix}mathbf h_1\vdots\mathbf h_hend{bmatrix} in mathbb{R}^{p_o}. Wo⎣ ⎡h1⋮hh⎦ ⎤∈Rpo.
基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。
在实现过程中,选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h个头。在下面的实现中, p o p_o po是通过参数num_hiddens指定的。
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self,query_size,key_size,value_size,num_hiddens,num_heads,dropout,bias=False):super(MultiHeadAttention,self).__init__()self.num_heads = num_headsself.attention = h.DotProductAttention(dropout)self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)self.W_k = nn.Linear(key_size,num_hiddens,bias=bias)self.W_v = nn.Linear(value_size,num_hiddens,bias=bias)self.W_o = nn.Linear(num_hiddens,num_hiddens,bias=bias)def forward(self,queries,keys,values,valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries),self.num_heads)keys = transpose_qkv(self.W_k(keys),self.num_heads)values = transpose_qkv(self.W_v(values),self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = peat_interleave(valid_lens,repeats=self.num_heads,dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries,keys,values,valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output,self.num_heads)return self.W_o(output_concat)
为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说transpose_output函数反转了transpose_qkv函数的操作。
def transpose_qkv(X,num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0,2,1,3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)shape(-1,X.shape[2],X.shape[3])def transpose_output(X,num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])X = X.permute(0,2,1,3)shape(X.shape[0],X.shape[1],-1)
下面使用键和值相同的例子来测试编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_size,num_queries,num_hiddens)。
num_hiddens,num_heads = 100,5
multiHeadAttention = MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,5,0.5)
multiHeadAttention.eval()
输出结果如下:
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size,num_queries = 2,4
num_kvpairs,valid_lens = sor([3,2])
Y = s(size=(batch_size,num_kvpairs,num_hiddens))
X = s(size=(batch_size,num_queries,num_hiddens))
multiHeadAttention(X,Y,Y,valid_lens).shape
输出结果如下:
torch.Size([2, 4, 100])
import torch
h
from torch import nndef transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)shape(-1, X.shape[2], X.shape[3])def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)shape(X.shape[0], X.shape[1], -1)class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.attention = h.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = peat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)num_hiddens, num_heads = 100, 5
multiHeadAttention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, 5, 0.5)
multiHeadAttention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, sor([3, 2])
Y = s(size=(batch_size, num_kvpairs, num_hiddens))
X = s(size=(batch_size, num_queries, num_hiddens))
multiHeadAttention(X, Y, Y, valid_lens).shape
注意力机制第一篇:李沐动手学深度学习V2-注意力机制
注意力机制第二篇:李沐动手学深度学习V2-注意力评分函数
注意力机制第三篇:李沐动手学深度学习V2-基于注意力机制的seq2seq
注意力机制第四篇:李沐动手学深度学习V2-自注意力机制之位置编码
注意力机制第五篇:李沐动手学深度学习V2-自注意力机制
注意力机制第六篇:李沐动手学深度学习V2-多头注意力机制和代码实现
注意力机制第七篇:李沐动手学深度学习V2-transformer和代码实现
本文发布于:2024-02-08 19:47:10,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170739293868456.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |