Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
作者:互联网
论文:https://arxiv.org/abs/2103.14030
代码:https://github.com/microsoft/Swin-Transformer
论文中提出了一种新型的Transformer架构(Swin Transformer),其利用滑动窗口和分层结构使得Swin Transformer成为了机器视觉领域新的Backbone,在图像分类、目标检测、语义分割等多种机器视觉任务中达到了SOTA水平。
目前Transformer应用到图像领域主要有两大挑战:
- 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
- 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较
本文借鉴了CNN中的inductive bias,其中滑窗操作包括不重叠的local window,和重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量。
假设一张图片共有个patches(每个patches是原图4*4像素区域),每个窗口包括个patches.
原始Transformer self-attention计算复杂度 =
在Swin Transformer中采用的是window self-attention,其计算复杂度为窗口计算复杂度*窗口数量,窗口数量= ,窗口计算复杂度= ,
Swin Transformer self-attention计算复杂度 =
计算复杂度由patches数量的平方关系降低到线性关系。
层次设计则是类似CNN随着网络变深,感受野变大的特性,将window内的多个patch变成一个patch,类似下采样。
整体结构
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。
Patch Partition+Linear Embedding
Patch Partition 将像素分辨率图像转换为patches分辨率的图像,每个patch视为一个token,特征就是patch范围内的RGB值的展开,token_feature = 48;代码通过二维卷积层,将stride,kernelsize设置为patch_size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。
Linear Embedding 将token_feature转换为需要的维度(Swin_T/C=96) 。
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) # -> (img_size, img_size)
patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
# 假设采取默认参数
x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4)
x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
x = torch.transpose(x, 1, 2) # 把通道维放到最后 (N, 56*56, 96)
if self.norm is not None:
x = self.norm(x)
return x
Patch Merging
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C) # (B, H*W, C)->(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
Window Partition/Reverse
window partition
函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C
, 划分成 num_windows*B, window_size, window_size, C。
而window reverse
函数则是对应的逆过程。这两个函数分别用在windows attention前后。
Window Attention
传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。实验有证明相对位置编码的加入提升了模型性能,所以这里主要讲一下相对位置编码。
假设window_size = 2*2即每个窗口有4个patch ,如图1所示,在计算self-attention时,每个patch都要与所有的ptch计算QK值,如图6所示,当位置1的patch计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的patch为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。
首先我们利用torch.arange
和torch.meshgrid
函数生成对应的坐标,这里我们以windowsize=2
为例子
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
(tensor([[0, 0],
[1, 1]]),
tensor([[0, 1],
[0, 1]]))
"""
然后堆叠起来,展开为一个二维向量
coords = torch.stack(coords) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
"""
利用广播机制,分别在第一维,第二维,插入一个维度,进行广播相减,得到 2, wh*ww, wh*ww
的张量
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 形状的张量
首先 Window Partition将原本的张量从 N H W C
, 划分成 num_windows*B, window_size, window_size, C。
然后经过self.qkv
这个全连接层后,进行reshape,调整轴的顺序,得到形状为3, numWindows*B, num_heads, window_size*window_size, c//num_heads
,并分配给q,k,v
。再加上之前的相对位置编码,剩下就是跟transformer一样的softmax,dropout,与V
矩阵乘,再经过一层全连接层和dropout。
Shifted Window Attention
前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。
论文中以 向下取整的窗口重新对原图进行分割(这里M是Window Attention窗口的尺寸 ),并将之前没有联系的新窗口合并得到新的窗口划分方案,如图8所示,带来的问题就是窗口个数增加了,为了避免窗口增加导致的额外计算量并保证不重叠窗口间有关联,论文提出了cyclic shift方法,如下图所示:
通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。
特征图移位通过torch.roll实现,并且给每个子窗口进行编码:
Attention Mask
希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。
但实际上代码使用张量相减来得到mask矩阵的:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
这里代码写的还是很巧妙的,patch与patch的自相关对应mask之间的比较,所以广播再相减就可以得到结果,如下:
标签:Shifted,Transformer,Swin,self,mask,patch,window,coords,size 来源: https://blog.csdn.net/leihuang0822/article/details/121708311