PointNet++预测结果可视化
作者:互联网
目前网上对于PointNet++的预测结果可视化的资料比较少,一般都是直接可视化数据集。下面介绍一种我利用Matplotlib可视化预测的代码,希望能够对大家有所帮助。
原理:
简单阐述一下代码的原理,首先我们利用网络给出输入图像的预测结果,并存入为txt文件;然后利用Matplotlib读取txt文件,画出3d图像。
预测效果
由于3d图像无法直接显示,需要特定的软件才行,因此我们只能将它转换为2d图像报错,其结果如下
准备材料
1)网络模型。本代码是基于 pytorch版的PointNet++实现的,所以网络模型的输入输出格式要保持一致。(代码中会在相应位置进行注释)
2)数据集。既然都准备预测结果了,应该对网络的整体流程都有一个比较详细的理解了,数据集的准备在这里就不详细介绍了。需要注意的是,不同的数据集输出的内容不一样。例如ShapeNet会输出点云图像,Label(分类类别)和target(分割类别),而S3DIS就只会输出点云图像和分割类别。这里就需要我们进行调整一下。
3)训练好的权重文件。将训练好的权重文件以PointNet++的格式保存。
准备好以上材料就可以开始准备生成预测结果了。
主要代码
class Generate_txt_and_3d_img:
def __init__(self,img_root,target_root,num_classes,testDataLoader,model_dict,color_map=None):
self.img_root = img_root # 点云数据路径
self.target_root = target_root # 生成txt标签和预测结果路径
self.testDataLoader = testDataLoader
self.num_classes = num_classes
self.color_map = color_map
self.heat_map = False # 控制是否输出heatmap
self.label_path_txt = os.path.join(self.target_root, 'label_txt') # 存放label的txt文件
self.make_dir(self.label_path_txt)
# 拿到模型 并加载权重
self.model_name = []
self.model = []
self.model_weight_path = []
for k,v in model_dict.items():
self.model_name.append(k)
self.model.append(v[0])
self.model_weight_path.append(v[1])
# 加载权重
self.load_cheackpoint_for_models(self.model_name,self.model,self.model_weight_path)
# 创建文件夹
self.all_pred_image_path = [] # 所有预测结果的路径列表
self.all_pred_txt_path = [] # 所有预测txt的路径列表
for n in self.model_name:
self.make_dir(os.path.join(self.target_root,n+'_predict_txt'))
self.make_dir(os.path.join(self.target_root, n + '_predict_image'))
self.all_pred_txt_path.append(os.path.join(self.target_root,n+'_predict_txt'))
self.all_pred_image_path.append(os.path.join(self.target_root, n + '_predict_image'))
"将模型对应的预测txt结果和img结果生成出来,对应几个模型就在列表中添加几个元素"
self.generate_predict_to_txt() # 生成预测txt
self.draw_3d_img() # 画图
def generate_predict_to_txt(self):
for batch_id, (points, label, target) in tqdm.tqdm(enumerate(self.testDataLoader),
total=len(self.testDataLoader),smoothing=0.9):
#点云数据、整个图像的标签、每个点的标签、 没有归一化的点云数据(带标签)torch.Size([1, 7, 2048])
points = points.transpose(2, 1)
#print('1',target.shape) # 1 torch.Size([1, 2048])
xyz_feature_point = points[:, :6, :] # B C N ---->B N C
# 将标签保存为txt文件
point_set_without_normal = np.asarray(torch.cat([points.permute(0, 2, 1),target[:,:,None]],dim=-1)).squeeze(0) # 代标签 没有归一化的点云数据 的numpy形式
np.savetxt(os.path.join(self.label_path_txt,f'{batch_id}_label.txt'), point_set_without_normal, fmt='%.04f') # 将其存储为txt文件
" points torch.Size([16, 2048, 6]) label torch.Size([16, 1]) target torch.Size([16, 2048])"
assert len(self.model) == len(self.all_pred_txt_path) , '路径与模型数量不匹配,请检查'
for n,model,pred_path in zip(self.model_name,self.model,self.all_pred_txt_path):
points = points.long()
seg_pred, trans_feat = model(points, self.to_categorical(label, 16))
seg_pred = seg_pred.cpu().data.numpy()
#=================================================
seg_pred = np.argmax(seg_pred, axis=-1) # 获得网络的预测结果 b n c
#=================================================
seg_pred = np.concatenate([np.asarray(xyz_feature_point), seg_pred[:, None, :]],
axis=1).transpose((0, 2, 1)).squeeze(0) # 将点云与预测结果进行拼接,准备生成txt文件
svae_path = os.path.join(pred_path, f'{n}_{batch_id}.txt')
np.savetxt(svae_path,seg_pred, fmt='%.04f')
def draw_3d_img(self):
# 调用matpltlib 画3d图像
each_label = os.listdir(self.label_path_txt) # 所有标签txt路径
self.label_path_3d_img = os.path.join(self.target_root, 'label_3d_img')
self.make_dir(self.label_path_3d_img)
assert len(self.all_pred_txt_path) == len(self.all_pred_image_path)
for i,(pre_txt_path,save_img_path,name) in enumerate(zip(self.all_pred_txt_path,self.all_pred_image_path,self.model_name)):
each_txt_path = os.listdir(pre_txt_path) # 拿到txt文件的全部名字
for idx,(txt,lab) in tqdm.tqdm(enumerate(zip(each_txt_path,each_label)),total=len(each_txt_path)):
if i == 0:
self.draw_each_img(os.path.join(self.label_path_txt, lab), idx,heat_maps=False)
self.draw_each_img(os.path.join(pre_txt_path,txt),idx,name=name,save_path=save_img_path,heat_maps=self.heat_map)
print(f'所有预测图片已生成完毕,请前往:{self.all_pred_image_path} 查看')
def draw_each_img(self,root,idx,name=None,skip=1,save_path=None,heat_maps=False):
"root:每个txt文件的路径"
points = np.loadtxt(root)[:, :3] # 点云的xyz坐标
points_all = np.loadtxt(root) # 点云的所有坐标
points = self.pc_normalize(points)
skip = skip # Skip every n points
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
point_range = range(0, points.shape[0], skip) # skip points to prevent crash
x = points[point_range, 0]
z = points[point_range, 1]
y = points[point_range, 2]
"根据传入的类别数 自定义生成染色板 标签 0对应 随机颜色1 标签1 对应随机颜色2"
if self.color_map is not None:
color_map = self.color_map
else:
color_map = {idx: i for idx, i in enumerate(np.linspace(0, 0.9, num_classes))}
Label = points_all[point_range, -1] # 拿到标签
# 将标签传入前面的字典,找到对应的颜色 并放入列表
Color = list(map(lambda x: color_map[x], Label))
ax.scatter(x, # x
y, # y
z, # z
c=Color, # Color, # height data for color
s=25,
marker=".")
ax.axis('auto') # {equal, scaled}
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.axis('off') # 设置坐标轴不可见
ax.grid(False) # 设置背景网格不可见
ax.view_init(elev=0, azim=0)
if save_path is None:
plt.savefig(os.path.join(self.label_path_3d_img,f'{idx}_label_img.png'), dpi=300,bbox_inches='tight',transparent=True)
else:
plt.savefig(os.path.join(save_path, f'{idx}_{name}_img.png'), dpi=300, bbox_inches='tight',
transparent=True)
def pc_normalize(self,pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def make_dir(self, root):
if os.path.exists(root):
print(f'{root} 路径已存在 无需创建')
else:
os.mkdir(root)
def to_categorical(self,y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
def load_cheackpoint_for_models(self,name,model,cheackpoints):
assert cheackpoints is not None,'请填写权重文件'
assert model is not None, '请实例化模型'
for n,m,c in zip(name,model,cheackpoints):
print(f'正在加载{n}的权重.....')
weight_dict = torch.load(os.path.join(c,'best_model.pth'))
m.load_state_dict(weight_dict['model_state_dict'])
print(f'{n}权重加载完毕')
if __name__ =='__main__':
import copy
img_root = r'你的数据集路径' # 数据集路径
target_root = r'保存预测结果的路径' # 输出结果路径
num_classes = 13 # 填写数据集的类别数
choice_dataset = 'S3dis'
# 导入模型 部分
"所有的模型以PointNet++为标准 输入两个参数 输出两个参数,如果模型仅输出一个,可以将其修改为多输出一个None!!!!"
#==============================================
from models.pointnet2_sem_seg import get_model as pointnet2
from models.finally_csa_part_seg import get_model as csa
from models.pointcouldtransformer_part_seg import get_model as pct
model1 = pointnet2(num_classes=num_classes).eval()
model2 = csa(num_classes, normal_channel=True).eval()
model3 = pct(num_class=num_classes,normal_channel=False).eval()
#============================================
# 实例化数据集
"Dataset同理,都按ShapeNet格式输出三个变量 point_set, cls, seg # pointset是点云数据,cls十六个大类别,seg是一个数据中,不同点对应的小类别"
"不是这个格式的话就手动添加一个"
if choice_dataset == 'ShapeNet':
print('实例化ShapeNet')
TEST_DATASET = PartNormalDataset(root=img_root, npoints=2048, split='test', normal_channel=True)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
drop_last=True)
color_map = {idx: i for idx, i in enumerate(np.linspace(0, 0.9, num_classes))}
else:
TEST_DATASET = S3DISDataset(split='test', data_root=img_root, num_point=4096, test_area=5,
block_size=1.0, sample_rate=1.0, transform=None)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True, drop_last=True)
color_maps = [(152, 223, 138), (174, 199, 232), (255, 127, 14), (91, 163, 138), (255, 187, 120), (188, 189, 34),
(140, 86, 75)
, (255, 152, 150), (214, 39, 40), (197, 176, 213), (196, 156, 148), (23, 190, 207), (112, 128, 144)]
color_map = []
for i in color_maps:
tem = ()
for j in i:
j = j / 255
tem += (j,)
color_map.append(tem)
print('实例化S3DIS')
#将模型和权重路径填写到字典中,以下面这个格式填写就可以了
# 如果加载权重报错,可以查看类里面的加载权重部分,进行对应修改即可
model_dict = {
'PonintNet': [model1,r'权重路径1'],
'CSA': [model2, r'权重路径2'],
'PCT': [model3,r'权重路径3']
}
c = Generate_txt_and_3d_img(img_root,target_root,num_classes,testDataLoader,model_dict,color_map)
通过这个脚本就能实现利用matplotlib画出预结果了。由于本人实力有限,这些代码都是基于PointNet++的,所以在预测时候应该属性PointNet++代码。一些注意点在代码中进行了注释和说明,在这里再次强调一下:
1.所以模型都是两输入两输出
2.记得导入数据集,数据集默认以ShapNet为模板 输出三个值,具体可见ShapeNet数据集。(数据集的导入没有放进来,如有需要记得自己导入)。训练和预测的数据集的预处理一定要一样,否则预测效果会很差。
3.所有模型和权重都要放在那个model_dict中
4.如果觉得保存的图像视角比较怪, 可以通过ax.view_init(elev=0, azim=0)设置角度
还有什么以后想起来再说吧。
开头导入
开头导入部分
"传入模型权重文件,读取预测点,生成预测的txt文件"
import tqdm
import matplotlib.pyplot as plt
import matplotlib
import torch
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
matplotlib.use("Agg")
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
这里一定要加
warnings.filterwarnings('ignore')
matplotlib.use("Agg")
这两行代码,否则在画图的时候会导致中断。Matplotlib的画图程序好像不能一次画太多,如果太多了可能会报错(大概要超出1w多张吧)
以上代码纯属自己手码,我使用是没有问题的。如果有什么问题,大家可以阅读一下代码进行对应调整或在评论区留言,希望能够帮到大家。
参考:
https://blog.csdn.net/ssq183/article/details/104603454/
最后,大家一定要记得:
标签:img,++,self,PointNet,可视化,path,model,txt,root 来源: https://blog.csdn.net/weixin_47142735/article/details/121884501