其他分享
首页 > 其他分享> > Transformer 中的 attention

Transformer 中的 attention

作者:互联网

Transformer 中的 attention

转自Transformer中的attention,看完不懂扇我脸

大火的transformer 本质就是:

*使用attention机制的seq2seq。*

所以它的核心就是attention机制,今天就讲attention。直奔代码VIT-pytorch:

https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

中的

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

看吧!就是这么简单。今天就彻底搞懂这个东西。

先记住attention的这么几个点:

综上所述: attention优点 = CNN并行+RNN全局资讯+对输入尺寸(时序长度维度上)没有限制。

如果你能创造一个拥有上面三点优点的东西出来,你也可以引领潮流。

然后回到代码,再熟悉这么几个设置:

下面看这个图,看完不懂的可以扇自己了:

attention的顺序是:

  1. 你有长度为n(序列)的序列,每个元素都是一个特征,每个特征都是一个向量;
  2. 每个向量都经过FC1,FC2,FC3获取到q,k,v三个向量(长度自己定),记住,不同特征用的是同一个FC1,FC2,FC3。可以说对于一个head,就一组FC1,FC2,FC3。
  3. 特征1的q1和所有特征的k 进行点乘,获取一串值,注意:和自己的k也进行点乘;点乘向量变标量,表示相似性。多个K可不就是一串标量。
  4. 3中的那一串值进行softmax操作,作为权重 对所有v加权求和,获得特征1输出;
  5. 其他所有的特征和特征1的操作一样,注意所有特征是一块并行计算的;
  6. 最后获取的和输入一样长度的特征序列再经过FC进行长度(特征维度)调整,也可以不要;

对了,softmax之前不要忘记 除以 qkv长度开方进行scaled,其实就是标准化操作(我觉得可以理解为各种N(BN,GN,LN等))。

就是这么简单,你学会了吗?

标签:dim,Transformer,nn,特征,self,attention,out
来源: https://www.cnblogs.com/lwp-nicol/p/16245173.html