CIFAR10-网络训练技术
作者:互联网
1、数据增强
1)随机裁剪
在原始图片的每一边pad 4个 pixels,然后再裁切成32*32的图片
distorted_images = tf.image.resize_image_with_crop_or_pad(record_images, imageHeight+8, imageWidth+8) distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])
2)随机翻转、调节亮度和对比度、标准化
for i in xrange(len(distorted_images)): distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i]) distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63) distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8) distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])
2、学习率
1)线性衰减
2)指数衰减
3)按区间衰减
global_step = tf.Variable(0, trainable=False) boundaries = [10000, 15000, 20000, 25000] values = [0.1, 0.05, 0.01, 0.005, 0.001] learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
3、weight decay
#Add the l2 weights to the loss
#Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy_mean + l2_loss
4、优化器
#Define the optimizer optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) #Relate to the batch normalization update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_op = optimizer.minimize(loss, global_step)
标签:loss,训练,CIFAR10,image,网络,images,l2,distorted,tf 来源: https://www.cnblogs.com/wt-seu/p/12382130.html