Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect个人理解
作者:互联网
Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect
M代表torch中实现SGD中的momentum算法。momentum算法引入了动量的概念,类似于惯性,其目的是让梯度下降的更加稳定;但是它随着训练过程一直累积着,相当于LSTM的时序信息一样,将其他batch继承了过来,影响着参数的调整,从而影响特征提取的过程,从而影响X。
D表示特征X在向head的特征的方向上偏离的投影。M是影响X在向head的特征偏移的元凶,由于X在向head的特征偏移所以导致Y对于long-tail的loss很大。
计算 P ( Y ∣ X ) P(Y|X) P(Y∣X)时存在一条后门路径X<-M->D->Y,因此需要 d o ( X ) do(X) do(X)来消除后门路径的影响。
P ( Y = i ∣ d o ( X = x ) ) = ∑ m ∈ M ( Y = i ∣ X = x , M = m ) P ( M = m ) = ∑ m ∈ M P ( Y = i ∣ X = x , M = m ) P ( X = x ∣ M = m ) \begin{aligned} P(Y=i|do(X=x))&=\displaystyle \sum_{m \in M}(Y=i|X=x,M=m)P(M=m) \\ &=\displaystyle \sum_{m \in M}\frac{P(Y=i|X=x,M=m)}{P(X=x|M=m)} \end{aligned} P(Y=i∣do(X=x))=m∈M∑(Y=i∣X=x,M=m)P(M=m)=m∈M∑P(X=x∣M=m)P(Y=i∣X=x,M=m)
d o ( X ) do(X) do(X)后的因果图是:
计算 P ( Y ∣ X ) P(Y|X) P(Y∣X)时存在中介D,需要消除中介D的影响,计算X对Y的直接因果效应。
T D E ( Y i ) = Y x , d − Y x ∗ , d = P ( Y = i ∣ d o ( X = x ) ) − P ( Y = i ∣ d o ( X = x ∗ ) ) \begin{aligned} TDE(Y_i) &= Y_{x,d}-Y_{x^*,d} \\ &= P(Y=i|do(X=x))-P(Y=i|do(X=x^*)) \end{aligned} TDE(Yi)=Yx,d−Yx∗,d=P(Y=i∣do(X=x))−P(Y=i∣do(X=x∗))
由于M的取值范围无限,因此采取之前解决long-tail问题中常用的re-sampling的重采样方式,然后通过逆概率加权的方式来近似
∑
m
∈
M
P
(
Y
=
i
∣
X
=
x
,
M
=
m
)
P
(
M
=
m
)
≈
∑
k
=
1
K
P
~
(
Y
=
i
,
X
=
x
k
,
D
=
d
k
)
K
\displaystyle \sum_{m \in M}P(Y=i|X=x,M=m)P(M=m) \approx \frac{\displaystyle \sum_{k=1}^K \tilde{P}(Y=i,X=x^k,D=d^k)}{K}
m∈M∑P(Y=i∣X=x,M=m)P(M=m)≈Kk=1∑KP~(Y=i,X=xk,D=dk)
应用倾向性评分 g ( x k , d k ) g(x^k,d^k) g(xk,dk)来去除混杂效应(因为利用神经网络计算 f ( x k , d k ) f(x^k,d^k) f(xk,dk)相当于还留着M的影响,需要加上倾向性评分来去除M的影响)
P ~ ( Y = i , X = x k , D = d k ) ∝ E ( x k , d k ; w i k ) = τ f ( x k , d k ; w i k ) g ( x k , d k ; w i k ) \displaystyle \tilde{P}(Y=i,X=x^k,D=d^k) \propto E(x^k,d^k;w_i^k)= \tau \frac{f(x^k,d^k;w_i^k)}{g(x^k,d^k;w_i^k)} P~(Y=i,X=xk,D=dk)∝E(xk,dk;wik)=τg(xk,dk;wik)f(xk,dk;wik)
作者通过实验发现cosine classifier ∣ ∣ w i k ∣ ∣ ⋅ ∣ ∣ x k ∣ ∣ ||w_i^k||·||x^k|| ∣∣wik∣∣⋅∣∣xk∣∣效果不好, ∣ ∣ x k ∣ ∣ ||x^k|| ∣∣xk∣∣本身也包含着M导致的向head特征上偏移,因此采取 ∣ ∣ w i k ∣ ∣ ⋅ ∣ ∣ x k ∣ ∣ + γ ∣ ∣ x k ∣ ∣ = ∣ ∣ x k ∣ ∣ ( ∣ ∣ w i k ∣ ∣ + γ ) ||w_i^k||·||x^k||+\gamma ||x^k||=||x^k||(||w_i^k||+\gamma) ∣∣wik∣∣⋅∣∣xk∣∣+γ∣∣xk∣∣=∣∣xk∣∣(∣∣wik∣∣+γ)来作为倾向性评分的计算方程。
而
f
(
x
k
,
d
k
;
w
i
k
)
=
(
w
i
k
)
T
(
x
¨
k
+
d
k
)
=
(
w
i
k
)
T
x
k
f(x^k,d^k;w_i^k)=(w_i^k)^T(\ddot{x}^k+d^k)=(w_i^k)^Tx^k
f(xk,dk;wik)=(wik)T(x¨k+dk)=(wik)Txk
因此
P
(
Y
=
i
∣
d
o
(
X
=
x
)
)
=
τ
K
∑
k
=
1
K
(
w
i
k
)
T
x
k
∣
∣
x
k
∣
∣
(
∣
∣
w
i
k
∣
∣
+
γ
)
\displaystyle P(Y=i|do(X=x))=\frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^Tx^k}{||x^k||(||w_i^k||+\gamma)}
P(Y=i∣do(X=x))=Kτk=1∑K∣∣xk∣∣(∣∣wik∣∣+γ)(wik)Txk
由于
d
=
∣
∣
d
∣
∣
⋅
d
^
=
cos
(
x
,
d
^
)
⋅
∣
∣
x
∣
∣
⋅
d
^
d=||d||·\hat{d}=\cos(x,\hat{d})·||x||·\hat{d}
d=∣∣d∣∣⋅d^=cos(x,d^)⋅∣∣x∣∣⋅d^
当
x
=
x
∗
x=x^*
x=x∗时,
x
k
=
(
x
∗
)
k
+
d
k
=
d
k
x^k=(x^*)^k+d^k=d^k
xk=(x∗)k+dk=dk
此时
P
(
Y
=
i
∣
d
o
(
X
=
x
∗
)
)
=
τ
K
∑
k
=
1
K
(
w
i
k
)
T
d
k
∣
∣
x
k
∣
∣
(
∣
∣
w
i
k
∣
∣
+
γ
)
=
τ
K
∑
k
=
1
K
(
w
i
k
)
T
⋅
cos
(
x
k
,
d
^
k
)
⋅
∣
∣
x
k
∣
∣
⋅
d
^
k
∣
∣
x
k
∣
∣
(
∣
∣
w
i
k
∣
∣
+
γ
)
=
τ
K
∑
k
=
1
K
(
w
i
k
)
T
⋅
cos
(
x
k
,
d
^
k
)
⋅
d
^
k
∣
∣
w
i
k
∣
∣
+
γ
\begin{aligned} P(Y=i|do(X=x^*))&=\frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^Td^k}{||x^k||(||w_i^k||+\gamma)}\\ &=\frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^T·\cos(x^k,\hat{d}^k)·||x^k||·\hat{d}^k}{||x^k||(||w_i^k||+\gamma)}\\ &=\frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^T·\cos(x^k,\hat{d}^k)·\hat{d}^k}{||w_i^k||+\gamma} \end{aligned}
P(Y=i∣do(X=x∗))=Kτk=1∑K∣∣xk∣∣(∣∣wik∣∣+γ)(wik)Tdk=Kτk=1∑K∣∣xk∣∣(∣∣wik∣∣+γ)(wik)T⋅cos(xk,d^k)⋅∣∣xk∣∣⋅d^k=Kτk=1∑K∣∣wik∣∣+γ(wik)T⋅cos(xk,d^k)⋅d^k
然后将它们带入前面的TDE公式
T D E ( Y i ) = P ( Y = i ∣ d o ( X = x ) ) − P ( Y = i ∣ d o ( X = x ∗ ) ) = τ K ∑ k = 1 K ( w i k ) T x k ∣ ∣ x k ∣ ∣ ( ∣ ∣ w i k ∣ ∣ + γ ) − τ K ∑ k = 1 K ( w i k ) T ⋅ cos ( x k , d ^ k ) ⋅ d ^ k ∣ ∣ w i k ∣ ∣ + γ = τ K ∑ k = 1 K ( ( w i k ) T x k ∣ ∣ x k ∣ ∣ ( ∣ ∣ w i k ∣ ∣ + γ ) − α ⋅ cos ( x k , d ^ k ) ( w i k ) T d ^ k ∣ ∣ w i k ∣ ∣ + γ ) \begin{aligned} TDE(Y_i) &= P(Y=i|do(X=x))-P(Y=i|do(X=x^*))\\ &=\frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^Tx^k}{||x^k||(||w_i^k||+\gamma)} - \frac{\tau}{K} \displaystyle \sum_{k=1}^K \displaystyle \frac{(w_i^k)^T·\cos(x^k,\hat{d}^k)·\hat{d}^k}{||w_i^k||+\gamma} \\ &=\frac{\tau}{K}\displaystyle \sum_{k=1}^K (\displaystyle \frac{(w_i^k)^Tx^k}{||x^k||(||w_i^k||+\gamma)} - \alpha·\frac{\cos(x^k,\hat{d}^k)(w_i^k)^T\hat{d}^k}{||w_i^k||+\gamma}) \end{aligned} TDE(Yi)=P(Y=i∣do(X=x))−P(Y=i∣do(X=x∗))=Kτk=1∑K∣∣xk∣∣(∣∣wik∣∣+γ)(wik)Txk−Kτk=1∑K∣∣wik∣∣+γ(wik)T⋅cos(xk,d^k)⋅d^k=Kτk=1∑K(∣∣xk∣∣(∣∣wik∣∣+γ)(wik)Txk−α⋅∣∣wik∣∣+γcos(xk,d^k)(wik)Td^k)
标签:xk,Keeping,Good,frac,Classification,do,wik,cos,displaystyle 来源: https://blog.csdn.net/qq_42856273/article/details/120721594