【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers
作者:互联网
【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers
91/100
发布文章
gaopursuit
未选择文件
【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers
基于自注意力机制的视觉Transformer(ViT)在视觉任务上,已经形成和CNN一样强有力的架构,但其计算量和模型大小很大。虽然一些工作通过引入先验信息或级联多阶段结构到ViT中,但在移动设备上仍不够高效。本文研究基于MobileNetV2的轻量化ViT,其通过引入局部-全局-局部(LGL)的bottleneck实现,其结合了注意力机制和CNN的优势。
作者提出了将VIT模型应用于移动端需要考虑的三个问题:
- 1)推理速度要快。 当前的一些指标如 FLOPs 难以反映模型在移动端的速度,因为内存访问速度、并行性等因素还要综合考虑。
- 2)模型可以大。 当前手机可以拥有32GB的内存,存储模型并应该做为限制因素。
- 3)实现的友好性。 SWIN里的 cyclic shift 不便于在移动端实现,因此模型要考虑是否便于在移动端实现。
以上面三个原则为指导,作者提出了 EdgeViTs,设计了一个高效的 局部-全局-局部(LGL) 模块,能够实现更好的准确性和计算效率。
模型如上图所示,重点是其中的LGL模块,包括个关键部分:
- local aggregation: 由卷积和 depth conv 组成
- global sparse attention: 平均池化后进行注意力计算
- local propagation: 使用反卷积将缩小的特征图恢复到原来大小。
下面看具体代码,理解起来没有什么难度。
class LocalAgg():
def __init__(self, dim):
self.conv1 = Conv2d(dim, dim, 1)
self.conv2 = Conv2d(im, dim, 3, padding=1, groups=dim)
self.conv3 = Conv2d(dim, dim, 1)
self.norm1 = BatchNorm2d(dim)
self.norm2 = BatchNorm2d(dim)
forward(self, x):
x = self.conv1(self.norm1(x))
x = self.conv2(x)
x = self.conv3(self.norm2(x))
return x
class GlobalSparseAttn():
def __init__(self, dim, sample_rate, scale):
self.scale = scale
self.qkv = Linear(dim, dim * 3)
self.sampler = AvgPool2d(1, stride=sample_rate)
kernel_size=sr_ratio
self.LocalProp = ConvTranspose2d(dim, dim, kernel_size, stride=sample_rate, groups=dim
)
self.norm = LayerNorm(dim)
self.proj = Linear(dim, dim)
def forward(self, x):
x = self.sampler(x)
q, k, v = self.qkv(x)
attn = q @ k * self.scale
attn = attn.softmax(dim=-1)
x = attn @ v
x = self.LocalProp(x)
x = self.proj(self.norm(x))
return x
其实,网络整体就是基于CNN的,只不过沿用了 SWIN 的典型架构。实验结果如下表所示。尽管作者说,与MobileViTs相比,EdgeViTs在三种复杂度设置下分别实现了5.4%、2.8%和2.7%的提高,但是我感觉从FLOPs等指标来看,并没有约对的优势。这里是我的个人理解,有不同意见的地方可以随时交流。
标签:__,dim,scale,Transformers,weight,CNNs,self,EdgeViTs,attn 来源: https://www.cnblogs.com/gaopursuit/p/16388674.html