其他分享
首页 > 其他分享> > 基础DL模型-STN-Spatial Transformer Networks-论文笔记

基础DL模型-STN-Spatial Transformer Networks-论文笔记

作者:互联网

原文链接:https://arleyzhang.github.io/articles/7c7952f0/

论文:Spatial Transformer Networks,是Google旗下 DeepMind 公司的研究成果。

这篇论文的试验做的特别好。

1 简介

1.2 问题提出

CNN在图像分类中取得了显著的成效,主要是得益于 CNN 的深层结构具有 空间不变性(spatially invariance)(平移不变性,旋转不变性),所以图像上的目标物体就算是做了平移或者旋转,CNN仍然能够准确的识别出来,这对于CNN的泛化能力是有益的。

那么如何在保证准确率的情况下,即不损失局部信息的前提下,增强网络的空间不变性呢?这篇文章就是为了解决这个问题。

1.2 解决方法

对于CNN 来说,即便通过选择合适的降采样比例来保证准确率和空间不变性,但是 池化层 带来的空间不变性是不够的,它受限于预先选定的固定尺寸的池化核(感受野是固定的,局部的)。因为物体的变形包括旋转,平移,扭曲,缩放,混淆噪声等,所以后面feature map中像素点的感受野不一定刚好包含物体或者反映物体的形变。

文章提出了一种 Spatial Transformer Networks,简称 STN,引进了一种可学习的采样模块 Spatial Transformer ,姑且称为空间变换器,Spatial Transformer的学习不需要引入额外的数据标签,它可以在网络中对数据(feature map)进行空间变换操作。这个模块是可微的(后向传播必须),并且可以插入到现有的CNN模型中,使得 feature map具有空间变换能力,也就是说 感受野是动态变化的,feature map的空间变换方向 与 原图片上的目标的空间变换方向(一般认为是数据噪声)是相反的,所以使得整个网络的空间不变性增强。试验结果展示这种方法确实增强了空间不变性,在一些标志性的数据集(benchmark)上取得了先进的水平。

1517381763710

图1 在输入层使用 Spatial Transformer

空说无凭,先看一个简单效果,如图1:

整体上来看是一种视觉 attention 机制,也更像一种弱的目标检测机制,就是把图片中物体所在区域送到网络后面的层中,使得后面的分类任务更简单。

CNN是尽力让网络适应物体的形变,而STN是直接通过 Spatial Transformer 将形变的物体给变回到正常的姿态(比如把字摆正),然后再给网络识别。

文章给的 Spatial Transformer 的使用场景:

看完这篇论文之后,个人觉得目标检测(object detection)也是可以用的,果不其然,真有人将类似的方法用在了 目标检测上,这篇论文就是 Deformable Convolutional Networks ,后面再讲。

2 Spatial Transformer结构

文章最重要的一个结构就是 Spatial Transformers ,这个结构的示意图如下:

1517313809413

图2 Spatial Transformers 结构图

这样一个结构相当于 CNN中的一个 卷积层或者池化层:

这个结构又被分为三部分:localisation network ,grid generator和sampler

一些符号意义:

这个图与图1做个对应,U 相当于 图1 中的 (a) , V相当于 图1 中的(c),中间那一部分相当于图1 中的(b), 作用就是为了找到那个物体所在的框,或者叫做弱目标检测。

2.1 Localisation network

这一部分很简单,可以使用全连接层或者全卷积层,只要保证最后一层是一个回归层即可,最后输出的一个向量是 θθ 。 θθ 的维度下面再说。

2.2 Grid generator

前面提到中间那一部分是为了找到那个物体所在的框,并把它给 变换回 “直立的状态”。很自然就能想到使用仿射变换就可以完成,如下图:

 

 

图3 (a)恒等变换与采样; (b)仿射变换与采样

我们期望的是输出 V 是 将U中某一部分(比如绿色点覆盖的部分)做了旋转,放缩,平移之后的feature map。

看一下Grid generator是如何进行仿射变换的。


先简单的看一下仿射变换:

仿射变换用于表示旋转,缩放和平移,表示的是两副图之间的关系,

以下 A 为旋转矩阵,B 为平移矩阵,M称为仿射变换矩阵。

1517390080627

1517390143003

假设要对二维向量1517391663142进行仿射变换,仿射变换可以写成如下两式,两种写法等价:

1517390161398

1517390173934

输出的结果是:

1517390244077

对于仿射变换来说,一般的用法有两种:


这里使用的是第一种用法。其中 图3 (b) U 中的被绿点覆盖的那一部分相当于这里的 T,V相当于这里的 X,那不是应该 M也是已知的吗?M哪去了?还记得上面提到的 θθ ? θθ 就相当于这里的M。因为 M的大小是 2×3 ,所以 θθ 的维度为6。如果使用了别的变换方法,那就根据变换矩阵的大小相应调整。也就是说这里的变换矩阵是学习出来的。

