其他分享
首页 > 其他分享> > SO-Net中分类器(classifier)的实现过程

SO-Net中分类器(classifier)的实现过程

作者:互联网

一、数据准备

  该部分代码位于./data/modelnet_shrec_loader.py中。读取的数据为pc_np(点的坐标),surface_normal_np(法向量),som_node_np(som节点坐标)和class_id(类别)。然后对数据增强,包括旋转、微扰、尺度变换和位移。返回点的坐标、法向量、类别、som节点和每个som节点在som节点中的k个近邻点的索引。

二、模型

  该部分代码位于./models/classifier.py。主要包括两个部分:编码网络和分类网络。
在这里插入图片描述

2.1编码网络

  编码器如上图蓝色底色的部分所示。具体的定义在./models/networks.py中。在Encoder类forward中可以看到,其输入为点的坐标、法向量、节点坐标和节点到节点的knn索引。

2.1.1 SOM层

  如上图所示,第一部分为SOM层。通过./util/som.pyBatchSOM类中的query_topk()函数实现映射,返回mask,mask_row_max和min_idx。
(1)mask:其大小为[B,kN,M],第i个N*M行表示节点是否为N个点的第i近邻节点。
在这里插入图片描述

上图中红色的1表示第2个节点是第一个点的第一近邻点。蓝色的1表示第3个节点是第四个点的第一近邻点,而最下方的1表示第5个节点是第三个点的第二近邻点。
(2)mask_row_max:其大小为[B,M],表示每个节点是否存在近邻的点
在这里插入图片描述

其中1表示第i个节点是某个点的近邻点。
(3)min_idx:其大小为[B,kN],第i个N*1行表示N个点的第i近邻节点的索引。
在这里插入图片描述

与mask对应,表示最近邻索引值。
  然后获取以每个节点为近邻的所有点的中心。

		self.mask, mask_row_max, min_idx = self.som_builder.query_topk(x.data, k=self.opt.k)  # BxkNxnode_num, Bxnode_num
        mask_row_sum = torch.sum(self.mask, dim=1)  # BxM
        mask = self.mask.unsqueeze(1)  # Bx1xkNxM

        #将x和sn堆叠
        x_list, sn_list = [], []
        for i in range(self.opt.k):
            x_list.append(x)
            sn_list.append(sn)
        x_stack = torch.cat(tuple(x_list), dim=2)  # B x C x kN
        sn_stack = torch.cat(tuple(sn_list), dim=2)# B x C x kN

        # 计算以每个点为近邻的所有点的平均坐标,作为新的节点坐标
        x_stack_data_unsqueeze = x_stack.data.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float()  # BxCxkNxM
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float()+1e-5)  # BxCxM,为了防止数值不稳定,即没有点以该节点为近邻点
        self.som_builder.node = cluster_mean
        self.som_node = self.som_builder.node

然后对于每个点进行去中心化,并与sn拼接到一起,作为输入。

		node_expanded = self.som_node.data.unsqueeze(2)  # BxCx1xM
        self.centers = torch.sum(mask.float() * node_expanded, dim=3).detach()  # BxCxkN

        self.x_decentered = (x_stack - self.centers).detach()  # Bx3xkN
        x_augmented = torch.cat((self.x_decentered, sn_stack), dim=1)  # Bx6xkN

2.1.2 first_pointnet

  该部分代码在./models/layers.py。其实质是一个残差网络,每一层是一个EquivariantLayer的结构,其定义在同样在./models/layers.py中。

在这里插入图片描述

(index_max是cuda的c++扩展,还没弄懂是什么意思,太菜了555)

2.1.3 knnlayer

  该部分代码在./models/layers.py中。首先计算每个center的前k个距离的center的索引:

        coordinate_tensor = coordinate.data  # Bx3xM  以节点为近邻点的所有点的中心
        if precomputed_knn_I is not None:
            assert precomputed_knn_I.size()[2] >= K
            knn_I = precomputed_knn_I[:, :, 0:K]
        else:
            coordinate_Mx1 = coordinate_tensor.unsqueeze(3)  # Bx3xMx1
            coordinate_1xM = coordinate_tensor.unsqueeze(2)  # Bx3x1xM
            norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1)  # BxMxM, each row corresponds to each coordinate - other coordinates
            knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True)  # BxMxK 每个center到其他center的前k个最近距离的距离和索引

然后对于每个center,计算k个近邻center的坐标、均值及去中心化后的坐标值:

        neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I)  # Bx3xMxK 每个center最近的k个center的坐标
        if center_type == 'avg':  # 如果以k个平均值为中心
            neighbors_center = torch.mean(neighbors, dim=3, keepdim=True)  # Bx3xMx1 每个center最近的k个center的坐标中心
        elif center_type == 'center':  # 以center本身为中心
            neighbors_center = coordinate_tensor.unsqueeze(3)  # Bx3xMx1 每个center的坐标
        neighbors_decentered = (neighbors - neighbors_center).detach() # Bx3xMxK 每个center最近的k个center的去中心坐标
        neighbors_center = neighbors_center.squeeze(3).detach()  # Bx3xM  中心坐标

最后得到每个center最近邻的k个center的特征向量,并作为卷积层的输入,该卷积层是在同一个文件中定义的。返回值是center点的坐标和特征向量。

        x_neighbors = operations.knn_gather_by_indexing(x, knn_I)  # BxCxMxK 每个center最近邻k个center的特征向量
        x_augmented = torch.cat((neighbors_decentered, x_neighbors), dim=1)  # Bx(3+C)xMxK 与中心坐标拼接

2.1.4 final_pointnet

  该部分代码在./models/layers.py中。是一个常规的pointnet网络结构。得到全局的特征向量(图中global feature)。

2.2 分类器

  该部分代码位于models/networks.py中。其实质是一个三层的全连接层,输出点云对于每个类别的概率,最后用交叉熵损失进行训练。

三、测试与保存

3.1 get_current_errors & visualizer.plot_current_errors

  该部分代码在./models/classifier.py中。统计预测的准确率,并进行可视化(loss-time曲线)。

3.2 model.save_network

  保存模型到指定路径。

    def save_network(self, network, network_label, epoch_label, gpu_id):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.opt.checkpoints_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if gpu_id>=0 and torch.cuda.is_available():
            # torch.cuda.device(gpu_id)
            network.to(self.opt.device)

3.3 model.update_learning_rate

  更新学习速率。

def update_learning_rate(self, ratio):
        lr_clip = 0.00001

        # encoder
        lr_encoder = self.old_lr_encoder * ratio
        if lr_encoder < lr_clip:
            lr_encoder = lr_clip
        for param_group in self.optimizer_encoder.param_groups:
            param_group['lr'] = lr_encoder
        print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder))
        self.old_lr_encoder = lr_encoder

        # classifier
        lr_classifier = self.old_lr_classifier * ratio
        if lr_classifier < lr_clip:
            lr_classifier = lr_clip
        for param_group in self.optimizer_classifier.param_groups:
            param_group['lr'] = lr_classifier
        print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier))
        self.old_lr_classifier = lr_classifier

标签:center,self,mask,节点,分类器,lr,SO,Net,classifier
来源: https://blog.csdn.net/qq_43173635/article/details/121107760