Swin-Transformer中MSA和W-MSA模块计算复杂度推导(非常详细,最新)
作者:互联网
在Swin-Transformer一文中,有这样两个公式,分别为:
1. Transformer中提出的Multi-head Self-Attention模块(MSA)
2. Swin-Transformer中提出的Window Multi-head Self-Attention模块(W-MSA)
两者计算量即计算复杂度分别为:
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
Ω(MSA) = 4hwC^2 + 2(hw)^2C
Ω(MSA)=4hwC2+2(hw)2C
Ω
(
W
−
M
S
A
)
=
4
h
w
C
2
+
2
M
2
h
w
C
Ω(W-MSA) = 4hwC^2 + 2M^2hwC
Ω(W−MSA)=4hwC2+2M2hwC
一、MSA模块计算量
Transformer中提出的多头自注意力模块运算公式为
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(Q,K,V)=Softmax(\frac {QK^T} {\sqrt{d_k} \quad})V
Attention(Q,K,V)=Softmax(dk
QKT)V
运算过程以及计算量如下图:
(注:图中黑色字体为矩阵维度,橙色字体为计算量)
- 首先,一张维度为H * W * C的图片,分别通过乘上一个C*C的变换矩阵,分别变成QKV矩阵,维度仍为H * W * C。这样三个变换用到的计算量都为:HW * C^2,加起来为: 3 H W ∗ C 2 3HW*C^2 3HW∗C2
- 然后做Q*K的转置,(HW * C) * (C * HW),得到的矩阵维度为HW * HW,计算量为: ( H W ) 2 ∗ C (HW)^2*C (HW)2∗C
- 忽略根号dk和softmax的计算量,2中得到的 A矩阵 × V矩阵 ,结果的维度为HW * C,计算量为: ( H W ) 2 ∗ C (HW)^2*C (HW)2∗C
- 因为是多头自注意力,所以在做完矩阵乘法后,需要矩阵拼接融合,做一个Proj映射,得到输出矩阵,维度和输入一样为HW * C,计算量为: H W ∗ C 2 HW*C^2 HW∗C2
- 最后计算量加起来,为: Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω(MSA) = 4hwC^2 + 2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
二、W-MSA模块计算量
W-MSA与MSA总体的计算过程是一致的,区别在于:W-MSA的长宽不再是H和W,而是
窗
口
:
M
∗
M
窗口: M*M
窗口:M∗M并且有
H
M
∗
W
M
个
窗
口
需
要
计
算
\frac{H}{M}*\frac{W}{M}个窗口需要计算
MH∗MW个窗口需要计算
所以它的计算量为:
(
h
M
∗
h
M
)
∗
(
4
M
2
C
2
+
2
M
4
C
)
=
4
h
w
C
2
+
2
M
2
h
w
C
(\frac{h}{M}*\frac{h}{M})*(4M^2C^2+2M^4C)=4hwC^2+2M^2hwC
(Mh∗Mh)∗(4M2C2+2M4C)=4hwC2+2M2hwC
分析
从他们公式可以看出区别主要在于两个公式的后半部分
带一点数进去就可以看出W-MSA在计算量上比MSA少很多,比如以原文中的一些参数设定为例:HW都为56,C为96,M为7
前者MSA为:
2
∗
(
56
∗
56
)
2
∗
96
=
1888223232
2*(56*56)^2*96=1888223232
2∗(56∗56)2∗96=1888223232
后者为W-MSA为:
2
∗
7
2
∗
56
∗
56
∗
96
=
29503488
2*7^2*56*56*96=29503488
2∗72∗56∗56∗96=29503488
二者相差了64倍。
论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030
标签:Transformer,Swin,56,HW,计算,MSA,96 来源: https://blog.csdn.net/qq_45588019/article/details/122599502