对应于图3的变换公式如下:

1517393808097

注意他这个仿射变换是 从后向前变换的,就是说这个模块的输出是仿射变换的输入,这个模块的输入的其中一部分(图3(b) 绿点覆盖部分)是仿射变换的输出。

按照一般的做法,应该是从前往后变换,即从 source coordinates 得到 target coordinates 。但是这样做的问题是,如何确定变换的输入?如果是从前往后做变换,U 中绿色部分相当于 X,那怎么确定这一部分是多大,什么形状,位置在哪?

实际上从后往前变换也就是为了解决这个问题,就是要根据输出V的坐标得到输入U中目标所在的区域的坐标(绿色的区域)。

仿射变换变换的是坐标,既是坐标,那么变换的输入和输出的坐标的参考系应该是一样的,就是说 V 中像素的坐标 和 U 中像素的坐标应该是同一个参考系。这里使用的是针对 宽和高 进行的归一化坐标(height and width normalised coordinates),把在U和V中的像素坐标归一化到 [-1,1] 之间。U的 尺寸是上一层决定的,V的尺寸是人为固定的,输出 H′,W′H′,W′ 可以分别比 输入H,WH,W 大或者小,或者相等。

可以给仿射变换的变换矩阵添加更多的约束:

1517398935580

这时候,绿色区域已经确定了,相当于V中对应坐标(xti,yti)(xit,yit) 的像素都将从U中这块绿色区域中获取。 H′,W′H′,W′ 与H,WH,W 不一定相等;即便是相等,由于变换后的源坐标 (xsi,ysi)(xis,yis) 很有可能不是整数 ,对应U中不是整数像素点,所以没有像素值,没办法直接拷贝。所以V中 (xti,yti)(xit,yit) 坐标的像素值如何确定就成了问题。这时就涉及到采样和插值。

2.3 Sampler

实际上 CNN中的卷积核 或者 池化核起到的就是采样的作用。

(xsi,ysi)(xis,yis) 是U中绿色区域的坐标,来看看更加具有一般性的采样问题如何描述:

1517398970784

注意上式只是针对一个通道的像素进行采样,实际上每个通道的采样都是一样的,这样可以保留 空间一致性。

卷积的操作也是符合上式的,比如一维卷积:

1517400545429

理论来说 任意 对 xsi,ysixis,yis 可导或局部可导的采样核函数都是可以使用的.

比如最近邻插值核函数:

1517401679000

这个插值核函数做的就是把U中 离 当前源坐标 (xsi,ysi)(xis,yis) (小数坐标) 最近的 整数坐标 (n,m)(n,m) 处的像素值拷贝到V中的 (xti,yti)(xit,yit) 坐标处;

不过这篇文章使用的是双线性插值,双线性插值 参考 维基百科 和 图像处理之插值运算,这里放一张示意图吧:

20170403231241311

图4 双线性插值(来源于[参考资料 6])

这里的公式如下:

1517402932968

这个插值核函数做的是利用 U中 离 当前源坐标 (xsi,ysi)(xis,yis) (小数坐标) 最近的 4个整数坐标 (n,m)(n,m) 处的像素值做双线性插值然后拷贝到V中的 (xti,yti)(xit,yit)坐标处;

我在想他那个通过仿射变换确定绿色区域之后,绿色区域相当于ROI,那采样能不能使用ROI 池化的方式?

2.4 前向传播

结合前面的分析,总结一下前向传播的过程,如下图:

 

