关于model.fuse().eval()理解
作者:互联网
问题描述:
在学习yolov5过程中,我们可以通过如下代码进行模型导入,为什么要使用fuse() 和 eval() ?
def get_model(weights):
# fuse conv_bn and repvgg
# only fuse conv_bn
model = torch.load(weights, map_location=device)['model'].float().fuse()
return model.eval()
问题分析:
fuse()是用来进行conv和bn层合并,为了提速模型推理速度。
eval()是模型进行预测推理时关闭BN(预测数据均值方差计算)和Dropout以免影响预测结果。具体如下:
- 训练过程中BN的变化
在训练过程中BN会不断计算均值和方差,训练结束后会得到最终的均值和方差,可以记作mean_train, variance_train。 - 预测过程中的BN的变化
如果预测过程中不适用model.eval(),BN 层还是会根据输入的数据继续计算均值和方差,相比于训练过程中的均值和方差发生了变化因此会导致预测结果发生变化。 - 训练过程中Dropout变化
训练过程中会依据设置的dropout比例会使一部分的网络连接不进行计算, - 预测过程中的Dropout变化
使用model.eval()会使所有网络连接参与计算,显然预测时都参与计算结果会更准确。
标签:预测,方差,BN,fuse,eval,model 来源: https://www.cnblogs.com/chentiao/p/16656918.html