Pytorch:模型Finetune
作者:互联网
通常会采用一个已经训练好的模型的权值参数作为模型的初始化参数,称之为Finetune。本质上就是构建新的模型,拥有一个较好的权值初始化。
一、Finetune权值初始化
- 保存模型:保存一个预训练好的模型
- 加载模型:把预训练模型中的权值取出来
- 初始化模型:将权值对应的放到新模型中
step1:保存模型
net = Net()
torch.save(net.state_dict(),'net_params.pkl')
step2:加载模型
pretrained_dict = torch.load('net_params.pkl')
step3:初始化
# 创建net
net = Net()
# 获取已创建net的state_dict
net_state_dict = net.state_dict()
# 将pretrain_dict中 不属于net_state_dict的键剃掉:
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if f in net_state_dict}
# 用与训练模型俄参数字典对新模型的参数字典net_state_dice进行更新
net_state_dict.update(pretrained_dict_1)
# 将更新了的参数字典放回网络
net.load_state_dict(net_state_dict)
二、不同层设置不同的学习率
采用fintune的训练过程中,有时候希望前面的学习率低一些,更新慢一些,后面的全连接层的学习率大一些,相对更新的快一些。将原始的参数组划分成多个组,每个组分别设置相应的学习率。
ignored_params = list(map(id,net.fc3.parameters()) # 返回parameters的内存地址
base_params = filter(lambda p:id(p) not in ignored_params,net.parameters())
optimizer = optim.SGD([
{'params':base_params},
{'params':net.fc3.parameters(),'lr':0.001*10}
],0.001,momentum=0.9,weight_decay= 1e-4)
以上代码的意思就是,将fc3层的参数net.fc3.parameters()从原始netparameters()中剥离出来,两层设置不同的学习率。
标签:parameters,Finetune,Pytorch,state,dict,params,net,模型 来源: https://blog.csdn.net/weixin_39393712/article/details/88954309