图5 前向传播流程(来源于[参考资料 6]

2.5 梯度流动与反向传播

这个函数虽不是 完全可导 但也是局部可导的,求导如下,对 ysiyis 的导数也是类似的:

1517403846718

根据公式(1)很容易求得: ∂xsi∂θ∂xis∂θ 和 ∂ysi∂θ∂yis∂θ 。

所以反向传播过程,误差可以传播到输入 feature map(公式6),可以传播到 采样格点坐标(sampling grid coordinates )(公式7),还可以传播到变换参数 θθ .

下图是梯度流动的示意图:

图6 反向传播流程(来源于[参考资料 6]

其中localisation network中的 ∂xsi∂θ∂xis∂θ 和 ∂ysi∂θ∂yis∂θ 也就是这一股误差流 {∂Vci∂xSi→∂xSi∂θ ∂Vci∂ySi→∂ySi∂θ{∂Vic∂xiS→∂xiS∂θ ∂Vic∂yiS→∂yiS∂θ ,在定位网络处就断了。

定位网络是一个回归模型,相当于一个子网络,一旦更新完参数,流就断了,独立于主网络。

3 试验

3.1 Distorted MNIST

这个试验的数据集 是 MNIST,不过与原版的MNIST 不同,这个数据集对图片上的数字做了各种形变操作,比如平移,扭曲,放缩,旋转等。

如下,不同形变操作的简写表示:

文章将 Spatial Transformer 模块嵌入到 两种主流的分类网络,FCN和CNN中(ST-FCN 和 ST-CNN )。Spatial Transformer 模块嵌入位置在图片输入层与后续分类层之间。

试验也测试了不同的变换函数对结果的影响:

其中CNN的模型与 LeNet是一样的,包含两个池化层。为了公平,所有的网络变种都只包含 3 个可学习参数的层,总体网络参数基本一致,训练策略也相同。

试验结果

作者也做了噪声环境下的试验:将数字 放置在 60×60的图片上,并添加斑点噪声(图1第三行)错误率分别为:

FCN ,13.2% error; CNN , 3.5% error; ST-FCN ,2.0% error; ST-CNN ,1.7% error.

3.2 Street View House Numbers

Street View House Numbers是一个真实的 街景门牌号 数据集,共200k张图片,每张图片包含1-5个数字 ,数字都有形变。

结果:

 

3.3 Fine-Grained Classification

数据集:CUB-200-2011 birds dataset, 6k training images and 5.8k test images, covering 200 species of birds.

这里使用了并行的Spatial Transformer , 效果是可以将图片的不同 部分(part)输入到不同的 Spatial Transformer 层,会产生不同的 part representations 然后经过 inception ,最后再合并起来,经过一个单独的softmax层做分类。

结果:

3.4 MNIST Addition

这个试验是将任意两张MNIST中的数字独立的进行一系列变形,然后叠加到一块,给网络识别,标签是二者之和。

同样的测试 FCN, CNN, ST-CNN,2×ST-CNN。

2×ST-CNN在输入层使用了两个并行的Spatial Transformer,结构见下面table 4右侧。

3.5 Co-localisation

这个试验将 Spatial Transformer用在了半监督的任务Co-localisation 。

Co-localisation :给一些图片,假设这些图片包含一些目标(也可能不包含),在不使用目标类别标签和目标位置标签的情况下,定位出常见的目标。

数据集还是 MNIST ,将 28×28大小的 数字图像 随机的放在 84×84 大小的含有噪声的背景上,对每个数字产生100个不同的变形。数据有定位标签,但是在训练时不用,测试时用。

模型还是使用 LeNet CNN模型,在输入层嵌入Spatial Transformer。

文章使用了半监督的方式,监督的学习过程是这样的:

对于一个 包含 N 张图片的 数据集 I={In}I={In} ,比如table 5 右侧的图。

经过上面的分析,可以提出如下损失函数: hinge loss (triplet loss)

1517414543432

αα is a margin ,可以称为 裕度,相当于净赚多少。

半监督是因为这里的标签相当于 L2,而L2是人为构造出来的距离指标。

测试时认为检测出的box与ground truth bounding box的IOU 大于0.5为正确,table5 左侧为测试结果。

在没有噪声时,可以达到100%的准确率,有噪声时在75-93%之间。

下图是优化过程的动态可视化结果,可见随着迭代次数越来越多,模型对目标的定位越来越准确

1517414970818

这个试验使用了一种简单的损失函数,在不使用数据定位标签的情况下,构建了一种距离标签,实现了对目标的检测。这个可以推广到目标检测或追踪问题中去。

作者把前面一些检测的动态效果做成了视频,看起来很清晰明了,看这里:https://goo.gl/qdEhUu

4 总结

这篇文章提出的 Spatial Transformer 结构能够很方便的嵌入到现有的CNN模型中去,并且实现端到端(end-to-end)的训练,通过对数据进行反向空间变换来消除图片上目标的变形,从而使得分类网络的识别更加简单高效。现在的CNN的已经非常强大了,但是 Spatial Transformer 仍然能过通过增强空间不变性来提高性能表现。Spatial Transformer实际上是一种attention机制,可以用于目标检测,目标追踪等问题,还可以构建半监督模。

下一篇介绍 Deformable Convolutional Networks ,跟本篇的TSN思路很像,但是又比这个模型简单。

参考资料

  1. opencv中文教程——仿射变换
  2. 仿射变换与齐次坐标
  3. 知乎——如何通俗易懂的理解高维仿射变换与线性变换
  4. 维基百科——双线性插值
  5. 图像处理之插值运算
  6. 讲STN的一篇博客,不过关于仿射变换那一块写的是错的,但是其中的图还是挺不错的,借用几张图

标签:Transformer,变换,STN,DL,坐标,Spatial,CNN,仿射变换
来源: https://blog.csdn.net/chen645096127/article/details/100081824