【转载】 tensorflow batch_normalization的正确使用姿势
作者:互联网
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/computerme/article/details/80836060
————————————————
BN在如今的CNN结果中已经普遍应用,在tensorflow中可以通过tf.layers.batch_normalization()
这个op来使用BN。该op隐藏了对BN的mean var alpha beta参数的显示声明,因此在训练和部署测试中需要特征注意正确使用BN的姿势。
###正确使用BN训练
注意把tf.layers.batch_normalization(x, training=is_training,name=scope)
输入参数的training=True
。另外需要在来训练中添加update_ops
以便在每一次训练完后及时更新BN的参数。
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): #保证train_op在update_ops执行之后再执行。 train_op = optimizer.minimize(loss)
###正确保存带BN的模型
保存模型的时候不能只保存trainable_variables
,因为BN的参数不属于trainable_variables
。为了方便,可以用tf.global_variables()
。使用姿势如下
saver = tf.train.Saver(var_list=tf.global_variables()) savepath = saver.save(sess, 'here_is_your_personal_model_path’)
###正确读取带BN的模型
与保存类似,读的时候变量也需要为 global_variables 。如下:
saver = tf.train.Saver() or saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, 'here_is_your_personal_model_path')
PS:inference的时候还需要把tf.layers.batch_normalization(x, training=is_training,name=scope)
这里的training
设为False
标签:training,BN,batch,saver,train,tf,tensorflow,variables,normalization 来源: https://www.cnblogs.com/devilmaycry812839668/p/12444494.html