其他分享
首页 > 其他分享> > CIFAR10-网络训练技术

CIFAR10-网络训练技术

作者:互联网

1、数据增强

  1)随机裁剪

  在原始图片的每一边pad 4个 pixels,然后再裁切成32*32的图片

distorted_images = tf.image.resize_image_with_crop_or_pad(record_images, 
                                                          imageHeight+8, imageWidth+8)
distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])

  2)随机翻转、调节亮度和对比度、标准化

for i in xrange(len(distorted_images)):
    distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i])
    distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63)
    distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8)
    distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])

 

2、学习率

  1)线性衰减

  2)指数衰减

  3)按区间衰减

global_step = tf.Variable(0, trainable=False)
boundaries = [10000, 15000, 20000, 25000]
values = [0.1, 0.05, 0.01, 0.005, 0.001]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

 

3、weight decay

#Add the l2 weights to the loss
#Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy_mean + l2_loss

 

4、优化器

#Define the optimizer
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
 
#Relate to the batch normalization
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    opt_op = optimizer.minimize(loss, global_step)

 

标签:loss,训练,CIFAR10,image,网络,images,l2,distorted,tf
来源: https://www.cnblogs.com/wt-seu/p/12382130.html