(AAAI2020)Adversarial Domain Adaptation with Domain Mixup论文笔记
作者:互联网
(AAAI2020)Adversarial Domain Adaptation with Domain Mixup论文笔记
基于对抗方式的对齐源域数据和目标域数据。本文提出一种mixup的对齐方式。传统的对齐方式中,直接在对齐源域和目标域的特征分布。而该方法中,构造出一些minup的数据,相当于在源域的特征和目标域的特征之间建立桥梁,逐渐对齐特征。
图中表达的意思是,该方法将构造出一些mixup的图片,通过这些图片建立源域和目标域之间的联系,从而对齐源域的特征和目标域的特征。
模型结构
对于输入数据 x s x^s xs和 x t x^t xt,我们将其mixup,合并的权值使用 λ \lambda λ,即
x m = λ x s + ( 1 − λ ) x t x^m=\lambda x^s + (1-\lambda)x^t xm=λxs+(1−λ)xt
将源域和目标域数据输入编码器 N e N_e Ne,得到两个特征向量 μ \mu μ和 σ \sigma σ。源域的命名为 μ s \mu_s μs和 σ s \sigma_s σs,目标域的命名为 μ s \mu_s μs和 σ t \sigma_t σt。
命名时候感觉其中一个代表均值,另一个代表方差。但在代码中实现的时候, N e N_e Ne是卷积层后接全连接层, μ \mu μ和 σ \sigma σ公用卷积层,卷积层后接两个全连接层分别输出 μ \mu μ和 σ \sigma σ
之后我们估计mixup图片的 μ \mu μ和 σ \sigma σ,估计方法同样使用 λ \lambda λ加权
μ m = λ μ s + ( 1 − λ ) μ t \mu_m=\lambda\mu_s+(1-\lambda)\mu_t μm=λμs+(1−λ)μt
σ m = λ σ s + ( 1 − λ ) σ t \sigma_m=\lambda\sigma_s+(1-\lambda)\sigma_t σm=λσs+(1−λ)σt
之后拼接起来作为解码器 N d N_d Nd的输入,解码器的输入包括 [ μ , σ , z , l c l s , l c o m p ] [\mu,\sigma,z,l_{cls},l_{comp}] [μ,σ,z,lcls,lcomp]
z z z是一个噪声向量, l c l s l_{cls} lcls代表类别向量, l c o m p l_{comp} lcomp代表域类别向量。源域,目标域和mixup的标签信息为
源域: l c l s s = [ 0 , 0... , 1 , . . . , 0 ] , l c o m p s = 0 l_{cls}^s=[0,0...,1,...,0],l_{comp}^s=0 lclss=[0,0...,1,...,0],lcomps=0
目标域: l c l s t = [ 0 , 0... , 0 , . . . , 0 ] , l c o m p s = 1 l_{cls}^t=[0,0...,0,...,0],l_{comp}^s=1 lclst=[0,0...,0,...,0],lcomps=1
mixup图片: l c l s m = [ 0 , 0... , λ , . . . , 0 ] , l c o m p s = 1 − λ l_{cls}^m=[0,0...,\lambda,...,0],l_{comp}^s=1-\lambda lclsm=[0,0...,λ,...,0],lcomps=1−λ
N d N_d Nd可以认为成一个条件生成网络,目的是产生出图片。
之后通过计算损失函数来优化参数。
损失函数
首先对于 μ \mu μ和 σ \sigma σ,我们希望数据的分布和 N ( 0 , 1 ) N(0,1) N(0,1)对齐,所以使用KL散度对齐,损失函数为
m i n N e L K L = D K L ( N ( μ , σ ) ∣ ∣ N ( 0 , I ) ) min_{N_e}L_{KL}=D_{KL}(N(\mu,\sigma)||N(0,I)) minNeLKL=DKL(N(μ,σ)∣∣N(0,I))
这步损失函数表明,虽然都是卷积层+全连接层的输出。但 μ \mu μ和 σ \sigma σ在意义上是存在区别的。
之后和传统的对抗方法类似,我们让判别器D和编码器解码器进行对抗,损失函数为
min N e , N d max D L a d v s + L a d v t + L a d v m \min\limits_{N_e,N_d}\max\limits_{D}L^s_{adv}+L^t_{adv}+L^m_{adv} Ne,NdminDmaxLadvs+Ladvt+Ladvm
其中
L a d v s = E x s ∼ P s l o g ( D d o m ( x s ) ) + l o g ( 1 − D d o m ( x g s ) ) L^s_{adv}=E_{x^s\sim P_s}log(D_{dom}(x^s))+log(1-D_{dom}(x^s_g)) Ladvs=Exs∼Pslog(Ddom(xs))+log(1−Ddom(xgs))
L a d v t = E x t ∼ P t l o g ( 1 − D d o m ( x g t ) ) L^t_{adv}=E_{x^t\sim P_t}log(1-D_{dom}(x_g^t)) Ladvt=Ext∼Ptlog(1−Ddom(xgt))
L a d v m = E x s ∼ P s , x t ∼ P t l o g ( 1 − D d o m ( x g m ) ) L_{adv}^m=E_{x^s\sim P_s,x^t\sim P_t}log(1-D_{dom}(x_g^m)) Ladvm=Exs∼Ps,xt∼Ptlog(1−Ddom(xgm))
式子中的 x s , x t x^s,x^t xs,xt表示源域和目标域的原始图片, x g s , x g t , x g m x^s_g,x^t_g,x^m_g xgs,xgt,xgm分别表示源域的生成图片,目标域的生成图片和mixup的生成图片。
- 损失函数中的 L a d v s L^s_{adv} Ladvs中 D D D的目标是区分出 x s x_s xs和 x g s x_g^s xgs的区别,而 N e , N d N_e,N_d Ne,Nd的目标是混淆 x s x_s xs和 x g s x_g^s xgs。我们可以将 x g s x_g^s xgs看成另一个领域的数据,起初, x s x_s xs和 x g s x_g^s xgs是两个不同的域的数据,我们希望可以将 x s x_s xs和 x g s x_g^s xgs对齐,与 x s x_s xs和$x_t对齐类似。所以这个损失函数和传统的对抗方法的域判别损失类似。
- 损失函数中的 L a d v t L^t_{adv} Ladvt是类似于传统对抗方法中的域判别损失的一部分。将其中的 E x t ∼ P t l o g ( 1 − D d o m ( x t ) ) E_{x^t\sim P_t}log(1-D_{dom}(x^t)) Ext∼Ptlog(1−Ddom(xt))换成了 E x t ∼ P t l o g ( 1 − D d o m ( x g t ) ) E_{x^t\sim P_t}log(1-D_{dom}(x_g^t)) Ext∼Ptlog(1−Ddom(xgt))。将原始图片更换成了生成图片
- 第三个式子的作用是将mixup的生成图片与源域对齐。
个人认为,这三个式子的目的是逐渐将目标域的图片,mixup的图片与源域的图片对齐。本文中的判别器 D D D和之前的判别器不同。这个判别器包括特征提取功能和判别功能(否则直接将图片输入判别器,由于特征是low-level的特征,很难进行判别和对齐)。代码中的 D D D的网络是卷积层+全连接层+sigmoid层实现。
综合一下上述的损失函数可以发现, N e , N d N_e,N_d Ne,Nd的目的是对图片解码编码后,生成的图片和源域图片对齐。这部分损失函数并没有使用到 x t x^t xt和 x m x^m xm。后续的soft label和triplet loss将会用到。
soft label 损失为
min D L s o f t m = − E x s ∼ P s , x t ∼ P t l d o m m l o g ( D d o m ( x m ) ) + ( 1 − l d o m m ) l o g ( 1 − D d o m ( x m ) ) \min\limits_D{L^m_{soft}}=-E_{x^s\sim P_s,x^t\sim P_t}l_{dom}^mlog(D_{dom}(x^m))+(1-l_{dom}^m)log(1-D_{dom}(x^m)) DminLsoftm=−Exs∼Ps,xt∼Ptldommlog(Ddom(xm))+(1−ldomm)log(1−Ddom(xm))
其中 l d o m m l_{dom}^m ldomm表示mixup图像的领域标签,即 λ \lambda λ
这个的作用是希望 D D D对于mixup图像,可以输出其领域标签信息为 λ \lambda λ
triplet loss
triplet loss中包含三类样本, ( a , p , n ) (a,p,n) (a,p,n),分别表示取出的样本,同类的样本和非同类的样本。本文中的triplet loss并不是针对类别层面,而是针对领域层面。
如果mixup的样本的 λ ≥ 0.5 \lambda\geq 0.5 λ≥0.5,说明这类样本更接近源域,那么 ( a , p , n ) = ( x m , x s , x t ) (a,p,n)=(x^m,x^s,x^t) (a,p,n)=(xm,xs,xt)
如果mixup的样本的 λ < 0.5 \lambda < 0.5 λ<0.5,说明这类样本更接近目标域,那么 ( a , p , n ) = ( x m , x t , x s ) (a,p,n)=(x^m,x^t,x^s) (a,p,n)=(xm,xt,xs)
之后计算triplet loss,triplet loss中的偏置设定为 ∣ 2 λ − 1 ∣ |2\lambda-1| ∣2λ−1∣
(代码中的 λ \lambda λ按照 β \beta β分布随机生成,但如果 λ \lambda λ的值比较靠近 0.5 0.5 0.5,就会修改到稍微远离 0.5 0.5 0.5)
和之前很多方法类似,我们让判别器 D D D拥有分类的能力,这里分类针对的是源域和目标域数据的生成图像。不同的是这里的判别器 D D D包括特征提取功能。所以我们只需要在卷积层后加入全连接层用于分类。
min N e , N d , D L c l s s + L c l s t \min\limits_{N_e,N_d,D}L_{cls}^s+L_{cls}^t Ne,Nd,DminLclss+Lclst
都是交叉熵损失,目标域的标签使用分类器 C C C给出的伪标签。
最后还有个分类器 C C C的优化损失函数,分类损失 min N e , C L C \min\limits_{N_e,C}L_C Ne,CminLC
文章效果
总结
本文的创新点在于引入mixup图像的方法,让源域和目标域的对齐不那么直接,而是通过mixup作为桥梁连接源域和目标域的对齐过程。
但这篇文章中并没有提到对齐条件概率分布(不知道是不是弄漏了)。没有在类别层面的对齐,效果依然很好,很奇怪,如果说给 D D D加入分类效果,其中用上了 C C C给的伪标签,这部分的对齐效果有这么好么?
在本文的主体部分,mixup图像以及生产图像的对齐上,并没有对齐条件概率分布的对齐。
标签:...,Domain,AAAI2020,log,Adversarial,源域,对齐,mixup,lambda 来源: https://blog.csdn.net/weixin_43141836/article/details/111225691