其他分享
首页 > 其他分享> > Pytorch加载与保存模型(利用pth的参数更新h5预训练模型)

Pytorch加载与保存模型(利用pth的参数更新h5预训练模型)

作者:互联网

前言

以前用Keras用惯了,fit和fit_generator真的太好使了,模型断电保存搞个checkpoint回调函数就行了。近期使用pytorch进行训练,苦于没有类似的回调函数,写完网络进行训练的时候总不能每次都从头开始训练,于是乎就学了一下pytorch的模型相关操作。

训练过程

ArgumentParser解析器

argparse是一个Python模块:命令行选项、参数和子命令解析器。
主要有三个步骤:

如下:

    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, default='pre/91-image_x2.h5')
    parser.add_argument('--eval-file', type=str, default='pre/Set5_x2.h5')
    parser.add_argument('--outputs-dir', type=str, default='output/')
    parser.add_argument('--weights-file', type=str, default='weight/')
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=1024)
    parser.add_argument('--num-epochs', type=int, default=1000)
    parser.add_argument('--num-workers', type=int, default=16)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

使用参数的时候可以通过args.xxx进行调用就行了,这样的好处就是方便统一管理。 如果用编译器运行就得给赋默认值,如果命令行运行,在运行的时候命令后面给出参数就行,有默认值的会进行覆盖。

参数设置

包含网络模型的实例化,损失函数,优化器等一系列操作。

  if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = FSRCNN(scale_factor=args.scale).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': model.first_part.parameters()},
        {'params': model.mid_part.parameters()},
        {'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

模型加载/参数更新

加载上次训练生成的参数文件,通过update操作进行更新,并加载到现有模型中进行训练,这个也就是预训练参数,还要去掉参数中多余的k,v对。

    model_dict = model.state_dict()
    pre_dict = torch.load('训练参数文件.pth')
    pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
    model_dict.update(pre_dict)
    model.load_state_dict(model_dict)

模型保存

torch.save就可以进行模型的保存,里面传入的参数不一样保存方式就不一样。

torch.save(model.state_dict(), path)

eg: 

 torch.save(model.state_dict(),
                       os.path.join(args.outputs_dir, 'epoch_{}_psnr{:.2f}.pth'.format(best_epoch, best_psnr)))
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

这种方式加载的时候:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])
torch.save(model, path)

标签:args,pth,模型,parser,argument,h5,add,dict,model
来源: https://blog.csdn.net/qq_41573860/article/details/116564815