其他分享
首页 > 其他分享> > AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

作者:互联网

https://arxiv.org/pdf/2010.11929.pdf

---------------------------------------------------------

2021-08-30

                        

transformer缺少cnn的平移不变性,局部性:大规模数据集预训练可解决

class PatchEmbeddin(nn.Module):
    def __init__(self,in_channel:int = 3,patch_size:int = 16,emb_size:int=768,img_size:int = 224):
        super(PatchEmbeddin, self).__init__()
        self.patch_size=patch_size
        self.projection=nn.Sequential(
            nn.Conv2d(in_channel,emb_size,kernel_size=patch_size,stride=patch_size),
            Rearrange("b e (h) (w) -> b (h w) e"),
        )
        self.cls_token=nn.Parameter(torch.randn(1,1,emb_size))
        self.position=nn.Parameter(torch.randn((img_size//patch_size)**2+1,emb_size))

    def forward(self,x:torch.Tensor)->torch.Tensor:
        b=x.size()[0]
        x=self.projection(x)
        cls_tokens=einops.repeat(self.cls_token,"() n e -> b n e",b=b)
        x=torch.cat([cls_tokens,x],dim=1)
        x+=self.position

        return x


class MultiHeadAttention(nn.Module):
    def __init__(self,emb_size:int=768,num_headas:int=8,dropout:float=0):
        super(MultiHeadAttention, self).__init__()
        self.emb_size=emb_size
        self.num_heads=num_headas
        self.qkv=nn.Linear(emb_size,emb_size*3)
        self.ett_drop=nn.Dropout(dropout)
        self.projection=nn.Linear(emb_size,emb_size)

    def forward(self,x:torch.Tensor,mask:torch.Tensor=None)->torch.Tensor:
        qkvs=einops.rearrange(self.qkv(x),"b n (h d qkv) -> (qkv) b h n d",h=self.num_heads,qkv=3)
        queries,keys,values=qkvs[0],qkvs[1],qkvs[2]
        energy=torch.einsum("bhqd,bhkd -> bhqk",queries,keys)
        if mask is not None:
            fill_value=torch.finfo(torch.float32).min
            energy.mask_fill(~mask,fill_value)
        scaling=self.emb_size**(1/2)
        att=F.softmax(energy,dim=-1)/scaling
        att=self.ett_drop(att)
        out=torch.einsum("bhal,bhl -> bhav",att,values)
        out=einops.rearrange(out,"b h n d -> b n (h d)")
        out=self.projection(out)

        return out

标签:__,SCALE,TRANSFORMERS,nn,emb,IMAGE,torch,self,size
来源: https://www.cnblogs.com/shuimobanchengyan/p/15205078.html