tensorflow2——warmup+Cos衰减
作者:互联网
class WarmUpCos(keras.callbacks.Callback): def __init__(self, lr_max,lr_min, warm_step,sum_step,bat): super(WarmUpCos, self).__init__() self.lr_max = lr_max self.lr_min = lr_min self.warm_step = warm_step self.sum_step = sum_step self.bat = bat def on_train_begin(self, batch, logs=None): self.init_lr = self.lr_max self.step = 0 def on_epoch_begin(self, epoch, logs=None): self.epoch = epoch def on_batch_end(self,batch, logs=None): self.step += 1 print('step:',self.step) # learning_decay_steps = 1 # learning_decay_rate = 0.999 warm_lr = self.lr_max * (self.step / self.warm_step) # decay_lr = max(self.init_lr * tf.pow(learning_decay_rate , ((step-self.warm_step) / learning_decay_steps)),self.lr_min) decay_lr = self.lr_max * ( 1 + math.cos( (self.step - self.warm_step) * math.pi / ( self.sum_step - self.warm_step) ) ) / 2 if self.step < self.warm_step: lr = warm_lr else: lr =decay_lr K.set_value(self.model.optimizer.lr, lr) warm_up = WarmUpCos(lr_rate, lr_min, warm_step=warm_epoch*int(train_x.shape[0]//bat),bat=bat,sum_step=epochs*int(train_x.shape[0]//bat)) s_model.fit(train_db, epochs=epochs, validation_data=test_db, callbacks=[warm_up])
搜索
复制
标签:tensorflow2,Cos,warmup,max,self,bat,step,warm,lr 来源: https://www.cnblogs.com/cxhzy/p/16496377.html