其他分享
首页 > 其他分享> > model.load_state_dict(state_dict)报错问题

model.load_state_dict(state_dict)报错问题

作者:互联网

在这里插入图片描述

看一下训练得到的模型参数

    state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
    state_dict = state_dict['model']
看一下参数
for k,v in state_dict.items():
    print(k)

输出:
module.block.0.layers.0.weight
module.block.0.layers.0.bias
module.block.0.layers.2.weight
module.block.0.layers.2.bias
module.block.0.layers.4.weight
module.block.0.layers.4.bias
module.block.0.layers.6.weight
module.block.0.layers.6.bias
module.block.0.layers.8.weight
module.block.0.layers.8.bias
module.block.2.layers.0.weight

再看一下网络模型的参数

model = CascadeNetwork(**args['network']) 
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

输出:
block.0.layers.0.weight
block.0.layers.0.bias
block.0.layers.2.weight
block.0.layers.2.bias
block.0.layers.4.weight
block.0.layers.4.bias
block.0.layers.6.weight
block.0.layers.6.bias

解决方法

对load的模型创建新的字典,去掉不需要的key值"module".

首先加载我训练好的模型
    state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
    state_dict = state_dict['model']
然后创建一个新的词典that does not contain module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
        new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。

最后把参数导入网络模型

model.load_state_dict(new_state_dict) # 从新加载这个模型。

标签:weight,block.0,module,state,bias,报错,dict
来源: https://blog.csdn.net/xuru_0927/article/details/119274321