SO-Net中分类器(classifier)的实现过程
作者:互联网
一、数据准备
该部分代码位于./data/modelnet_shrec_loader.py
中。读取的数据为pc_np
(点的坐标),surface_normal_np
(法向量),som_node_np
(som节点坐标)和class_id
(类别)。然后对数据增强,包括旋转、微扰、尺度变换和位移。返回点的坐标、法向量、类别、som节点和每个som节点在som节点中的k个近邻点的索引。
二、模型
该部分代码位于./models/classifier.py
。主要包括两个部分:编码网络和分类网络。
2.1编码网络
编码器如上图蓝色底色的部分所示。具体的定义在./models/networks.py
中。在Encoder类
的forward
中可以看到,其输入为点的坐标、法向量、节点坐标和节点到节点的knn索引。
2.1.1 SOM层
如上图所示,第一部分为SOM层。通过./util/som.py
的BatchSOM
类中的query_topk()
函数实现映射,返回mask,mask_row_max和min_idx。
(1)mask:其大小为[B,kN,M],第i个N*M行表示节点是否为N个点的第i近邻节点。
上图中红色的1表示第2个节点是第一个点的第一近邻点。蓝色的1表示第3个节点是第四个点的第一近邻点,而最下方的1表示第5个节点是第三个点的第二近邻点。
(2)mask_row_max:其大小为[B,M],表示每个节点是否存在近邻的点
其中1表示第i个节点是某个点的近邻点。
(3)min_idx:其大小为[B,kN],第i个N*1行表示N个点的第i近邻节点的索引。
与mask对应,表示最近邻索引值。
然后获取以每个节点为近邻的所有点的中心。
self.mask, mask_row_max, min_idx = self.som_builder.query_topk(x.data, k=self.opt.k) # BxkNxnode_num, Bxnode_num
mask_row_sum = torch.sum(self.mask, dim=1) # BxM
mask = self.mask.unsqueeze(1) # Bx1xkNxM
#将x和sn堆叠
x_list, sn_list = [], []
for i in range(self.opt.k):
x_list.append(x)
sn_list.append(sn)
x_stack = torch.cat(tuple(x_list), dim=2) # B x C x kN
sn_stack = torch.cat(tuple(sn_list), dim=2)# B x C x kN
# 计算以每个点为近邻的所有点的平均坐标,作为新的节点坐标
x_stack_data_unsqueeze = x_stack.data.unsqueeze(3) # BxCxkNx1
x_stack_data_masked = x_stack_data_unsqueeze * mask.float() # BxCxkNxM
cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float()+1e-5) # BxCxM,为了防止数值不稳定,即没有点以该节点为近邻点
self.som_builder.node = cluster_mean
self.som_node = self.som_builder.node
然后对于每个点进行去中心化,并与sn拼接到一起,作为输入。
node_expanded = self.som_node.data.unsqueeze(2) # BxCx1xM
self.centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN
self.x_decentered = (x_stack - self.centers).detach() # Bx3xkN
x_augmented = torch.cat((self.x_decentered, sn_stack), dim=1) # Bx6xkN
2.1.2 first_pointnet
该部分代码在./models/layers.py
。其实质是一个残差网络,每一层是一个EquivariantLayer的结构,其定义在同样在./models/layers.py
中。
(index_max是cuda的c++扩展,还没弄懂是什么意思,太菜了555)
2.1.3 knnlayer
该部分代码在./models/layers.py
中。首先计算每个center的前k个距离的center的索引:
coordinate_tensor = coordinate.data # Bx3xM 以节点为近邻点的所有点的中心
if precomputed_knn_I is not None:
assert precomputed_knn_I.size()[2] >= K
knn_I = precomputed_knn_I[:, :, 0:K]
else:
coordinate_Mx1 = coordinate_tensor.unsqueeze(3) # Bx3xMx1
coordinate_1xM = coordinate_tensor.unsqueeze(2) # Bx3x1xM
norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1) # BxMxM, each row corresponds to each coordinate - other coordinates
knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True) # BxMxK 每个center到其他center的前k个最近距离的距离和索引
然后对于每个center,计算k个近邻center的坐标、均值及去中心化后的坐标值:
neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I) # Bx3xMxK 每个center最近的k个center的坐标
if center_type == 'avg': # 如果以k个平均值为中心
neighbors_center = torch.mean(neighbors, dim=3, keepdim=True) # Bx3xMx1 每个center最近的k个center的坐标中心
elif center_type == 'center': # 以center本身为中心
neighbors_center = coordinate_tensor.unsqueeze(3) # Bx3xMx1 每个center的坐标
neighbors_decentered = (neighbors - neighbors_center).detach() # Bx3xMxK 每个center最近的k个center的去中心坐标
neighbors_center = neighbors_center.squeeze(3).detach() # Bx3xM 中心坐标
最后得到每个center最近邻的k个center的特征向量,并作为卷积层的输入,该卷积层是在同一个文件中定义的。返回值是center点的坐标和特征向量。
x_neighbors = operations.knn_gather_by_indexing(x, knn_I) # BxCxMxK 每个center最近邻k个center的特征向量
x_augmented = torch.cat((neighbors_decentered, x_neighbors), dim=1) # Bx(3+C)xMxK 与中心坐标拼接
2.1.4 final_pointnet
该部分代码在./models/layers.py
中。是一个常规的pointnet网络结构。得到全局的特征向量(图中global feature)。
2.2 分类器
该部分代码位于models/networks.py
中。其实质是一个三层的全连接层,输出点云对于每个类别的概率,最后用交叉熵损失进行训练。
三、测试与保存
3.1 get_current_errors & visualizer.plot_current_errors
该部分代码在./models/classifier.py
中。统计预测的准确率,并进行可视化(loss-time曲线)。
3.2 model.save_network
保存模型到指定路径。
def save_network(self, network, network_label, epoch_label, gpu_id):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.opt.checkpoints_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if gpu_id>=0 and torch.cuda.is_available():
# torch.cuda.device(gpu_id)
network.to(self.opt.device)
3.3 model.update_learning_rate
更新学习速率。
def update_learning_rate(self, ratio):
lr_clip = 0.00001
# encoder
lr_encoder = self.old_lr_encoder * ratio
if lr_encoder < lr_clip:
lr_encoder = lr_clip
for param_group in self.optimizer_encoder.param_groups:
param_group['lr'] = lr_encoder
print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder))
self.old_lr_encoder = lr_encoder
# classifier
lr_classifier = self.old_lr_classifier * ratio
if lr_classifier < lr_clip:
lr_classifier = lr_clip
for param_group in self.optimizer_classifier.param_groups:
param_group['lr'] = lr_classifier
print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier))
self.old_lr_classifier = lr_classifier
标签:center,self,mask,节点,分类器,lr,SO,Net,classifier 来源: https://blog.csdn.net/qq_43173635/article/details/121107760