Person_reID_baseline_pytorch 源码解析之 test.py
作者:互联网
源码中有两个用于测试的脚本: test.py 和 evaluate_gpu.py 。其中, test.py 加载通过脚本 train.py 训练好的模型,实现对 query 和 gallery 图片的特征提取;本文对脚本 test.py 进行解析。
1. 加载模型和数据
首先需要载入训练好的模型,这里以基于 Resnet50 输出类别为 751 类的行人重识别模型 ft_net 为例。
model_structure = ft_net(751)
model = load_network(model_structure)
然后需要载入经过预处理的 gallery 和 query 数据集
data_transforms = transforms.Compose([
transforms.Resize((256,128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=0) for x in ['gallery','query']}
加载预处理过的数据集和训练好的模型,然后使用函数 extract_feature 进行特征提取
with torch.no_grad():
gallery_feature = extract_feature(model,dataloaders['gallery'])
query_feature = extract_feature(model,dataloaders['query'])
2. 完成特征提取
extract_feature 是 test.py 中非常重要的一个函数,用于提取图片的特征,下面对它逐行解析
def extract_feature(model,dataloaders):
features = torch.FloatTensor()
count = 0
# 加载数据集
for data in dataloaders:
img, label = data
n, c, h, w = img.size()
count += n
# 统计数据集图片数量
print(count)
ff = torch.FloatTensor(n,512).zero_().cuda()
for i in range(2):
if(i==1):
# 翻转图片
img = fliplr(img)
# 将图片变成 Variable,准备加载到网络中
input_img = Variable(img.cuda())
# 缩放尺寸 multiple_scale
for scale in ms:
if scale != 1:
# bicubic is only available in pytorch>= 1.1
input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic', align_corners=False)
# 模型推理
outputs = model(input_img)
# 拼接多尺度预测结果
ff += outputs
# norm feature 特征归一化
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
# 返回提取到的特征
features = torch.cat((features,ff.data.cpu()), 0)
return features
3. 实现特征归一化
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
这里是在输入张量 ff 的第 1 维进行 L2-norm,即 2 范数归一化。特征向量中每个元素均除以向量的L2范数。
pytorch 中使用 torch.norm 计算张量的范数。
fnorm = torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
- input 输入张量
- p 是范数计算中的幂指数值,p = 2 时即为 2 范数
- dim 指定计算的维度,如果 dim 是整数值,则计算向量范数。当输入张量 input 超过2维,将在最后一维计算向量范数
- keepdim 指明是否保留输出张量的维度dim
- out 输出张量
- dtype 返回张量的期待数据类型
令特征向量除以向量的L2范数,expand_as 函数将范数 fnorm 扩展成张量 ff 相同的维度。
ff = ff.div(fnorm.expand_as(ff))
然后使用 tensor.div 完成除法。
Tensor.div(value, *, rounding_mode=None)
最后,使用 torch.cat 在第 0 维上拼接输入张量
features = torch.cat((features,ff.data.cpu()), 0)
4. 生成 Matlab 文件
通过上述步骤实现了 query 和 gallery 图片特征的提取,将特征矩阵存储到 pytorch_result.mat 文件中。
# Save to Matlab for check
result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
scipy.io.savemat('pytorch_result.mat',result)
为了评估模型效果,还要记录图片的 label 和 camera 。
这里使用 get_id 函数通过图片名称获取 label 和 camera 信息。
def get_id(img_path):
camera_id = []
labels = []
for path, v in img_path:
#filename = path.split('/')[-1]
filename = os.path.basename(path)
label = filename[0:4]
camera = filename.split('c')[1]
if label[0:2]=='-1':
labels.append(-1)
else:
labels.append(int(label))
camera_id.append(int(camera[0]))
return camera_id, labels
gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs
gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)
生成的 Matlab 文件将被脚本 evaluate_gpu.py 使用,用于计算模型的评估指标。
参考链接
- pytorch求范数函数——torch.norm
- pytorch torch.norm 文档
- Pytorch expand_as()函数
- torch.cat()函数的官方解释,详解以及例子
- torch.stack()的官方解释,详解以及例子
标签:torch,baseline,img,py,label,源码,ff,query,gallery 来源: https://blog.csdn.net/qq_39220334/article/details/121630259