其他分享
首页 > 其他分享> > pytorch中self.register_buffer()

pytorch中self.register_buffer()

作者:互联网

PyTorch中定义模型时,有时候会遇到self.register_buffer(‘name’, Tensor)的操作,该方法的作用是定义一组参数,该组参数的特别之处在于:模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。

为了更好地理解这句话,按照惯例,我们通过一个例子实验来解释:

首先,定义一个模型并实例化:

import torch 
import torch.nn as nn
from collections import OrderedDict

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # (1)常见定义模型时的操作
        self.param_nn = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        # (2)使用register_buffer()定义一组参数
        self.register_buffer('param_buf', torch.randn(1, 2))

        # (3)使用形式类似的register_parameter()定义一组参数
        self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))

        # (4)按照类的属性形式定义一组变量
        self.param_attr = torch.randn(1, 2) 

    def forward(self, x):
        return x

net = Model()
 

上例中,我们通过继承nn.Module类定义了一个模型,在模型参数的定义中,我们分别以(1)常见的nn.Module类形式、(2)self.register_buffer()形式、(3)self.register_parameter()形式,以及(4)python类的属性形式定义了4组参数。

(1)哪些参数可以在模型训练时被更新?
这可以通过net.parameters()查看,因为定义优化器时是这样的:optimizer = SGD(net.parameters(), lr=0.1)。为了方便查看,我们使用 net.named_parameters():

In [8]: list(net.named_parameters())
Out[8]:
[('param_reg',
  Parameter containing:
  tensor([[-0.0617, -0.8984]], requires_grad=True)),
 ('param_nn.conv.weight',
  Parameter containing:
  tensor([[[[-0.3183, -0.0426, -0.2984],
            [-0.1451,  0.2686,  0.0556],
            [-0.3155,  0.0451,  0.0702]]]], requires_grad=True)),
 ('param_nn.fc.weight',
  Parameter containing:
  tensor([[-0.4647],
          [ 0.7753]], requires_grad=True))]
 

可以看到,我们定义的4组参数中,只有(1)和(3)定义的参数可以被更新,而self.register_buffer()和以python类的属性形式定义的参数都不能被更新。也就是说,modules和parameters可以被更新,而buffers和普通类属性不行。

那既然这两种形式定义的参数都不能被更新,二者可以互相替代吗?答案是不可以,原因看下一节:

(2)这其中哪些才算是模型的参数呢?
模型的所有参数都装在 state_dict 中,因为保存模型参数时直接保存 net.state_dict()。我们看一下其中究竟是哪些参数:

In [9]: net.state_dict()
Out[9]:
OrderedDict([('param_reg', tensor([[-0.0617, -0.8984]])),
             ('param_buf', tensor([[-1.0517,  0.7663]])),
             ('param_nn.conv.weight',
              tensor([[[[-0.3183, -0.0426, -0.2984],
                        [-0.1451,  0.2686,  0.0556],
                        [-0.3155,  0.0451,  0.0702]]]])),
             ('param_nn.fc.weight',
              tensor([[-0.4647],
                      [ 0.7753]]))])
 

可以看到,通过 nn.Module 类、self.register_buffer() 以及 self.register_parameter() 定义的参数都在 state-dict 中,只有用python类的属性形式定义的参数不包含其中。也就是说,保存模型时,buffers,modules和parameters都可以被保存,但普通属性不行。

(3)self.register_buffer() 的使用方法
在用self.register_buffer(‘name’, tensor) 定义模型参数时,其有两个形参需要传入。第一个是字符串,表示这组参数的名字;第二个就是tensor 形式的参数。

在模型定义中调用这个参数时(比如改变这组参数的值),可以使用self.name 获取。本文例中,就可用self.param_buf 引用。这和类属性的引用方法是一样的。

在实例化模型后,获取这组参数的值时,可以用 net.buffers() 方法获取,该方法返回一个生成器(可迭代变量):

In [10]: net.buffers()
Out[10]: <generator object Module.buffers at 0x00000289CA0032E0>

In [11]: list(net.buffers())
Out[11]: [tensor([[-1.0517,  0.7663]])]

# 也可以用named_buffers() 方法同时获取名字
In [12]: list(net.named_buffers())
Out[12]: [('param_buf', tensor([[-1.0517,  0.7663]]))]
 

(4)modules, parameters 和 buffers
实际上,PyTorch 定义的模型用OrderedDict() 的方式记录这三种类型,分别保存在self._modules, self._parameters 和 self._buffers 三个私有属性中。调试模式时就可以看到每个模型都有这几个私有属性:

在这里插入图片描述
调试模式 变量窗口
由于是私有属性,我们无法在实例化的变量上调用这些属性,可以在模型定义中调用它们:
在模型的实例化变量上调用时,三者有着相似的方法:

net.modules()
net.named_modules()

net.parameters()
net.named_parameters()

net.buffers()
net.named_buffers()
 

细心的读着可能会发现,self._parameters 和 net.parameters() 的返回值并不相同。这里self._parameters 只记录了使用 self.register_parameter() 定义的参数,而net.parameters() 返回所有可学习参数,包括self._modules 中的参数和self._parameters 参数的并集。

实际上,由nn.Module类定义的参数和self.register_parameter() 定义的参数性质是一样的,都是nn.Parameter 类型。

from:https://blog.csdn.net/dagouxiaohui/article/details/125649813

标签:parameters,nn,buffer,self,register,pytorch,参数,net
来源: https://www.cnblogs.com/chentiao/p/16683798.html