其他分享
首页 > 其他分享> > 『论文笔记』SuperGlue

『论文笔记』SuperGlue

作者:互联网

https://zhuanlan.zhihu.com/p/342105673

特征处理部分比较好理解,点的self、cross注意力机制实现建议看下源码(MultiHeadedAttention),

def attention(query, key, value):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    prob = torch.nn.functional.softmax(scores, dim=-1)
    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
    """ Multi-head attention to increase model expressivitiy """
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value):
        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = attention(query, key, value)
        self.prob.append(prob)
        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))

这里直接跳到最后的逻辑部分,这部分论文写的比较粗略,需要看下源码才知道在讲啥(也许有大佬不用看)。

看这里,即是说推理时检出的匹配关系是不考虑最后一行和最后一列的,而是设定阈值,将不合格的匹配过滤掉

        # Get the matches with score above "match_threshold".
        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)  # [0,0...,1,..0]
        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
        zero = scores.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
        valid1 = mutual1 & valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

推理时代码如下,可见图A和图B互相匹配的结果(按照score的行列取最大值的index)不必完全一致:

                kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0]
                matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy()
                image0 = read_image_modified(image0, opt.resize, opt.resize_float)
                image1 = read_image_modified(image1, opt.resize, opt.resize_float)
                valid = matches > -1
                mkpts0 = kpts0[valid]
                mkpts1 = kpts1[matches[valid]]
                mconf = conf[valid]

然后看求解分配矩阵的部分,couplings为相似度得分矩阵,为其添加了最后一行一列,并赋值为1,在原文提到的约束下,使用sinkhorn(待看)算法求解,求出分配矩阵Z,

# b(m+1)(n+1), b(m+1), b(n+1)
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
    """ Perform Sinkhorn Normalization in Log-space for stability"""
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
    for _ in range(iters):
        # [log(m+n) ..., log(n)+log(m+n)] - []
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)  # b(m+1)
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
    return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int):
    """ Perform Differentiable Optimal Transport in Log-space for stability"""
    b, m, n = scores.shape
    one = scores.new_tensor(1)
    ms, ns = (m*one).to(scores), (n*one).to(scores)

    bins0 = alpha.expand(b, m, 1)  # only a new view
    bins1 = alpha.expand(b, 1, n)
    alpha = alpha.expand(b, 1, 1)

    # b(m+1)(n+1), 额外行列值为1
    couplings = torch.cat([torch.cat([scores, bins0], -1),  # bmn,bm1->bm(n+1)
                           torch.cat([bins1, alpha], -1)], 1)  # b1n,b11->b1(n+1)

    norm = - (ms + ns).log()
    log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])  # m+1: [log(m+n) ..., log(n)+log(m+n)]
    log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])  # n+1: [log(m+n) ..., log(m)+log(m+n)]
    log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)  # b(m+1), b(n+1)

    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
    Z = Z - norm  # multiply probabilities by M+N
    return Z

损失函数就是最大化这个分配矩阵Z,即下面的scores矩阵,匹配对中肯定不包含dustbin点的,也就是说对dustbin的考量蕴含在sinkhorn中,注意下面的函数调用的参数self.bin_score,这是superglue网络的一个可学习的参数:

        bin_score = torch.nn.Parameter(torch.tensor(1.))         self.register_parameter('bin_score', bin_score) 回头看上面的log_optimal_transport代码,每次给couplings的额外行列赋的值就是这个值。
        all_matches = data['all_matches'].permute(1,2,0) # shape=torch.Size([1, 87, 2])

        ……

        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config['sinkhorn_iterations'])

        ……

        # check if indexed correctly
        loss = []
        for i in range(len(all_matches[0])):
            x = all_matches[0][i][0]
            y = all_matches[0][i][1]
            loss.append(-torch.log( scores[0][x][y].exp() )) # check batch size == 1 ?

原文里对分配矩阵的约束如下,这个应该是引入sinkhorn的作用,在代码中分配矩阵P_head并没有显式出现,所以没法辅助我理解这个公式:

 相对应的,P的约束就很好理解:

 

标签:dim,SuperGlue,torch,log,matches,self,论文,笔记,scores
来源: https://www.cnblogs.com/hellcat/p/15260145.html