model.load_state_dict(state_dict)报错问题
作者:互联网
看一下训练得到的模型参数
state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
state_dict = state_dict['model']
看一下参数
for k,v in state_dict.items():
print(k)
输出:
module.block.0.layers.0.weight
module.block.0.layers.0.bias
module.block.0.layers.2.weight
module.block.0.layers.2.bias
module.block.0.layers.4.weight
module.block.0.layers.4.bias
module.block.0.layers.6.weight
module.block.0.layers.6.bias
module.block.0.layers.8.weight
module.block.0.layers.8.bias
module.block.2.layers.0.weight
再看一下网络模型的参数
model = CascadeNetwork(**args['network'])
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
print(k) #只打印key值,不打印具体参数。
输出:
block.0.layers.0.weight
block.0.layers.0.bias
block.0.layers.2.weight
block.0.layers.2.bias
block.0.layers.4.weight
block.0.layers.4.bias
block.0.layers.6.weight
block.0.layers.6.bias
解决方法
对load的模型创建新的字典,去掉不需要的key值"module".
首先加载我训练好的模型
state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
state_dict = state_dict['model']
然后创建一个新的词典that does not contain module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
最后把参数导入网络模型
model.load_state_dict(new_state_dict) # 从新加载这个模型。
标签:weight,block.0,module,state,bias,报错,dict 来源: https://blog.csdn.net/xuru_0927/article/details/119274321