pytorch余弦退火学习率和warmup实现
作者:互联网
不说废话,直接上代码warmup_lr_scheduler.py:
from torch.optim.lr_scheduler import _LRScheduler
import warnings
import math
class CosineAnnealingLRWarmup(_LRScheduler):
def __init__(self, optimizer, T_max, eta_min=1.0e-5, last_epoch=-1, verbose=False,
warmup_steps=2, warmup_start_lr=1.0e-5):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLRWarmup, self).__init__(optimizer, last_epoch, verbose)
self.warmup_steps=warmup_steps
self.warmup_start_lr = warmup_start_lr
if warmup_steps>0:
self.base_warup_factors = [
(base_lr/warmup_start_lr)**(1.0/self.warmup_steps)
for base_lr in self.base_lrs
]
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
if hasattr(self,'warmup_steps'):
if self.last_epoch<self.warmup_steps:
return [self.warmup_start_lr*(warmup_factor**self.last_epoch)
for warmup_factor in self.base_warup_factors]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.T_max - self.warmup_steps)))*0.5
for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
使用方式, optimizer中,对每个group设定的初始学习率:
lr_scheduler_warmup = CosineAnnealingLRWarmup(optimizer,
T_max=100,
eta_min=1.0e-4,
last_epoch=-1,
warmup_steps=10,
warmup_start_lr=1.0e-5)
for i in range(args.epochs):
lr_scheduler_warmup.step()
print(i, 'lr: ', lr_scheduler_warmup.get_last_lr())
打印的日志:
0 lr: [1.6986464646342476e-05, 1.6986464646342476e-05]
1 lr: [2.8853998118144274e-05, 2.8853998118144274e-05]
2 lr: [4.9012741893949e-05, 4.9012741893949e-05]
3 lr: [8.325532074018734e-05, 8.325532074018734e-05]
4 lr: [0.00014142135623730956, 0.00014142135623730956]
5 lr: [0.00024022488679628634, 0.00024022488679628634]
6 lr: [0.00040805715467367407, 0.00040805715467367407]
7 lr: [0.0006931448431551469, 0.0006931448431551469]
8 lr: [0.0011774080373049502, 0.0011774080373049502]
9 lr: [0.002, 0.002]
10 lr: [0.001999421285668141, 0.001999421285668141]
11 lr: [0.0019976858477468327, 0.0019976858477468327]
12 lr: [0.0019947958005998596, 0.0019947958005998596]
13 lr: [0.0019907546653044916, 0.0019907546653044916]
14 lr: [0.0019855673653615975, 0.0019855673653615975]
15 lr: [0.0019792402206971153, 0.0019792402206971153]
16 lr: [0.0019717809399621964, 0.0019717809399621964]
17 lr: [0.001963198611141403, 0.001963198611141403]
18 lr: [0.001953503690480396, 0.001953503690480396]
19 lr: [0.0019427079897466131, 0.0019427079897466131]
20 lr: [0.001930824661838448, 0.001930824661838448]
21 lr: [0.0019178681847604707, 0.0019178681847604707]
22 lr: [0.0019038543439842087, 0.0019038543439842087]
23 lr: [0.0018888002132159806, 0.0018888002132159806]
24 lr: [0.001872724133595217, 0.001872724133595217]
25 lr: [0.0018556456913486046, 0.0018556456913486046]
26 lr: [0.0018375856939272896, 0.0018375856939272896]
27 lr: [0.0018185661446562002, 0.0018185661446562002]
28 lr: [0.001798610215926386, 0.001798610215926386]
29 lr: [0.001777742220963029, 0.001777742220963029]
30 lr: [0.0017559875842035246, 0.0017559875842035246]
31 lr: [0.0017333728103217185, 0.0017333728103217185]
32 lr: [0.0017099254519360473, 0.0017099254519360473]
33 lr: [0.0016856740760409154, 0.0016856740760409154]
34 lr: [0.0016606482292022124, 0.0016606482292022124]
35 lr: [0.0016348784015593754, 0.0016348784015593754]
36 lr: [0.0016083959896778495, 0.0016083959896778495]
37 lr: [0.0015812332582972094, 0.0015812332582972094]
38 lr: [0.0015534233010215447, 0.0015534233010215447]
39 lr: [0.001525, 0.001525]
40 lr: [0.0014959979846465962, 0.0014959979846465962]
41 lr: [0.0014664525894496235, 0.0014664525894496235]
42 lr: [0.0014363998109220102, 0.0014363998109220102]
43 lr: [0.0014058762637451164, 0.0014058762637451164]
44 lr: [0.0013749191361593855, 0.0013749191361593855]
45 lr: [0.0013435661446562, 0.0013435661446562]
46 lr: [0.0013118554880261492, 0.0013118554880261492]
47 lr: [0.0012798258008196845, 0.0012798258008196845]
48 lr: [0.0012475161062768714, 0.0012475161062768714]
49 lr: [0.001214965768783584, 0.001214965768783584]
50 lr: [0.0011822144459120625, 0.0011822144459120625]
51 lr: [0.0011493020401042709, 0.0011493020401042709]
52 lr: [0.0011162686500569192, 0.0011162686500569192]
53 lr: [0.0010831545218673762, 0.0010831545218673762]
54 lr: [0.00105, 0.00105]
55 lr: [0.0010168454781326244, 0.0010168454781326244]
56 lr: [0.0009837313499430809, 0.0009837313499430809]
57 lr: [0.0009506979598957294, 0.0009506979598957294]
58 lr: [0.0009177855540879379, 0.0009177855540879379]
59 lr: [0.0008850342312164163, 0.0008850342312164163]
60 lr: [0.0008524838937231288, 0.0008524838937231288]
61 lr: [0.0008201741991803156, 0.0008201741991803156]
62 lr: [0.0007881445119738509, 0.0007881445119738509]
63 lr: [0.0007564338553438001, 0.0007564338553438001]
64 lr: [0.0007250808638406148, 0.0007250808638406148]
65 lr: [0.0006941237362548836, 0.0006941237362548836]
66 lr: [0.00066360018907799, 0.00066360018907799]
67 lr: [0.0006335474105503764, 0.0006335474105503764]
68 lr: [0.0006040020153534041, 0.0006040020153534041]
69 lr: [0.0005750000000000002, 0.0005750000000000002]
70 lr: [0.0005465766989784554, 0.0005465766989784554]
71 lr: [0.0005187667417027907, 0.0005187667417027907]
72 lr: [0.0004916040103221507, 0.0004916040103221507]
73 lr: [0.0004651215984406246, 0.0004651215984406246]
74 lr: [0.0004393517707977876, 0.0004393517707977876]
75 lr: [0.00041432592395908465, 0.00041432592395908465]
76 lr: [0.0003900745480639528, 0.0003900745480639528]
77 lr: [0.00036662718967828134, 0.00036662718967828134]
78 lr: [0.0003440124157964757, 0.0003440124157964757]
79 lr: [0.000322257779036971, 0.000322257779036971]
80 lr: [0.0003013897840736142, 0.0003013897840736142]
81 lr: [0.0002814338553438, 0.0002814338553438]
82 lr: [0.00026241430607271046, 0.00026241430607271046]
83 lr: [0.00024435430865139536, 0.00024435430865139536]
84 lr: [0.00022727586640478324, 0.00022727586640478324]
85 lr: [0.00021119978678401958, 0.00021119978678401958]
86 lr: [0.00019614565601579133, 0.00019614565601579133]
87 lr: [0.0001821318152395293, 0.0001821318152395293]
88 lr: [0.00016917533816155207, 0.00016917533816155207]
89 lr: [0.0001572920102533871, 0.0001572920102533871]
90 lr: [0.00014649630951960415, 0.00014649630951960415]
91 lr: [0.00013680138885859727, 0.00013680138885859727]
92 lr: [0.00012821906003780334, 0.00012821906003780334]
93 lr: [0.00012075977930288461, 0.00012075977930288461]
94 lr: [0.00011443263463840238, 0.00011443263463840238]
95 lr: [0.00010924533469550826, 0.00010924533469550826]
96 lr: [0.00010520419940014038, 0.00010520419940014038]
97 lr: [0.00010231415225316702, 0.00010231415225316702]
98 lr: [0.00010057871433185903, 0.00010057871433185903]
99 lr: [0.0001, 0.0001]
标签:last,warmup,05,退火,self,pytorch,lr,steps 来源: https://blog.csdn.net/qq_22751305/article/details/123421696